[experimental][llms][ollama_functions] Update OllamaFunctions to send tool_calls attribute (#21625)

Update OllamaFunctions to return `tool_calls` for AIMessages when used
for tool calling.
This commit is contained in:
Karim Lalani 2024-05-29 08:38:33 -05:00 committed by GitHub
parent d61bdeba25
commit a1899439fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 78 additions and 21 deletions

View File

@ -1,4 +1,5 @@
import json import json
import uuid
from operator import itemgetter from operator import itemgetter
from typing import ( from typing import (
Any, Any,
@ -18,7 +19,7 @@ from typing import (
from langchain_community.chat_models.ollama import ChatOllama from langchain_community.chat_models.ollama import ChatOllama
from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LanguageModelInput from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import AIMessage, BaseMessage from langchain_core.messages import AIMessage, BaseMessage, ToolCall
from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.json import JsonOutputParser from langchain_core.output_parsers.json import JsonOutputParser
from langchain_core.output_parsers.pydantic import PydanticOutputParser from langchain_core.output_parsers.pydantic import PydanticOutputParser
@ -94,16 +95,19 @@ def parse_response(message: BaseMessage) -> str:
"""Extract `function_call` from `AIMessage`.""" """Extract `function_call` from `AIMessage`."""
if isinstance(message, AIMessage): if isinstance(message, AIMessage):
kwargs = message.additional_kwargs kwargs = message.additional_kwargs
if "function_call" in kwargs: tool_calls = message.tool_calls
if len(tool_calls) > 0:
tool_call = tool_calls[-1]
args = tool_call.get("args")
return json.dumps(args)
elif "function_call" in kwargs:
if "arguments" in kwargs["function_call"]: if "arguments" in kwargs["function_call"]:
return kwargs["function_call"]["arguments"] return kwargs["function_call"]["arguments"]
raise ValueError( raise ValueError(
f"`arguments` missing from `function_call` within AIMessage: {message}" f"`arguments` missing from `function_call` within AIMessage: {message}"
) )
raise ValueError( else:
"`function_call` missing from `additional_kwargs` " raise ValueError("`tool_calls` missing from AIMessage: {message}")
f"within AIMessage: {message}"
)
raise ValueError(f"`message` is not an instance of `AIMessage`: {message}") raise ValueError(f"`message` is not an instance of `AIMessage`: {message}")
@ -293,10 +297,9 @@ class OllamaFunctions(ChatOllama):
"matching function in `functions`." "matching function in `functions`."
) )
del kwargs["function_call"] del kwargs["function_call"]
elif not functions:
functions.append(DEFAULT_RESPONSE_FUNCTION)
if _is_pydantic_class(functions[0]): if _is_pydantic_class(functions[0]):
functions = [convert_to_ollama_tool(fn) for fn in functions] functions = [convert_to_ollama_tool(fn) for fn in functions]
functions.insert(0, DEFAULT_RESPONSE_FUNCTION)
system_message_prompt_template = SystemMessagePromptTemplate.from_template( system_message_prompt_template = SystemMessagePromptTemplate.from_template(
self.tool_system_prompt_template self.tool_system_prompt_template
) )
@ -318,7 +321,6 @@ class OllamaFunctions(ChatOllama):
Response: {chat_generation_content}""" Response: {chat_generation_content}"""
) )
called_tool_name = parsed_chat_result["tool"] called_tool_name = parsed_chat_result["tool"]
called_tool_arguments = parsed_chat_result["tool_input"]
called_tool = next( called_tool = next(
(fn for fn in functions if fn["name"] == called_tool_name), None (fn for fn in functions if fn["name"] == called_tool_name), None
) )
@ -328,26 +330,39 @@ class OllamaFunctions(ChatOllama):
f"{chat_generation_content}" f"{chat_generation_content}"
) )
if called_tool["name"] == DEFAULT_RESPONSE_FUNCTION["name"]: if called_tool["name"] == DEFAULT_RESPONSE_FUNCTION["name"]:
if (
"tool_input" in parsed_chat_result
and "response" in parsed_chat_result["tool_input"]
):
response = parsed_chat_result["tool_input"]["response"]
elif "response" in parsed_chat_result:
response = parsed_chat_result["response"]
else:
raise ValueError(
f"Failed to parse a response from {self.model} output: "
f"{chat_generation_content}"
)
return ChatResult( return ChatResult(
generations=[ generations=[
ChatGeneration( ChatGeneration(
message=AIMessage( message=AIMessage(
content=called_tool_arguments["response"], content=response,
) )
) )
] ]
) )
called_tool_arguments = parsed_chat_result["tool_input"]
response_message_with_functions = AIMessage( response_message_with_functions = AIMessage(
content="", content="",
additional_kwargs={ tool_calls=[
"function_call": { ToolCall(
"name": called_tool_name, name=called_tool_name,
"arguments": json.dumps(called_tool_arguments) args=called_tool_arguments if called_tool_arguments else {},
if called_tool_arguments id=f"call_{str(uuid.uuid4()).replace('-', '')}",
else "", )
}, ],
},
) )
return ChatResult( return ChatResult(

View File

@ -52,9 +52,51 @@ class TestOllamaFunctions(unittest.TestCase):
res = model.invoke("What's the weather in San Francisco?") res = model.invoke("What's the weather in San Francisco?")
function_call = res.additional_kwargs.get("function_call") self.assertIsInstance(res, AIMessage)
assert function_call res = AIMessage(**res.__dict__)
self.assertEqual(function_call.get("name"), "get_current_weather") tool_calls = res.tool_calls
assert tool_calls
tool_call = tool_calls[0]
assert tool_call
self.assertEqual("get_current_weather", tool_call.get("name"))
def test_default_ollama_functions_default_response(self) -> None:
base_model = OllamaFunctions(model="llama3", format="json")
# bind functions
model = base_model.bind_tools(
tools=[
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, "
"e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
}
]
)
res = model.invoke("What is the capital of France?")
self.assertIsInstance(res, AIMessage)
res = AIMessage(**res.__dict__)
tool_calls = res.tool_calls
if len(tool_calls) > 0:
tool_call = tool_calls[0]
assert tool_call
self.assertEqual("__conversational_response", tool_call.get("name"))
def test_ollama_structured_output(self) -> None: def test_ollama_structured_output(self) -> None:
model = OllamaFunctions(model="phi3") model = OllamaFunctions(model="phi3")