mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-02 01:23:07 +00:00
[experimental][llms][ollama_functions] Update OllamaFunctions to send tool_calls
attribute (#21625)
Update OllamaFunctions to return `tool_calls` for AIMessages when used for tool calling.
This commit is contained in:
parent
d61bdeba25
commit
a1899439fc
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import uuid
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
@ -18,7 +19,7 @@ from typing import (
|
||||
from langchain_community.chat_models.ollama import ChatOllama
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.messages import AIMessage, BaseMessage
|
||||
from langchain_core.messages import AIMessage, BaseMessage, ToolCall
|
||||
from langchain_core.output_parsers.base import OutputParserLike
|
||||
from langchain_core.output_parsers.json import JsonOutputParser
|
||||
from langchain_core.output_parsers.pydantic import PydanticOutputParser
|
||||
@ -94,16 +95,19 @@ def parse_response(message: BaseMessage) -> str:
|
||||
"""Extract `function_call` from `AIMessage`."""
|
||||
if isinstance(message, AIMessage):
|
||||
kwargs = message.additional_kwargs
|
||||
if "function_call" in kwargs:
|
||||
tool_calls = message.tool_calls
|
||||
if len(tool_calls) > 0:
|
||||
tool_call = tool_calls[-1]
|
||||
args = tool_call.get("args")
|
||||
return json.dumps(args)
|
||||
elif "function_call" in kwargs:
|
||||
if "arguments" in kwargs["function_call"]:
|
||||
return kwargs["function_call"]["arguments"]
|
||||
raise ValueError(
|
||||
f"`arguments` missing from `function_call` within AIMessage: {message}"
|
||||
)
|
||||
raise ValueError(
|
||||
"`function_call` missing from `additional_kwargs` "
|
||||
f"within AIMessage: {message}"
|
||||
)
|
||||
else:
|
||||
raise ValueError("`tool_calls` missing from AIMessage: {message}")
|
||||
raise ValueError(f"`message` is not an instance of `AIMessage`: {message}")
|
||||
|
||||
|
||||
@ -293,10 +297,9 @@ class OllamaFunctions(ChatOllama):
|
||||
"matching function in `functions`."
|
||||
)
|
||||
del kwargs["function_call"]
|
||||
elif not functions:
|
||||
functions.append(DEFAULT_RESPONSE_FUNCTION)
|
||||
if _is_pydantic_class(functions[0]):
|
||||
functions = [convert_to_ollama_tool(fn) for fn in functions]
|
||||
functions.insert(0, DEFAULT_RESPONSE_FUNCTION)
|
||||
system_message_prompt_template = SystemMessagePromptTemplate.from_template(
|
||||
self.tool_system_prompt_template
|
||||
)
|
||||
@ -318,7 +321,6 @@ class OllamaFunctions(ChatOllama):
|
||||
Response: {chat_generation_content}"""
|
||||
)
|
||||
called_tool_name = parsed_chat_result["tool"]
|
||||
called_tool_arguments = parsed_chat_result["tool_input"]
|
||||
called_tool = next(
|
||||
(fn for fn in functions if fn["name"] == called_tool_name), None
|
||||
)
|
||||
@ -328,26 +330,39 @@ class OllamaFunctions(ChatOllama):
|
||||
f"{chat_generation_content}"
|
||||
)
|
||||
if called_tool["name"] == DEFAULT_RESPONSE_FUNCTION["name"]:
|
||||
if (
|
||||
"tool_input" in parsed_chat_result
|
||||
and "response" in parsed_chat_result["tool_input"]
|
||||
):
|
||||
response = parsed_chat_result["tool_input"]["response"]
|
||||
elif "response" in parsed_chat_result:
|
||||
response = parsed_chat_result["response"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Failed to parse a response from {self.model} output: "
|
||||
f"{chat_generation_content}"
|
||||
)
|
||||
return ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(
|
||||
message=AIMessage(
|
||||
content=called_tool_arguments["response"],
|
||||
content=response,
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
called_tool_arguments = parsed_chat_result["tool_input"]
|
||||
|
||||
response_message_with_functions = AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
"name": called_tool_name,
|
||||
"arguments": json.dumps(called_tool_arguments)
|
||||
if called_tool_arguments
|
||||
else "",
|
||||
},
|
||||
},
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
name=called_tool_name,
|
||||
args=called_tool_arguments if called_tool_arguments else {},
|
||||
id=f"call_{str(uuid.uuid4()).replace('-', '')}",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
return ChatResult(
|
||||
|
@ -52,9 +52,51 @@ class TestOllamaFunctions(unittest.TestCase):
|
||||
|
||||
res = model.invoke("What's the weather in San Francisco?")
|
||||
|
||||
function_call = res.additional_kwargs.get("function_call")
|
||||
assert function_call
|
||||
self.assertEqual(function_call.get("name"), "get_current_weather")
|
||||
self.assertIsInstance(res, AIMessage)
|
||||
res = AIMessage(**res.__dict__)
|
||||
tool_calls = res.tool_calls
|
||||
assert tool_calls
|
||||
tool_call = tool_calls[0]
|
||||
assert tool_call
|
||||
self.assertEqual("get_current_weather", tool_call.get("name"))
|
||||
|
||||
def test_default_ollama_functions_default_response(self) -> None:
|
||||
base_model = OllamaFunctions(model="llama3", format="json")
|
||||
|
||||
# bind functions
|
||||
model = base_model.bind_tools(
|
||||
tools=[
|
||||
{
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, "
|
||||
"e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
res = model.invoke("What is the capital of France?")
|
||||
|
||||
self.assertIsInstance(res, AIMessage)
|
||||
res = AIMessage(**res.__dict__)
|
||||
tool_calls = res.tool_calls
|
||||
if len(tool_calls) > 0:
|
||||
tool_call = tool_calls[0]
|
||||
assert tool_call
|
||||
self.assertEqual("__conversational_response", tool_call.get("name"))
|
||||
|
||||
def test_ollama_structured_output(self) -> None:
|
||||
model = OllamaFunctions(model="phi3")
|
||||
|
Loading…
Reference in New Issue
Block a user