타임트리

[LangGraph] Delete Messages 본문

LLM/LangGraph

[LangGraph] Delete Messages

sean_j 2024. 12. 31. 02:38

메세지 삭제 방법

일반적으로 State에는 messages라는 키로 리스트에 메세지 이력을 append 해가며 이력을 관리하게 된다.

때로는 메세지를 삭제할 필요가 있을 수 있다. config별로 대화 내역을 쌓게 되는데, 내용이 너무 길어지면 컨텍스트가 너무 길어질 수 있다. 이외에도 필요없는 이력은 삭제하고 싶을 수 있다.

이를 위해 RemoveMessage라는 reducer를 사용할 수 있다. RemoveMessage에 동일한 ID를 갖는 메세지를 자동으로 삭제해준다.

간단한 웹 서치 그래프 정의

먼저 간단하게 web search를 모방하는 search 함수를 tool로 정의하고 이를 binding한 LLM을 사용하여 그래프를 만들어보자.

 

# web search graph 구축
from typing import Annotated, TypedDict
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import MessagesState, StateGraph, START, END
from langgraph.prebuilt import ToolNode, tools_condition

# checkpoint 저장을 위한 memory 객체 초기화
memory = MemorySaver()


# mimic web search tool
@tool
def search(query: str):
    """Call to surf the web"""
    return "서울의 날씨는 영하 200도 입니다!"


tools = [search]
tool_node = ToolNode(tools)

llm = ChatOpenAI(model_name="gpt-4o-mini")
llm_with_tools = llm.bind_tools(tools=tools)


def should_continue(state: MessagesState):
    """다음에 실행할 노드 반환"""
    last_message = state["messages"][-1]
    # 만약 tool call이 없으면 그래프 끝내기
    if not last_message.tool_calls:
        return END
    # 아니면 계속 진행
    return "search_tool"


# model 호출하는 함수 정의
def call_model(state: MessagesState):
    response = llm_with_tools.invoke(state["messages"])
    # list 반환, add_messages reducer는 기존 list에 더해줌
    return {"messages": [response]}


flow = StateGraph(MessagesState)

flow.add_node("agent", call_model)
flow.add_node("search_tool", tool_node)

flow.add_edge(START, "agent")
flow.add_edge("search_tool", "agent")

flow.add_conditional_edges(
    "agent", should_continue, {"search_tool": "search_tool", END: END}
)


graph = flow.compile(checkpointer=memory)

 

 

이 그래프를 compile할 때, InMemory 객체인 memory를 checkpointer로 지정했기 때문에 기존 대화 내역을 기록하고 있다. 아래를 통해 확인해보자.

config = {"configurable": {"thread_id": "1"}}
input = {"messages": [("user", "안녕하세요~ 제 이름은 Sean입니다.")]}

events = graph.stream(input, config, stream_mode="values")

for event in events:
    event["messages"][-1].pretty_print()
================================ Human Message =================================

안녕하세요~ 제 이름은 Sean입니다.
================================== Ai Message ==================================

안녕하세요, Sean님! 어떻게 도와드릴까요?
# 2번째 대화 추가
input = {"messages": [("user", "제 이름이 뭐라고 했는지 기억하시나요?")]}

events = graph.stream(input, config, stream_mode="values")

for event in events:
    event["messages"][-1].pretty_print()
================================ Human Message =================================

제 이름이 뭐라고 했는지 기억하시나요?
================================== Ai Message ==================================

네, Sean님이라고 하셨습니다! 어떻게 도와드릴까요?

 

해당 내역은 MemorySaver 객체의 get_tuple(config) 메서드를 통해 접근할 수 있다. (또는 현재 그래프의 get_state 메서드로 가져올 수도 있다)

# memory.get_tuple(config).checkpoint["channel_values"]
graph.get_state(config).values["messages"]
[HumanMessage(content='안녕하세요~ 제 이름은 Sean입니다.', ...),
 AIMessage(content='안녕하세요, Sean님! 어떻게 도와드릴까요?',...),
 HumanMessage(content='제 이름이 뭐라고 했는지 기억하시나요?',...),
 AIMessage(content='네, Sean님이라고 하셨습니다! 어떻게 도와드릴까요?',...)]

메세지 수동으로 삭제하기

먼저 메세지를 수동으로 삭제해보자. 메세지 삭제는 RemoveMessage reducer를 통해 삭제할 수 있다. add_message reducer가 State의 값을 바꿔치기하는 것이 아니라, list에 append해주는 것처럼 State의 messages 키의 값에 RemoveMessage(id)로 상태를 업데이트해 주면 값을 바꿔치지 않고 해당 id를 가진 메세지를 삭제해준다.

즉, 수동으로 메세지를 삭제하기 위해서는 update_state 메서드를 호출하면서, values 인자에 RemoveMessage(삭제하고자하는 메세지의 id)를 전달하면 해당 메세지가 삭제된다.

from langchain_core.messages import RemoveMessage

# 첫 번째 ID를 기반으로 제거하고 graph 상태 없데이트
graph.update_state(config, values={"messages": RemoveMessage(id=messages[0].id)})

 

다시 graph의 해당 config의 메세지 목록을 가져와보자.

  • 첫 번째 메세지가 삭제된 것을 확인할 수 있다!
# 현재 thread state 확인
messages = graph.get_state(config).values["messages"]
messages
[AIMessage(content='안녕하세요, Sean님! 어떻게 도와드릴까요?',...),
 HumanMessage(content='제 이름이 뭐라고 했는지 기억하시나요?',...),
 AIMessage(content='네, Sean님이라고 하셨습니다! 어떻게 도와드릴까요?',...)]

메세지를 동적으로 삭제하기

Graph 내부에서 노드를 추가해 해당 노드를 방문할 시 메세지를 삭제하도록 할 수도 있다.
그래프 실행이 종료될 때, 오래된 메시지를 삭제하도록 그래프를 수정해보자 (3개 메시지만 살리고 나머지는 삭제).

delete_messages 라는 Node를 만드는 것!

 

아래처럼, 그래프가 종료되기 전에 delete_messages 노드를 들려 최근 3개의 메세지만 살리고 나머지는 삭제하도록 만들어보자.

 

from typing import Literal
from langchain_core.messages import RemoveMessage


# 메세지 개수가 3개가 넘어갈 경우, 오래된 메세지 삭제해서 최신 3개만 유지
def delete_messages(state):
    messages = state["messages"]
    if len(messages) > 3:
        return {"messages": [RemoveMessage(id=message.id) for message in messages[:-3]]}


# 바로 종료하지 않고 delete_messages를 호출하도록 로직 수정
def should_continue(state: MessagesState) -> Literal["search_tool", "delete_messages"]:
    last_message = state["messages"][-1]
    if not last_message.tool_calls:
        return "delete_messages"
    return "search_tool"


flow = StateGraph(MessagesState)

flow.add_node("agent", call_model)
flow.add_node("search_tool", search)
# delete_messages 노드 추가
flow.add_node("delete_messages", delete_messages)

flow.add_edge(START, "agent")
flow.add_edge("search_tool", "agent")
flow.add_edge("delete_messages", END)

flow.add_conditional_edges(
    "agent",
    should_continue,
    {"delete_messages": "delete_messages", "search_tool": "search_tool"},
)

graph = flow.compile(checkpointer=memory)

 

이제 LLM이 응답을 모두 완성했다면, 바로 END 노드로 가서 그래프가 종료되는 것이 아니라, delete_messages 노드를 거쳐 최근 3개 메세지만 이력으로 갖게 된다.

 

호출을 2번해서 확인해보자!

from langchain_core.messages import HumanMessage

config = {"configurable": {"thread_id": "2"}}
input = {"messages": [HumanMessage(content="안녕하세요, 제 이름은 Sean 입니다!")]}

events = graph.stream(input, config, stream_mode="values") # 노드마다 상태 출력

for event in events:
    print([(message.type, message.content) for message in event["messages"]])
[('human', '안녕하세요, 제 이름은 Sean 입니다!')] 
[('human', '안녕하세요, 제 이름은 Sean 입니다!'), ('ai', '안녕하세요, Sean님! 어떻게 도와드릴까요?')]

 

아래 결과를 살펴보면, ('human', '제 이름이 뭐라고 했었죠?')이 input으로 들어오게 되고, 중간에 agent 노드에서 ('ai', 'Sean님이라고 말씀하셨습니다!')가 추가된다.

 

그리고 마지막으로 가장 첫 번째 메세지가 삭제되어 총 3개의 메세지가 담겨서 반환되는 것을 확인할 수 있다.

input = {"messages": [HumanMessage(content="제 이름이 뭐라고 했었죠?")]}
events = graph.stream(input, config, stream_mode="values") # 노드마다 상태 출력

for event in events:
    print([(message.type, message.content) for message in event["messages"]])
[('human', '안녕하세요, 제 이름은 Sean 입니다!'), ('ai', '안녕하세요, Sean님! 어떻게 도와드릴까요?'), ('human', '제 이름이 뭐라고 했었죠?')] 
[('human', '안녕하세요, 제 이름은 Sean 입니다!'), ('ai', '안녕하세요, Sean님! 어떻게 도와드릴까요?'), ('human', '제 이름이 뭐라고 했었죠?'), ('ai', 'Sean님이라고 말씀하셨습니다!')]
[('ai', '안녕하세요, Sean님! 어떻게 도와드릴까요?'), ('human', '제 이름이 뭐라고 했었죠?'), ('ai', 'Sean님이라고 말씀하셨습니다!')]

 

 

 

---

출처

LangGraph. "How to delete messages". https://langchain-ai.github.io/langgraph/how-tos/memory/delete-messages/

위키독스 - <랭체인LangChain 노트> - LangChain 한국어 튜토리얼🇰🇷  (https://wikidocs.net/book/14314)