anthropic[minor]: tool use (#20016)

This commit is contained in:
Bagatur
2024-04-04 13:22:48 -07:00
committed by GitHub
parent 3aacd11846
commit 209de0a561
13 changed files with 1021 additions and 196 deletions

View File

@@ -1,13 +1,17 @@
"""Test chat model integration."""
import os
from typing import Any, Callable, Dict, Literal, Type
import pytest
from anthropic.types import ContentBlock, Message, Usage
from langchain_core.messages import AIMessage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import BaseTool
from langchain_anthropic import ChatAnthropic, ChatAnthropicMessages
from langchain_anthropic.chat_models import _merge_messages, convert_to_anthropic_tool
os.environ["ANTHROPIC_API_KEY"] = "foo"
@@ -83,3 +87,175 @@ def test__format_output() -> None:
llm = ChatAnthropic(model="test", anthropic_api_key="test")
actual = llm._format_output(anthropic_msg)
assert expected == actual
def test__merge_messages() -> None:
messages = [
SystemMessage("foo"),
HumanMessage("bar"),
AIMessage(
[
{"text": "baz", "type": "text"},
{
"tool_input": {"a": "b"},
"type": "tool_use",
"id": "1",
"text": None,
"name": "buz",
},
{"text": "baz", "type": "text"},
{
"tool_input": {"a": "c"},
"type": "tool_use",
"id": "2",
"text": None,
"name": "blah",
},
]
),
ToolMessage("buz output", tool_call_id="1"),
ToolMessage("blah output", tool_call_id="2"),
HumanMessage("next thing"),
]
expected = [
SystemMessage("foo"),
HumanMessage("bar"),
AIMessage(
[
{"text": "baz", "type": "text"},
{
"tool_input": {"a": "b"},
"type": "tool_use",
"id": "1",
"text": None,
"name": "buz",
},
{"text": "baz", "type": "text"},
{
"tool_input": {"a": "c"},
"type": "tool_use",
"id": "2",
"text": None,
"name": "blah",
},
]
),
HumanMessage(
[
{"type": "tool_result", "content": "buz output", "tool_use_id": "1"},
{"type": "tool_result", "content": "blah output", "tool_use_id": "2"},
{"type": "text", "text": "next thing"},
]
),
]
actual = _merge_messages(messages)
assert expected == actual
@pytest.fixture()
def pydantic() -> Type[BaseModel]:
class dummy_function(BaseModel):
"""dummy function"""
arg1: int = Field(..., description="foo")
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
return dummy_function
@pytest.fixture()
def function() -> Callable:
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
"""dummy function
Args:
arg1: foo
arg2: one of 'bar', 'baz'
"""
pass
return dummy_function
@pytest.fixture()
def dummy_tool() -> BaseTool:
class Schema(BaseModel):
arg1: int = Field(..., description="foo")
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
class DummyFunction(BaseTool):
args_schema: Type[BaseModel] = Schema
name: str = "dummy_function"
description: str = "dummy function"
def _run(self, *args: Any, **kwargs: Any) -> Any:
pass
return DummyFunction()
@pytest.fixture()
def json_schema() -> Dict:
return {
"title": "dummy_function",
"description": "dummy function",
"type": "object",
"properties": {
"arg1": {"description": "foo", "type": "integer"},
"arg2": {
"description": "one of 'bar', 'baz'",
"enum": ["bar", "baz"],
"type": "string",
},
},
"required": ["arg1", "arg2"],
}
@pytest.fixture()
def openai_function() -> Dict:
return {
"name": "dummy_function",
"description": "dummy function",
"parameters": {
"type": "object",
"properties": {
"arg1": {"description": "foo", "type": "integer"},
"arg2": {
"description": "one of 'bar', 'baz'",
"enum": ["bar", "baz"],
"type": "string",
},
},
"required": ["arg1", "arg2"],
},
}
def test_convert_to_anthropic_tool(
pydantic: Type[BaseModel],
function: Callable,
dummy_tool: BaseTool,
json_schema: Dict,
openai_function: Dict,
) -> None:
expected = {
"name": "dummy_function",
"description": "dummy function",
"input_schema": {
"type": "object",
"properties": {
"arg1": {"description": "foo", "type": "integer"},
"arg2": {
"description": "one of 'bar', 'baz'",
"enum": ["bar", "baz"],
"type": "string",
},
},
"required": ["arg1", "arg2"],
},
}
for fn in (pydantic, function, dummy_tool, json_schema, expected, openai_function):
actual = convert_to_anthropic_tool(fn) # type: ignore
assert actual == expected