使用 LangGraph 构建带记忆的对话 Agent

准备工作

为了简单起见,我会在 Google Colab 中运行代码。因为 Google 提供了免费的 API 额度,那么后续尽量都使用 Gemini 模型。目前只需要安装三个依赖:langgraphlangchain-openaiipython。创建对话模型代码如下:

1
2
3
4
5
6
7
8
import os
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(
model="gemini-2.5-flash",
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
api_key = os.getenv("GEMINI_API_KEY"),
)

这里出于兼容的考虑,我使用了 langchain_openai 来调用 Google Gemini 而不使用 langchain-google-genai

基本概念

在编写实际代码之前,有一些概念需要我们了解:

  1. State(状态)是个共享的数据结构,作用是保存整个应用当前的信息与上下文。
  2. Node(节点)是图(graph)中执行特定任务的独立函数或操作。每个节点会接收输入(通常是当前的状态),对输入进行处理,之后产出输出结果,或者更新状态。
  3. Graph 是一种总体架构,用来规划不同任务(节点)是如何连接与执行的。
  4. Edges(边)是节点(Nodes)之间的连接,用来决定执行流程的走向。
  5. START 是一个虚拟的入口点,用来标记工作流开始的地方。它自己不执行任何具体操作,是整个图(Graph)执行时 “指定的起始位置”。
  6. END 节点用于标识工作流的结束。当流程执行到达这个节点时,整个图(Graph)的运行就会停止,意味着所有预设的流程都已完成。
  7. StateGraph 是一个类,作用是构建和编译图(Graph)的结构

创建 StateGraph

在官方教程中,创建 StateGraph 之前先要创建状态(State):

1
2
3
4
5
6
7
8
from langgraph.graph.message import add_messages
from typing import Annotated

class State(TypedDict):
# Messages have the type "list".
# The `add_messages` function in the annotation defines how this state key should be updated
# (in this case, it appends messages to the list, rather than overwriting them)
messages: Annotated[list, add_messages]

上面代码的含义就是使用一个列表来保存用户和模型之间的对话记录,状态中的 add_messages 方法会将 LLM 的响应消息追加到状态中已有的任何消息之后。还有一种更简单的使用方式是使用 langgraph.graph 中的 MessagesState 来代替上面创建的 State 类。

对于 StateGraph,它负责管理节点、边以及状态,控制组件之间的数据流转。综上,当前的代码可以被简化成如下形式:

1
2
3
from langgraph.graph import MessagesState, StateGraph

workflow = StateGraph(state_schema=MessagesState) # 这里用 MessagesState 来代替上面的 State 类

创建节点

接下去让我们来创建一个节点,负责获取用户输入,并生成模型回复。

1
2
3
def call_model(state: MessagesState):
response = llm.invoke(state["messages"])
return {"messages": [response]}

上述代码中,call_model 方法不是接收一个代表用户输入的字符串参数,而是接收一个 MessagesState 类型的参数,并返回一个包含更新后的消息列表的字典,key 为 messages。这是所有 LangGraph 节点方法的基本模式。

状态的变化

状态的变化是由图的节点驱动的。每个节点函数接收当前的完整 State 作为输入,节点函数执行后,并不需要返回一个全新的、完整的 State 对象,而只需要返回一个包含待更新字段的字典(或部分状态对象),这被称为“状态更新”。LangGraph 框架会接收这个“状态更新”,并根据预设的规则将其应用到当前的 State 上,从而生成一个新的 State 快照。

Reducer 函数定义了如何将节点返回的“状态更新”合并到现有的 State 中。State 中的每一个字段都可以有自己独立的 Reducer。如果在定义 State Schema 时没有为某个字段指定 Reducer,那么它会使用默认行为:覆盖(Override)。对于像聊天记录这样的常见场景,LangGraph 提供了预置的 Reducer,上述代码中的 add_messages 是一个很好的例子,另一个例子如下:

1
2
3
4
5
6
from typing import Annotated, TypedDict
from operator import add

class State(TypedDict):
foo: int
bar: Annotated[list[str], add]

foo 字段使用默认行为,将节点返回的更新值覆盖到 foo 字段中;bar 字段标注了 add,这将节点返回的更新值追加到 bar 字段中。

创建流程

接下去我们来构建一个顺序对话流程。

1
2
3
4
5
6
7
from langgraph.graph import END, START

workflow.add_node("chatbot", call_model) # 添加节点,"chatbot" 为节点名称
workflow.add_edge(START, "chatbot") # 添加起始边
workflow.add_edge("chatbot", END) # 添加结束边

graph = workflow.compile() # 编译

我们可以将上述流程转换为图像查看。

1
2
3
from IPython.display import Image, display

display(Image(graph.get_graph().draw_mermaid_png()))

Chat Agent Graph

对话测试

想要测试对话效果,我们还需额外添加一些代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from langchain_core.messages import AIMessage, HumanMessage

def chat(query: str):
input_message: list[HumanMessage] = [HumanMessage(content=query)]
for event in graph.stream({"messages": input_message}, checkpoint_config):
for value in event.values():
if isinstance(value["messages"][-1], AIMessage):
print(value["messages"][-1].content)

while True:
user_input = input("User: ")
if user_input.lower() in ["quit", "exit", "q"]:
break
chat(user_input)

现在运行代码即可与模型进行对话。不过你会发现,当前模型无法记住对话内容。

为什么无法记忆?

想让模型记住之前的对话,就需要在调用模型时把过去的历史对话消息都发送给模型。但是当前在每一轮 While 循环下,当顺序流程结束时,保存在 MessagesState 中的所有对话消息都会被丢弃,新一轮循环开始时,模型无法获取之前的对话消息,所以模型无法记住之前的对话。

添加记忆

我们只需要添加和修改少量代码就可以为上述流程添加记忆,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# 添加
from langgraph.checkpoint.memory import InMemorySaver

memory = InMemorySaver()
checkpoint_config = {"configurable": {"thread_id": "111"}} # thread_id 用来区分不同的用户对话

# ....

graph = workflow.compile(checkpointer=memory) # 修改:加上 checkpointer=memory

# ....

def chat(query: str):
input_message: list[HumanMessage] = [HumanMessage(content=query)]
for event in graph.stream({"messages": input_message}, checkpoint_config): # 修改:加上 checkpoint_config
for value in event.values():
if isinstance(value["messages"][-1], AIMessage):
print(value["messages"][-1].content)

封装

让我们把上述代码封装成一个 ChatAgent 类,方便后续使用:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from langchain_openai import ChatOpenAI
from langgraph.graph import MessagesState, StateGraph, END, START
from langgraph.checkpoint.memory import InMemorySaver
from langchain_core.messages import AIMessage, HumanMessage
import uuid
from functools import cached_property


class ChatAgent:
def __init__(self):
self.llm = ChatOpenAI(
model="gemini-2.5-flash",
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
api_key = api_key,
)
self.workflow = StateGraph(state_schema=MessagesState)
self.memory = InMemorySaver()
self.checkpoint_config = {"configurable": {"thread_id": uuid.uuid4().hex}}

def call_model(self, state: MessagesState):
response = self.llm.invoke(state["messages"])
return {"messages": [response]}

@cached_property
def graph(self):
self.workflow.add_node("chatbot", self.call_model)
self.workflow.add_edge(START, "chatbot")
self.workflow.add_edge("chatbot", END)

return self.workflow.compile(checkpointer=self.memory)

def chat(self, query: str):
input_message: list[HumanMessage] = [HumanMessage(content=query)]
for event in self.graph.stream({"messages": input_message}, self.checkpoint_config):
for value in event.values():
if isinstance(value["messages"][-1], AIMessage):
print(value["messages"][-1].content)


if __name__ == "__main__":
agent = ChatAgent()

while True:
user_input = input("User: ")
if user_input.lower() in ["quit", "exit", "q"]:
break
agent.chat(user_input)

参考