mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-15 22:44:36 +00:00
anthropic[minor]: tool use (#20016)
This commit is contained in:
@@ -1,13 +1,17 @@
|
||||
"""Test chat model integration."""
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Literal, Type
|
||||
|
||||
import pytest
|
||||
from anthropic.types import ContentBlock, Message, Usage
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from langchain_anthropic import ChatAnthropic, ChatAnthropicMessages
|
||||
from langchain_anthropic.chat_models import _merge_messages, convert_to_anthropic_tool
|
||||
|
||||
os.environ["ANTHROPIC_API_KEY"] = "foo"
|
||||
|
||||
@@ -83,3 +87,175 @@ def test__format_output() -> None:
|
||||
llm = ChatAnthropic(model="test", anthropic_api_key="test")
|
||||
actual = llm._format_output(anthropic_msg)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test__merge_messages() -> None:
|
||||
messages = [
|
||||
SystemMessage("foo"),
|
||||
HumanMessage("bar"),
|
||||
AIMessage(
|
||||
[
|
||||
{"text": "baz", "type": "text"},
|
||||
{
|
||||
"tool_input": {"a": "b"},
|
||||
"type": "tool_use",
|
||||
"id": "1",
|
||||
"text": None,
|
||||
"name": "buz",
|
||||
},
|
||||
{"text": "baz", "type": "text"},
|
||||
{
|
||||
"tool_input": {"a": "c"},
|
||||
"type": "tool_use",
|
||||
"id": "2",
|
||||
"text": None,
|
||||
"name": "blah",
|
||||
},
|
||||
]
|
||||
),
|
||||
ToolMessage("buz output", tool_call_id="1"),
|
||||
ToolMessage("blah output", tool_call_id="2"),
|
||||
HumanMessage("next thing"),
|
||||
]
|
||||
expected = [
|
||||
SystemMessage("foo"),
|
||||
HumanMessage("bar"),
|
||||
AIMessage(
|
||||
[
|
||||
{"text": "baz", "type": "text"},
|
||||
{
|
||||
"tool_input": {"a": "b"},
|
||||
"type": "tool_use",
|
||||
"id": "1",
|
||||
"text": None,
|
||||
"name": "buz",
|
||||
},
|
||||
{"text": "baz", "type": "text"},
|
||||
{
|
||||
"tool_input": {"a": "c"},
|
||||
"type": "tool_use",
|
||||
"id": "2",
|
||||
"text": None,
|
||||
"name": "blah",
|
||||
},
|
||||
]
|
||||
),
|
||||
HumanMessage(
|
||||
[
|
||||
{"type": "tool_result", "content": "buz output", "tool_use_id": "1"},
|
||||
{"type": "tool_result", "content": "blah output", "tool_use_id": "2"},
|
||||
{"type": "text", "text": "next thing"},
|
||||
]
|
||||
),
|
||||
]
|
||||
actual = _merge_messages(messages)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def pydantic() -> Type[BaseModel]:
|
||||
class dummy_function(BaseModel):
|
||||
"""dummy function"""
|
||||
|
||||
arg1: int = Field(..., description="foo")
|
||||
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
|
||||
|
||||
return dummy_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def function() -> Callable:
|
||||
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
|
||||
"""dummy function
|
||||
|
||||
Args:
|
||||
arg1: foo
|
||||
arg2: one of 'bar', 'baz'
|
||||
"""
|
||||
pass
|
||||
|
||||
return dummy_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def dummy_tool() -> BaseTool:
|
||||
class Schema(BaseModel):
|
||||
arg1: int = Field(..., description="foo")
|
||||
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
|
||||
|
||||
class DummyFunction(BaseTool):
|
||||
args_schema: Type[BaseModel] = Schema
|
||||
name: str = "dummy_function"
|
||||
description: str = "dummy function"
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
pass
|
||||
|
||||
return DummyFunction()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def json_schema() -> Dict:
|
||||
return {
|
||||
"title": "dummy_function",
|
||||
"description": "dummy function",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"arg1": {"description": "foo", "type": "integer"},
|
||||
"arg2": {
|
||||
"description": "one of 'bar', 'baz'",
|
||||
"enum": ["bar", "baz"],
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": ["arg1", "arg2"],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def openai_function() -> Dict:
|
||||
return {
|
||||
"name": "dummy_function",
|
||||
"description": "dummy function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"arg1": {"description": "foo", "type": "integer"},
|
||||
"arg2": {
|
||||
"description": "one of 'bar', 'baz'",
|
||||
"enum": ["bar", "baz"],
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": ["arg1", "arg2"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_convert_to_anthropic_tool(
|
||||
pydantic: Type[BaseModel],
|
||||
function: Callable,
|
||||
dummy_tool: BaseTool,
|
||||
json_schema: Dict,
|
||||
openai_function: Dict,
|
||||
) -> None:
|
||||
expected = {
|
||||
"name": "dummy_function",
|
||||
"description": "dummy function",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"arg1": {"description": "foo", "type": "integer"},
|
||||
"arg2": {
|
||||
"description": "one of 'bar', 'baz'",
|
||||
"enum": ["bar", "baz"],
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": ["arg1", "arg2"],
|
||||
},
|
||||
}
|
||||
|
||||
for fn in (pydantic, function, dummy_tool, json_schema, expected, openai_function):
|
||||
actual = convert_to_anthropic_tool(fn) # type: ignore
|
||||
assert actual == expected
|
||||
|
Reference in New Issue
Block a user