[experimental][llms][ollama_functions] Update OllamaFunctions to send tool_calls attribute (#21625)

Update OllamaFunctions to return `tool_calls` for AIMessages when used
for tool calling.
This commit is contained in:
Karim Lalani 2024-05-29 08:38:33 -05:00 committed by GitHub
parent d61bdeba25
commit a1899439fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 78 additions and 21 deletions

View File

@ -1,4 +1,5 @@
import json
import uuid
from operator import itemgetter
from typing import (
Any,
@ -18,7 +19,7 @@ from typing import (
from langchain_community.chat_models.ollama import ChatOllama
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.messages import AIMessage, BaseMessage, ToolCall
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.json import JsonOutputParser
from langchain_core.output_parsers.pydantic import PydanticOutputParser
@ -94,16 +95,19 @@ def parse_response(message: BaseMessage) -> str:
"""Extract `function_call` from `AIMessage`."""
if isinstance(message, AIMessage):
kwargs = message.additional_kwargs
if "function_call" in kwargs:
tool_calls = message.tool_calls
if len(tool_calls) > 0:
tool_call = tool_calls[-1]
args = tool_call.get("args")
return json.dumps(args)
elif "function_call" in kwargs:
if "arguments" in kwargs["function_call"]:
return kwargs["function_call"]["arguments"]
raise ValueError(
f"`arguments` missing from `function_call` within AIMessage: {message}"
)
raise ValueError(
"`function_call` missing from `additional_kwargs` "
f"within AIMessage: {message}"
)
else:
raise ValueError("`tool_calls` missing from AIMessage: {message}")
raise ValueError(f"`message` is not an instance of `AIMessage`: {message}")
@ -293,10 +297,9 @@ class OllamaFunctions(ChatOllama):
"matching function in `functions`."
)
del kwargs["function_call"]
elif not functions:
functions.append(DEFAULT_RESPONSE_FUNCTION)
if _is_pydantic_class(functions[0]):
functions = [convert_to_ollama_tool(fn) for fn in functions]
functions.insert(0, DEFAULT_RESPONSE_FUNCTION)
system_message_prompt_template = SystemMessagePromptTemplate.from_template(
self.tool_system_prompt_template
)
@ -318,7 +321,6 @@ class OllamaFunctions(ChatOllama):
Response: {chat_generation_content}"""
)
called_tool_name = parsed_chat_result["tool"]
called_tool_arguments = parsed_chat_result["tool_input"]
called_tool = next(
(fn for fn in functions if fn["name"] == called_tool_name), None
)
@ -328,26 +330,39 @@ class OllamaFunctions(ChatOllama):
f"{chat_generation_content}"
)
if called_tool["name"] == DEFAULT_RESPONSE_FUNCTION["name"]:
if (
"tool_input" in parsed_chat_result
and "response" in parsed_chat_result["tool_input"]
):
response = parsed_chat_result["tool_input"]["response"]
elif "response" in parsed_chat_result:
response = parsed_chat_result["response"]
else:
raise ValueError(
f"Failed to parse a response from {self.model} output: "
f"{chat_generation_content}"
)
return ChatResult(
generations=[
ChatGeneration(
message=AIMessage(
content=called_tool_arguments["response"],
content=response,
)
)
]
)
called_tool_arguments = parsed_chat_result["tool_input"]
response_message_with_functions = AIMessage(
content="",
additional_kwargs={
"function_call": {
"name": called_tool_name,
"arguments": json.dumps(called_tool_arguments)
if called_tool_arguments
else "",
},
},
tool_calls=[
ToolCall(
name=called_tool_name,
args=called_tool_arguments if called_tool_arguments else {},
id=f"call_{str(uuid.uuid4()).replace('-', '')}",
)
],
)
return ChatResult(

View File

@ -52,9 +52,51 @@ class TestOllamaFunctions(unittest.TestCase):
res = model.invoke("What's the weather in San Francisco?")
function_call = res.additional_kwargs.get("function_call")
assert function_call
self.assertEqual(function_call.get("name"), "get_current_weather")
self.assertIsInstance(res, AIMessage)
res = AIMessage(**res.__dict__)
tool_calls = res.tool_calls
assert tool_calls
tool_call = tool_calls[0]
assert tool_call
self.assertEqual("get_current_weather", tool_call.get("name"))
def test_default_ollama_functions_default_response(self) -> None:
base_model = OllamaFunctions(model="llama3", format="json")
# bind functions
model = base_model.bind_tools(
tools=[
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, "
"e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
}
]
)
res = model.invoke("What is the capital of France?")
self.assertIsInstance(res, AIMessage)
res = AIMessage(**res.__dict__)
tool_calls = res.tool_calls
if len(tool_calls) > 0:
tool_call = tool_calls[0]
assert tool_call
self.assertEqual("__conversational_response", tool_call.get("name"))
def test_ollama_structured_output(self) -> None:
model = OllamaFunctions(model="phi3")