타임트리

[LangGrpah] Tool Binding 본문

LLM/LangGraph

[LangGrpah] Tool Binding

sean_j 2024. 12. 22. 03:02

LangGraph에서 Tool을 사용하기 위해서는 ToolNode가 정의되어야 한다


LangChain에서 Tool은 결국 invoke 메서드를 갖는 하나의 클래스로 정의
만약 Custom이 필요한 경우, langchain_core.toolsBaseTool를 상속받는 클래스를 정의하면 됨!

  • 추상 메서드 _run을 정의
  • 클래스 변수 name, description, args_schema 3가지를 정의

LLM에 Tool을 주는 방법은 다음과 같다!

STEP1. Tool 리스트 정의

from utils.custom_tools import TavilySearch

# 검색도구 생성
web_search = TavilySearch(max_results=1)
tools = [tool]

STEP2. llm에 tool 리스트 바인딩

from langchain_openai import ChatOpenAI

llm = ChatOpenAI()

llm_with_tools = llm.bind_tools(tools)

[!Note]
llm_with_tools 처럼 tool을 바인딩한 경우, LLM의 응답은 Tool이 필요할 때 다음과 같이 응답을 뱉어준다.

  1. query="A회사에 재직 중인 Sean에 대해 알려줘." 인 경우 ➡️ Web 검색이 필요하다고 판단!
    1. content = "" 로 반환
    2. addtional_kwargs의 tool_calls 키에 대한 값이 저장
    3. finish_reason은 tool_calls
    4. AIMessage에서 tool_calls로 파싱 가능! (→ response["messages"][-1].tool_calls)
  2. 안녕!인 경우 ➡️ Web 검색이 필요하지 않다고 판단!
    1. 기존 응답처럼 content 필드에 응답이 반환됨!

STEP3. 도구 노드 (Tool Node) 정의

  • 다음으로, 도구가 호출될 경우 실제로 실행할 수 있는 함수를 만들어야 한다!
  • 가장 최근의 메세지를 확인하고 메시지에 tool_calls가 포함되어 있으면 Tool을 호출하는 BasicToolNode를 구현해보자.
  • 여기서는 직접 구현하지만, 나중에는 LangGraph에 빌트인된 ToolNode로 대체할 수 있다.

결국 해당 객체를 __call__할 때,

  1. 가장 최근의 메세지를 가져와 LLM이 뱉어준 Tool과 Args로 실행한 결과를 받아온 뒤
  2. LangChain의 ToolMessage() 객체로 결과를 id에 매핑해서 반환하는 것!

 

우선 응답(response)에서 tool_calls를 파싱해보자.

>>> response.tool_calls
[{'name': 'tavliy_web_search', 
  'args': {'query': 'Sean A회사'}, 
  'id': 'call_yp8RqgCkUPJjhBh77NOQT8fw', 
  'type': 'tool_call'}]

 

그럼 위처럼 파싱해서 name 값으로 Tool 객체를 부르고 여기서 args 인자를 넣어주면 될 것 같다!
그럼 key 값을 통해 Tool 객체를 반환받을 수 있게 dictionary 타입의 tool_list를 만들자.

  • LangChainBaseTool을 상속받았으므로, 클래스 변수 tool.name에는 해당 tool의 이름이 들어있다.
>>> tools_list = {tool.name: tool for tool in tools}
>>> tools_list
{'tavliy_web_search': TavilySearch(client=<tavily.tavily.TavilyClient object at 0x0000016297DC5ED0>, max_results=1)}

 

이제, LLM(llm_with_tools)의 응답을 저장하고 있는 가장 최근 message의tool_calls`로 응답을 파싱해와서, 도구를 호출하고 결과를 저장하자.

  • 이때, 호출 결과로 LangChainToolMessage로 저장하자 (tool_callid를 맞춰주자)
import json
from langchain_core.messages import ToolMessages

class BasicToolNode:
    """Run tools requested in the last AIMessage node"""

    def __init__(self, tools: list) -> None:
        # 도구 리스트
        self.tools_list = {tool.name: tool for tool in tools}

    def __call__(self, inputs: dict):
        # 메시지가 존재할 경우 가장 최근 메시지 1개 추출
        if messages := inputs.get("messages", []):
            message = messages[-1]
        else:
            raise ValueError("No message found in input")
        # 도구 호출 결과
        outputs = []
        for tool_call in message.tool_calls:
            # 도구 호출 후 결과 저장
            tool_result = self.tools_list[tool_call["name"]].invoke(tool_call["args"])
            outputs.append(
                # 도구 호출 결과를 메시지로 저장
                ToolMessage(
                    content=json.dumps(
                        tool_result, ensure_ascii=False
                    ),  # 도구 호출 결과를 문자열로 변환
                    name=tool_call["name"],
                    tool_call_id=tool_call["id"],
                )
            )
        return {"messages": outputs}

조건부 엣지 정의 (Conditional Edge)

위에서 Tool을 호출하고 결과를 반환하는 Tool Node도 직접 정의했다.
이제 남은 건, LLM이 응답을 보고 도구를 호출할지 말지 결정하는데 이를 반영해줘야 한다!
즉, LLM의 응답을 확인하고 tool_calls

챗봇의 출력에서 tool_calls를 확인하는 route_tools라는 라우터 함수를 정의하자

 

 

route_tools는 다음과 같이 작동하게 만들자.

  1. state로부터 messages를 가져오자. 이때, 빈 리스트가 반환되면 state에 메세지가 없다고 예외 처리하자.
  2. 가져온 AIMessage가 tool_calls 속성을 가지면 tool 노드로, 아니면 end 노드로 분기를 태우자.
def route_tools(state: State) -> State:
    # 가장 최근 메세지 추출하기
    if messages := state.get("messages", []):
        ai_message = messages[-1]
    else:
        # staet에 메세지가 없으면 에러
        raise ValueError(f"tool_edge에 대한 입력 상태에서 메시지를 찾을 수 없습니다.\n{state}")

    # AIMessage 내 tool_calls 속성 존재 여부에 따라 분기
    if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
        return "tool"
    else:
        return "end"

그래프 정의

이제 그래프를 정의하기 위한 모든 요소가 갖춰졌다. 아래와 같이 그래프를 작성하자!

from langgraph.graph import StateGraph, START, END

workflow = StateGraph(State)

tool_node = BasicToolNode(tools=tools)

# 노드 추가
workflow.add_node("chat", chat)
workflow.add_node("tool", tool_node)

# 엣지 추가
workflow.add_edge(START, "chat")
workflow.add_conditional_edges(
    source="chat",
    path=route_tools,
    path_map={
        "tool": "tool",
        "end": END
    }
)
# tool에서 chat으로 보내는 엣지 추가
workflow.add_edge("tool", "chat")

# Compile
graph = workflow.compile()

 

시각화를 해보면 원하는대로 결과가 출력됨을 알 수 있다.

from IPython.display import display, Image

display(Image(graph.get_graph().draw_mermaid_png()))

 

stream 메서드로 각 Step의 message를 확인해 볼 수 있다! (stream_mode="values"로, 각 노드의 output state 확인)

input_state = {"messages": "A회사에 다니는 Sean이란 사람에 대해 알려줘."}

for event in graph.stream(input_state, stream_mode="values"):
    for k, v in event.items():
        print(f"\n==============\nSTEP: {k}\n==============\n")
        # display_message_tree(value["messages"][-1])
        print(v[-1])

 

 

---

참고

1. <랭체인LangChain 노트> - LangChain 한국어 튜토리얼🇰🇷