mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-15 23:57:21 +00:00
[experimental][llms][ollama_functions] Update OllamaFunctions to send tool_calls
attribute (#21625)
Update OllamaFunctions to return `tool_calls` for AIMessages when used for tool calling.
This commit is contained in:
parent
d61bdeba25
commit
a1899439fc
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import uuid
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
@ -18,7 +19,7 @@ from typing import (
|
|||||||
from langchain_community.chat_models.ollama import ChatOllama
|
from langchain_community.chat_models.ollama import ChatOllama
|
||||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
from langchain_core.language_models import LanguageModelInput
|
from langchain_core.language_models import LanguageModelInput
|
||||||
from langchain_core.messages import AIMessage, BaseMessage
|
from langchain_core.messages import AIMessage, BaseMessage, ToolCall
|
||||||
from langchain_core.output_parsers.base import OutputParserLike
|
from langchain_core.output_parsers.base import OutputParserLike
|
||||||
from langchain_core.output_parsers.json import JsonOutputParser
|
from langchain_core.output_parsers.json import JsonOutputParser
|
||||||
from langchain_core.output_parsers.pydantic import PydanticOutputParser
|
from langchain_core.output_parsers.pydantic import PydanticOutputParser
|
||||||
@ -94,16 +95,19 @@ def parse_response(message: BaseMessage) -> str:
|
|||||||
"""Extract `function_call` from `AIMessage`."""
|
"""Extract `function_call` from `AIMessage`."""
|
||||||
if isinstance(message, AIMessage):
|
if isinstance(message, AIMessage):
|
||||||
kwargs = message.additional_kwargs
|
kwargs = message.additional_kwargs
|
||||||
if "function_call" in kwargs:
|
tool_calls = message.tool_calls
|
||||||
|
if len(tool_calls) > 0:
|
||||||
|
tool_call = tool_calls[-1]
|
||||||
|
args = tool_call.get("args")
|
||||||
|
return json.dumps(args)
|
||||||
|
elif "function_call" in kwargs:
|
||||||
if "arguments" in kwargs["function_call"]:
|
if "arguments" in kwargs["function_call"]:
|
||||||
return kwargs["function_call"]["arguments"]
|
return kwargs["function_call"]["arguments"]
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`arguments` missing from `function_call` within AIMessage: {message}"
|
f"`arguments` missing from `function_call` within AIMessage: {message}"
|
||||||
)
|
)
|
||||||
raise ValueError(
|
else:
|
||||||
"`function_call` missing from `additional_kwargs` "
|
raise ValueError("`tool_calls` missing from AIMessage: {message}")
|
||||||
f"within AIMessage: {message}"
|
|
||||||
)
|
|
||||||
raise ValueError(f"`message` is not an instance of `AIMessage`: {message}")
|
raise ValueError(f"`message` is not an instance of `AIMessage`: {message}")
|
||||||
|
|
||||||
|
|
||||||
@ -293,10 +297,9 @@ class OllamaFunctions(ChatOllama):
|
|||||||
"matching function in `functions`."
|
"matching function in `functions`."
|
||||||
)
|
)
|
||||||
del kwargs["function_call"]
|
del kwargs["function_call"]
|
||||||
elif not functions:
|
|
||||||
functions.append(DEFAULT_RESPONSE_FUNCTION)
|
|
||||||
if _is_pydantic_class(functions[0]):
|
if _is_pydantic_class(functions[0]):
|
||||||
functions = [convert_to_ollama_tool(fn) for fn in functions]
|
functions = [convert_to_ollama_tool(fn) for fn in functions]
|
||||||
|
functions.insert(0, DEFAULT_RESPONSE_FUNCTION)
|
||||||
system_message_prompt_template = SystemMessagePromptTemplate.from_template(
|
system_message_prompt_template = SystemMessagePromptTemplate.from_template(
|
||||||
self.tool_system_prompt_template
|
self.tool_system_prompt_template
|
||||||
)
|
)
|
||||||
@ -318,7 +321,6 @@ class OllamaFunctions(ChatOllama):
|
|||||||
Response: {chat_generation_content}"""
|
Response: {chat_generation_content}"""
|
||||||
)
|
)
|
||||||
called_tool_name = parsed_chat_result["tool"]
|
called_tool_name = parsed_chat_result["tool"]
|
||||||
called_tool_arguments = parsed_chat_result["tool_input"]
|
|
||||||
called_tool = next(
|
called_tool = next(
|
||||||
(fn for fn in functions if fn["name"] == called_tool_name), None
|
(fn for fn in functions if fn["name"] == called_tool_name), None
|
||||||
)
|
)
|
||||||
@ -328,26 +330,39 @@ class OllamaFunctions(ChatOllama):
|
|||||||
f"{chat_generation_content}"
|
f"{chat_generation_content}"
|
||||||
)
|
)
|
||||||
if called_tool["name"] == DEFAULT_RESPONSE_FUNCTION["name"]:
|
if called_tool["name"] == DEFAULT_RESPONSE_FUNCTION["name"]:
|
||||||
|
if (
|
||||||
|
"tool_input" in parsed_chat_result
|
||||||
|
and "response" in parsed_chat_result["tool_input"]
|
||||||
|
):
|
||||||
|
response = parsed_chat_result["tool_input"]["response"]
|
||||||
|
elif "response" in parsed_chat_result:
|
||||||
|
response = parsed_chat_result["response"]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to parse a response from {self.model} output: "
|
||||||
|
f"{chat_generation_content}"
|
||||||
|
)
|
||||||
return ChatResult(
|
return ChatResult(
|
||||||
generations=[
|
generations=[
|
||||||
ChatGeneration(
|
ChatGeneration(
|
||||||
message=AIMessage(
|
message=AIMessage(
|
||||||
content=called_tool_arguments["response"],
|
content=response,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
called_tool_arguments = parsed_chat_result["tool_input"]
|
||||||
|
|
||||||
response_message_with_functions = AIMessage(
|
response_message_with_functions = AIMessage(
|
||||||
content="",
|
content="",
|
||||||
additional_kwargs={
|
tool_calls=[
|
||||||
"function_call": {
|
ToolCall(
|
||||||
"name": called_tool_name,
|
name=called_tool_name,
|
||||||
"arguments": json.dumps(called_tool_arguments)
|
args=called_tool_arguments if called_tool_arguments else {},
|
||||||
if called_tool_arguments
|
id=f"call_{str(uuid.uuid4()).replace('-', '')}",
|
||||||
else "",
|
)
|
||||||
},
|
],
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return ChatResult(
|
return ChatResult(
|
||||||
|
@ -52,9 +52,51 @@ class TestOllamaFunctions(unittest.TestCase):
|
|||||||
|
|
||||||
res = model.invoke("What's the weather in San Francisco?")
|
res = model.invoke("What's the weather in San Francisco?")
|
||||||
|
|
||||||
function_call = res.additional_kwargs.get("function_call")
|
self.assertIsInstance(res, AIMessage)
|
||||||
assert function_call
|
res = AIMessage(**res.__dict__)
|
||||||
self.assertEqual(function_call.get("name"), "get_current_weather")
|
tool_calls = res.tool_calls
|
||||||
|
assert tool_calls
|
||||||
|
tool_call = tool_calls[0]
|
||||||
|
assert tool_call
|
||||||
|
self.assertEqual("get_current_weather", tool_call.get("name"))
|
||||||
|
|
||||||
|
def test_default_ollama_functions_default_response(self) -> None:
|
||||||
|
base_model = OllamaFunctions(model="llama3", format="json")
|
||||||
|
|
||||||
|
# bind functions
|
||||||
|
model = base_model.bind_tools(
|
||||||
|
tools=[
|
||||||
|
{
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, "
|
||||||
|
"e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
res = model.invoke("What is the capital of France?")
|
||||||
|
|
||||||
|
self.assertIsInstance(res, AIMessage)
|
||||||
|
res = AIMessage(**res.__dict__)
|
||||||
|
tool_calls = res.tool_calls
|
||||||
|
if len(tool_calls) > 0:
|
||||||
|
tool_call = tool_calls[0]
|
||||||
|
assert tool_call
|
||||||
|
self.assertEqual("__conversational_response", tool_call.get("name"))
|
||||||
|
|
||||||
def test_ollama_structured_output(self) -> None:
|
def test_ollama_structured_output(self) -> None:
|
||||||
model = OllamaFunctions(model="phi3")
|
model = OllamaFunctions(model="phi3")
|
||||||
|
Loading…
Reference in New Issue
Block a user