https://www.pexels.com/photo/flashing-sparks-coming-from-the-angle-grinder-9665360/ |
This is blog, we look at LangGraph specifically the conditional nodes.
The code is in a github repo. Let go through the graph setup so to understand what we have. the code can be found here too. The meat of it is about 60 lines of code. This is cool because LangChain core and LangGraph provide the plumbing code to create more complex workflow.
import argparse import asyncio import os import sys from langchain_core.messages.ai import AIMessage from langchain_core.messages.base import BaseMessage from langgraph.graph import END, START, MessagesState, StateGraph from langgraph.graph.state import CompiledStateGraph from langgraph.prebuilt import ToolNode from langgraph_tools.hosting import container from langgraph_tools.messages.message_builder import MessageBuilder, MessageKind from langgraph_tools.protocols.i_azure_openai_service import IAzureOpenAIService from langgraph_tools.tools import count_words, get_entities, summarize ############################################################################### # above are the imports. REAL stuffs start here # tools, count_words, get_entities and summarize # look here of the tools, very simple. # get_entities tool code is stolen from my previous blog :-) tools = [count_words, get_entities, summarize] # bind tools to the model # We have a Azure OpenAI Model and we bind it to the set of tools (defined above) llm_model_with_tools = container[IAzureOpenAIService].get_model().bind_tools(tools) # this is the condition function to tell LangGraph to route to tools or END node. # that's if AIMessage has a tool_calls go to "tools" node def should_continue(state: MessagesState): messages = state["messages"] last_message = messages[-1] return ( "tools" if isinstance(last_message, AIMessage) and last_message.tool_calls else END ) # this is the function for the "agent" node def call_model(state: MessagesState): messages = state["messages"] response = llm_model_with_tools.invoke(messages) return {"messages": [response]} def create_graph() -> CompiledStateGraph: # Every graph needs a state object to track the states # I could have implemented one however I just use the one # provided which is # class MessagesState(TypedDict): # messages: Annotated[list[AnyMessage], add_messages] graph = StateGraph(MessagesState) # here is where we connect up the graph where we have nodes # and edges. Note the "add_conditional_edges" which calls # the "should_continue" function graph.add_node("agent", call_model) tool_node = ToolNode(tools) graph.add_node("tools", tool_node) graph.add_edge(START, "agent") graph.add_conditional_edges("agent", should_continue, ["tools", END]) graph.add_edge("tools", "agent") return graph.compile() async def invoke(text: str, actions: list[MessageKind]) -> str: # create and compile the graph # track the last message so we can print the content app = create_graph() last_message: BaseMessage | None = None message = MessageBuilder().build(set(actions), text) async for value in app.astream( {"messages": [message]}, stream_mode="values", ): last_message = value["messages"][-1] return str(last_message.content) if last_message else "not results" ######################################################################## Below are some "argparse" call to determine where can be found the file that contains the content, and what actions (count words, get entities and summarize) to take def parse_args() -> tuple[str, list[MessageKind]]: parser = argparse.ArgumentParser() parser.add_argument("-f", "--file", help="file to process", required=True) parser.add_argument("-c", "--count-words", help="count words", action="store_true") parser.add_argument("-s", "--summarize", help="summarize", action="store_true") parser.add_argument( "-e", "--extract-entities", help="summarize", action="store_true" ) args = parser.parse_args() actions = [] if args.count_words: actions.append("word_count") if args.summarize: actions.append("summarize") if args.extract_entities: actions.append("extract_entities") if not actions: print("At least one action must be selected") sys.exit(1) return args.file, actions if __name__ == "__main__": file, actions = parse_args() path = os.path.join("data", file) with open(path, "r") as fp: text = fp.read() result = asyncio.run(invoke(text, actions)) print(result) sys.exit(0)
Addendum
You can also see another version of this code in a Python code. And here is how we execute it with LangChain Agent Executor.
Comments
Post a Comment