mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-25 04:40:06 +00:00
[experimental][llms][OllamaFunctions] tool calling related fixes (#22339)
Fixes issues with tool calling to handle tool objects correctly. Added support to handle ToolMessage correctly. Added additional checks for error conditions. --------- Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
@@ -2,6 +2,8 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from langchain_community.tools import DuckDuckGoSearchResults
|
||||
from langchain_community.tools.pubmed.tool import PubmedQueryRun
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
|
||||
@@ -22,7 +24,7 @@ class TestOllamaFunctions(unittest.TestCase):
|
||||
"""
|
||||
|
||||
def test_default_ollama_functions(self) -> None:
|
||||
base_model = OllamaFunctions(model="llama3", format="json")
|
||||
base_model = OllamaFunctions(model="phi3", format="json")
|
||||
|
||||
# bind functions
|
||||
model = base_model.bind_tools(
|
||||
@@ -60,8 +62,22 @@ class TestOllamaFunctions(unittest.TestCase):
|
||||
assert tool_call
|
||||
self.assertEqual("get_current_weather", tool_call.get("name"))
|
||||
|
||||
def test_ollama_functions_tools(self) -> None:
|
||||
base_model = OllamaFunctions(model="phi3", format="json")
|
||||
model = base_model.bind_tools(
|
||||
tools=[PubmedQueryRun(), DuckDuckGoSearchResults(max_results=2)]
|
||||
)
|
||||
res = model.invoke("What causes lung cancer?")
|
||||
self.assertIsInstance(res, AIMessage)
|
||||
res = AIMessage(**res.__dict__)
|
||||
tool_calls = res.tool_calls
|
||||
assert tool_calls
|
||||
tool_call = tool_calls[0]
|
||||
assert tool_call
|
||||
self.assertEqual("pub_med", tool_call.get("name"))
|
||||
|
||||
def test_default_ollama_functions_default_response(self) -> None:
|
||||
base_model = OllamaFunctions(model="llama3", format="json")
|
||||
base_model = OllamaFunctions(model="phi3", format="json")
|
||||
|
||||
# bind functions
|
||||
model = base_model.bind_tools(
|
||||
|
Reference in New Issue
Block a user