diff --git a/libs/experimental/langchain_experimental/llms/ollama_functions.py b/libs/experimental/langchain_experimental/llms/ollama_functions.py index 7bd04f918f0..f71f91c01a9 100644 --- a/libs/experimental/langchain_experimental/llms/ollama_functions.py +++ b/libs/experimental/langchain_experimental/llms/ollama_functions.py @@ -1,4 +1,5 @@ import json +import uuid from operator import itemgetter from typing import ( Any, @@ -18,7 +19,7 @@ from typing import ( from langchain_community.chat_models.ollama import ChatOllama from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models import LanguageModelInput -from langchain_core.messages import AIMessage, BaseMessage +from langchain_core.messages import AIMessage, BaseMessage, ToolCall from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.json import JsonOutputParser from langchain_core.output_parsers.pydantic import PydanticOutputParser @@ -94,16 +95,19 @@ def parse_response(message: BaseMessage) -> str: """Extract `function_call` from `AIMessage`.""" if isinstance(message, AIMessage): kwargs = message.additional_kwargs - if "function_call" in kwargs: + tool_calls = message.tool_calls + if len(tool_calls) > 0: + tool_call = tool_calls[-1] + args = tool_call.get("args") + return json.dumps(args) + elif "function_call" in kwargs: if "arguments" in kwargs["function_call"]: return kwargs["function_call"]["arguments"] raise ValueError( f"`arguments` missing from `function_call` within AIMessage: {message}" ) - raise ValueError( - "`function_call` missing from `additional_kwargs` " - f"within AIMessage: {message}" - ) + else: + raise ValueError("`tool_calls` missing from AIMessage: {message}") raise ValueError(f"`message` is not an instance of `AIMessage`: {message}") @@ -293,10 +297,9 @@ class OllamaFunctions(ChatOllama): "matching function in `functions`." ) del kwargs["function_call"] - elif not functions: - functions.append(DEFAULT_RESPONSE_FUNCTION) if _is_pydantic_class(functions[0]): functions = [convert_to_ollama_tool(fn) for fn in functions] + functions.insert(0, DEFAULT_RESPONSE_FUNCTION) system_message_prompt_template = SystemMessagePromptTemplate.from_template( self.tool_system_prompt_template ) @@ -318,7 +321,6 @@ class OllamaFunctions(ChatOllama): Response: {chat_generation_content}""" ) called_tool_name = parsed_chat_result["tool"] - called_tool_arguments = parsed_chat_result["tool_input"] called_tool = next( (fn for fn in functions if fn["name"] == called_tool_name), None ) @@ -328,26 +330,39 @@ class OllamaFunctions(ChatOllama): f"{chat_generation_content}" ) if called_tool["name"] == DEFAULT_RESPONSE_FUNCTION["name"]: + if ( + "tool_input" in parsed_chat_result + and "response" in parsed_chat_result["tool_input"] + ): + response = parsed_chat_result["tool_input"]["response"] + elif "response" in parsed_chat_result: + response = parsed_chat_result["response"] + else: + raise ValueError( + f"Failed to parse a response from {self.model} output: " + f"{chat_generation_content}" + ) return ChatResult( generations=[ ChatGeneration( message=AIMessage( - content=called_tool_arguments["response"], + content=response, ) ) ] ) + called_tool_arguments = parsed_chat_result["tool_input"] + response_message_with_functions = AIMessage( content="", - additional_kwargs={ - "function_call": { - "name": called_tool_name, - "arguments": json.dumps(called_tool_arguments) - if called_tool_arguments - else "", - }, - }, + tool_calls=[ + ToolCall( + name=called_tool_name, + args=called_tool_arguments if called_tool_arguments else {}, + id=f"call_{str(uuid.uuid4()).replace('-', '')}", + ) + ], ) return ChatResult( diff --git a/libs/experimental/tests/integration_tests/llms/test_ollama_functions.py b/libs/experimental/tests/integration_tests/llms/test_ollama_functions.py index fb63ee5d388..6c413a361da 100644 --- a/libs/experimental/tests/integration_tests/llms/test_ollama_functions.py +++ b/libs/experimental/tests/integration_tests/llms/test_ollama_functions.py @@ -52,9 +52,51 @@ class TestOllamaFunctions(unittest.TestCase): res = model.invoke("What's the weather in San Francisco?") - function_call = res.additional_kwargs.get("function_call") - assert function_call - self.assertEqual(function_call.get("name"), "get_current_weather") + self.assertIsInstance(res, AIMessage) + res = AIMessage(**res.__dict__) + tool_calls = res.tool_calls + assert tool_calls + tool_call = tool_calls[0] + assert tool_call + self.assertEqual("get_current_weather", tool_call.get("name")) + + def test_default_ollama_functions_default_response(self) -> None: + base_model = OllamaFunctions(model="llama3", format="json") + + # bind functions + model = base_model.bind_tools( + tools=[ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, " + "e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], + }, + } + ] + ) + + res = model.invoke("What is the capital of France?") + + self.assertIsInstance(res, AIMessage) + res = AIMessage(**res.__dict__) + tool_calls = res.tool_calls + if len(tool_calls) > 0: + tool_call = tool_calls[0] + assert tool_call + self.assertEqual("__conversational_response", tool_call.get("name")) def test_ollama_structured_output(self) -> None: model = OllamaFunctions(model="phi3")