[experimental][llms][OllamaFunctions] tool calling related fixes (#22339)

Fixes issues with tool calling to handle tool objects correctly. Added
support to handle ToolMessage correctly.
Added additional checks for error conditions.

---------

Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
Karim Lalani
2024-06-12 15:34:43 -05:00
committed by GitHub
parent d04e899b56
commit 276be6cdd4
2 changed files with 115 additions and 21 deletions

View File

@@ -2,6 +2,8 @@
import unittest
from langchain_community.tools import DuckDuckGoSearchResults
from langchain_community.tools.pubmed.tool import PubmedQueryRun
from langchain_core.messages import AIMessage
from langchain_core.pydantic_v1 import BaseModel, Field
@@ -22,7 +24,7 @@ class TestOllamaFunctions(unittest.TestCase):
"""
def test_default_ollama_functions(self) -> None:
base_model = OllamaFunctions(model="llama3", format="json")
base_model = OllamaFunctions(model="phi3", format="json")
# bind functions
model = base_model.bind_tools(
@@ -60,8 +62,22 @@ class TestOllamaFunctions(unittest.TestCase):
assert tool_call
self.assertEqual("get_current_weather", tool_call.get("name"))
def test_ollama_functions_tools(self) -> None:
base_model = OllamaFunctions(model="phi3", format="json")
model = base_model.bind_tools(
tools=[PubmedQueryRun(), DuckDuckGoSearchResults(max_results=2)]
)
res = model.invoke("What causes lung cancer?")
self.assertIsInstance(res, AIMessage)
res = AIMessage(**res.__dict__)
tool_calls = res.tool_calls
assert tool_calls
tool_call = tool_calls[0]
assert tool_call
self.assertEqual("pub_med", tool_call.get("name"))
def test_default_ollama_functions_default_response(self) -> None:
base_model = OllamaFunctions(model="llama3", format="json")
base_model = OllamaFunctions(model="phi3", format="json")
# bind functions
model = base_model.bind_tools(