mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 08:33:49 +00:00
parent
12e0c28a6e
commit
a7b4175091
@ -1,10 +1,18 @@
|
|||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, ToolMessage
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
AIMessageChunk,
|
||||||
|
BaseMessageChunk,
|
||||||
|
HumanMessage,
|
||||||
|
ToolMessage,
|
||||||
|
)
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from langchain_standard_tests.unit_tests.chat_models import (
|
from langchain_standard_tests.unit_tests.chat_models import (
|
||||||
ChatModelTests,
|
ChatModelTests,
|
||||||
@ -12,6 +20,21 @@ from langchain_standard_tests.unit_tests.chat_models import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def magic_function(input: int) -> int:
|
||||||
|
"""Applies a magic function to an input."""
|
||||||
|
return input + 2
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_tool_call_message(message: AIMessage) -> None:
|
||||||
|
assert isinstance(message, AIMessage)
|
||||||
|
assert len(message.tool_calls) == 1
|
||||||
|
tool_call = message.tool_calls[0]
|
||||||
|
assert tool_call["name"] == "magic_function"
|
||||||
|
assert tool_call["args"] == {"input": 3}
|
||||||
|
assert tool_call["id"] is not None
|
||||||
|
|
||||||
|
|
||||||
class ChatModelIntegrationTests(ChatModelTests):
|
class ChatModelIntegrationTests(ChatModelTests):
|
||||||
def test_invoke(self, model: BaseChatModel) -> None:
|
def test_invoke(self, model: BaseChatModel) -> None:
|
||||||
result = model.invoke("Hello")
|
result = model.invoke("Hello")
|
||||||
@ -98,6 +121,24 @@ class ChatModelIntegrationTests(ChatModelTests):
|
|||||||
result = custom_model.invoke("hi")
|
result = custom_model.invoke("hi")
|
||||||
assert isinstance(result, AIMessage)
|
assert isinstance(result, AIMessage)
|
||||||
|
|
||||||
|
def test_tool_calling(self, model: BaseChatModel) -> None:
|
||||||
|
if not self.has_tool_calling:
|
||||||
|
pytest.skip("Test requires tool calling.")
|
||||||
|
model_with_tools = model.bind_tools([magic_function])
|
||||||
|
|
||||||
|
# Test invoke
|
||||||
|
query = "What is the value of magic_function(3)? Use the tool."
|
||||||
|
result = model_with_tools.invoke(query)
|
||||||
|
assert isinstance(result, AIMessage)
|
||||||
|
_validate_tool_call_message(result)
|
||||||
|
|
||||||
|
# Test stream
|
||||||
|
full: Optional[BaseMessageChunk] = None
|
||||||
|
for chunk in model_with_tools.stream(query):
|
||||||
|
full = chunk if full is None else full + chunk # type: ignore
|
||||||
|
assert isinstance(full, AIMessage)
|
||||||
|
_validate_tool_call_message(full)
|
||||||
|
|
||||||
def test_tool_message_histories_string_content(
|
def test_tool_message_histories_string_content(
|
||||||
self,
|
self,
|
||||||
model: BaseChatModel,
|
model: BaseChatModel,
|
||||||
|
Loading…
Reference in New Issue
Block a user