luping85's picture
Update agent.py
a7928df verified
raw
history blame
2.13 kB
import os
from typing import TypedDict, Annotated
from langgraph.graph.message import add_messages
from langchain_core.messages import AnyMessage, SystemMessage
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.graph import START, StateGraph
from langchain_openai import ChatOpenAI
# from langchain_google_genai import ChatGoogleGenerativeAI
from tools import tavily_search_tool, repl_tool
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
# GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
chat_model = ChatOpenAI(model='gpt-4.1', temperature=0, api_key=OPENAI_API_KEY)
# chat_model = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=GOOGLE_API_KEY)
tools = [tavily_search_tool, repl_tool]
chat_model_with_tools = chat_model.bind_tools(tools)
SYSTEM_MESSAGE = """
You are a general AI assistant. I will ask you a question. Answer the question with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to
write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities),
and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be
put in the list is a number or a string.
"""
class AgentState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
def assistant(state: AgentState):
sys_msg = SystemMessage(content=SYSTEM_MESSAGE)
return {
"messages": [chat_model_with_tools.invoke([sys_msg] + state["messages"])]
}
graph_builder = StateGraph(AgentState)
graph_builder.add_node("assistant", assistant)
graph_builder.add_node("tools", ToolNode(tools))
graph_builder.add_edge(START, "assistant")
graph_builder.add_conditional_edges(
"assistant",
tools_condition,
)
graph_builder.add_edge("tools", "assistant")
graph = graph_builder.compile()