mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 16:13:25 +00:00
parent
12e0c28a6e
commit
a7b4175091
@ -1,10 +1,18 @@
|
||||
import base64
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, ToolMessage
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessageChunk,
|
||||
HumanMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from langchain_standard_tests.unit_tests.chat_models import (
|
||||
ChatModelTests,
|
||||
@ -12,6 +20,21 @@ from langchain_standard_tests.unit_tests.chat_models import (
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
def magic_function(input: int) -> int:
|
||||
"""Applies a magic function to an input."""
|
||||
return input + 2
|
||||
|
||||
|
||||
def _validate_tool_call_message(message: AIMessage) -> None:
|
||||
assert isinstance(message, AIMessage)
|
||||
assert len(message.tool_calls) == 1
|
||||
tool_call = message.tool_calls[0]
|
||||
assert tool_call["name"] == "magic_function"
|
||||
assert tool_call["args"] == {"input": 3}
|
||||
assert tool_call["id"] is not None
|
||||
|
||||
|
||||
class ChatModelIntegrationTests(ChatModelTests):
|
||||
def test_invoke(self, model: BaseChatModel) -> None:
|
||||
result = model.invoke("Hello")
|
||||
@ -98,6 +121,24 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
result = custom_model.invoke("hi")
|
||||
assert isinstance(result, AIMessage)
|
||||
|
||||
def test_tool_calling(self, model: BaseChatModel) -> None:
|
||||
if not self.has_tool_calling:
|
||||
pytest.skip("Test requires tool calling.")
|
||||
model_with_tools = model.bind_tools([magic_function])
|
||||
|
||||
# Test invoke
|
||||
query = "What is the value of magic_function(3)? Use the tool."
|
||||
result = model_with_tools.invoke(query)
|
||||
assert isinstance(result, AIMessage)
|
||||
_validate_tool_call_message(result)
|
||||
|
||||
# Test stream
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
for chunk in model_with_tools.stream(query):
|
||||
full = chunk if full is None else full + chunk # type: ignore
|
||||
assert isinstance(full, AIMessage)
|
||||
_validate_tool_call_message(full)
|
||||
|
||||
def test_tool_message_histories_string_content(
|
||||
self,
|
||||
model: BaseChatModel,
|
||||
|
Loading…
Reference in New Issue
Block a user