LangGraph - tools

https://www.pexels.com/photo/flashing-sparks-coming-from-the-angle-grinder-9665360/
https://www.pexels.com/photo/flashing-sparks-coming-from-the-angle-grinder-9665360/

This is blog, we look at LangGraph specifically the conditional nodes.

The code is in a github repo. Let go through the graph setup so to understand what we have. the code can be found here too. The meat of it is about 60 lines of code. This is cool because LangChain core and LangGraph provide the plumbing code to create more complex workflow.  

import argparse
import asyncio
import os
import sys

from langchain_core.messages.ai import AIMessage
from langchain_core.messages.base import BaseMessage
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import ToolNode

from langgraph_tools.hosting import container
from langgraph_tools.messages.message_builder import MessageBuilder, MessageKind
from langgraph_tools.protocols.i_azure_openai_service import IAzureOpenAIService
from langgraph_tools.tools import count_words, get_entities, summarize

###############################################################################
# above are the imports. REAL stuffs start here

# tools, count_words, get_entities and summarize
# look here of the tools, very simple.
# get_entities tool code is stolen from my previous blog :-)
tools = [count_words, get_entities, summarize]

# bind tools to the model
# We have a Azure OpenAI Model and we bind it to the set of tools (defined above)
llm_model_with_tools = container[IAzureOpenAIService].get_model().bind_tools(tools)


# this is the condition function to tell LangGraph to route to tools or END node.
# that's if AIMessage has a tool_calls go to "tools" node
def should_continue(state: MessagesState):
    messages = state["messages"]
    last_message = messages[-1]
    return (
        "tools"
        if isinstance(last_message, AIMessage) and last_message.tool_calls
        else END
    )


# this is the function for the "agent" node
def call_model(state: MessagesState):
    messages = state["messages"]
    response = llm_model_with_tools.invoke(messages)
    return {"messages": [response]}


def create_graph() -> CompiledStateGraph:
    # Every graph needs a state object to track the states
    # I could have implemented one however I just use the one
    # provided which is
    # class MessagesState(TypedDict):
    #    messages: Annotated[list[AnyMessage], add_messages]
    graph = StateGraph(MessagesState)

    # here is where we connect up the graph where we have nodes
    # and edges. Note the "add_conditional_edges" which calls
    # the "should_continue" function
    graph.add_node("agent", call_model)
    tool_node = ToolNode(tools)
    graph.add_node("tools", tool_node)

    graph.add_edge(START, "agent")
    graph.add_conditional_edges("agent", should_continue, ["tools", END])
    graph.add_edge("tools", "agent")

    return graph.compile()


async def invoke(text: str, actions: list[MessageKind]) -> str:
    # create and compile the graph
    # track the last message so we can print the content
    app = create_graph()
    last_message: BaseMessage | None = None
    message = MessageBuilder().build(set(actions), text)

    async for value in app.astream(
        {"messages": [message]},
        stream_mode="values",
    ):
        last_message = value["messages"][-1]

    return str(last_message.content) if last_message else "not results"


########################################################################
Below are some "argparse" call to determine where can be found the file 
that contains the content, and what actions (count words, get entities
and summarize) to take

def parse_args() -> tuple[str, list[MessageKind]]:
    parser = argparse.ArgumentParser()
    parser.add_argument("-f", "--file", help="file to process", required=True)
    parser.add_argument("-c", "--count-words", help="count words", action="store_true")
    parser.add_argument("-s", "--summarize", help="summarize", action="store_true")
    parser.add_argument(
        "-e", "--extract-entities", help="summarize", action="store_true"
    )
    args = parser.parse_args()

    actions = []
    if args.count_words:
        actions.append("word_count")
    if args.summarize:
        actions.append("summarize")
    if args.extract_entities:
        actions.append("extract_entities")

    if not actions:
        print("At least one action must be selected")
        sys.exit(1)

    return args.file, actions


if __name__ == "__main__":
    file, actions = parse_args()

    path = os.path.join("data", file)
    with open(path, "r") as fp:
        text = fp.read()
        result = asyncio.run(invoke(text, actions))
        print(result)
        sys.exit(0)


Addendum

You can also see another version of this code in a Python code.  And here is how we execute it with LangChain Agent Executor.



Comments