Compare commits

...

3 Commits

Author SHA1 Message Date
Chester Curme
ac9bab38dc add tests 2024-04-18 12:09:32 -04:00
Chester Curme
8211728c6f format 2024-04-18 12:09:26 -04:00
Chester Curme
a989f73ce6 merge 2024-04-18 11:32:18 -04:00
4 changed files with 137 additions and 1 deletions

View File

@@ -1,6 +1,7 @@
import json
import os
import re
import uuid
import warnings
from operator import itemgetter
from typing import (
@@ -37,11 +38,13 @@ from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
ToolCall,
ToolMessage,
)
from langchain_core.messages.tool import default_tool_parser
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.runnables import (
@@ -94,6 +97,39 @@ def _format_image(image_url: str) -> Dict:
}
def _process_ai_message(ai_message: AIMessage) -> AIMessage:
"""Process an AIMessage to ensure tool_calls are in the correct format."""
if "function_call" in ai_message.additional_kwargs and not ai_message.tool_calls:
tool_call_dict = {"function": ai_message.additional_kwargs["function_call"]}
tool_call_dict["id"] = uuid.uuid4().hex[:]
tool_calls, _ = default_tool_parser([tool_call_dict])
return ai_message.copy(update={"tool_calls": tool_calls})
else:
return ai_message
def _process_function_message(
function_message: FunctionMessage,
message_history: Sequence[BaseMessage],
) -> HumanMessage:
"""Process function messages to Anthropic tool_result."""
for message in message_history[::-1]:
if isinstance(message, AIMessage) and message.tool_calls:
for tool_call in message.tool_calls:
if tool_call["name"] == function_message.name:
if isinstance(function_message.content, str):
return HumanMessage(
[
{
"type": "tool_result",
"content": function_message.content,
"tool_use_id": tool_call["id"],
}
]
)
return HumanMessage(function_message.content)
def _merge_messages(
messages: Sequence[BaseMessage],
) -> List[Union[SystemMessage, AIMessage, HumanMessage]]:
@@ -101,6 +137,8 @@ def _merge_messages(
merged: list = []
for curr in messages:
curr = curr.copy(deep=True)
if isinstance(curr, AIMessage):
curr = _process_ai_message(curr)
if isinstance(curr, ToolMessage):
if isinstance(curr.content, str):
curr = HumanMessage(
@@ -114,6 +152,8 @@ def _merge_messages(
)
else:
curr = HumanMessage(curr.content)
if isinstance(curr, FunctionMessage):
curr = _process_function_message(curr, merged)
last = merged[-1] if merged else None
if isinstance(last, HumanMessage) and isinstance(curr, HumanMessage):
if isinstance(last.content, str):

View File

@@ -8,6 +8,7 @@ from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
ToolMessage,
@@ -311,6 +312,41 @@ def test_anthropic_with_empty_text_block() -> None:
model.invoke(messages)
def test_function_message() -> None:
@tool
def my_adder_tool(a: int, b: int) -> int:
"""Takes two integers, a and b, and returns their sum."""
return a + b
model = ChatAnthropic(model="claude-3-sonnet-20240229").bind_tools([my_adder_tool])
text_question = "What is 1 + 2"
function_name = "my_adder_tool"
function_call = {
"name": function_name,
"arguments": json.dumps({"a": "1", "b": "2"}),
}
function_answer = json.dumps({"result": 3})
message1 = HumanMessage(content=text_question)
message2 = AIMessage(
content="",
additional_kwargs={
"function_call": function_call,
},
)
message3 = FunctionMessage(
name=function_name,
content=function_answer,
)
messages = [
message1,
message2,
message3,
]
result = model.invoke(messages)
assert isinstance(result, AIMessage)
def test_with_structured_output() -> None:
llm = ChatAnthropic(
model="claude-3-opus-20240229",

View File

@@ -253,3 +253,11 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler":
return self
class AnyStr(str):
def __init__(self) -> None:
super().__init__()
def __eq__(self, other: object) -> bool:
return isinstance(other, str)

View File

@@ -5,7 +5,13 @@ from typing import Any, Callable, Dict, Literal, Type, cast
import pytest
from anthropic.types import ContentBlock, Message, Usage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.messages import (
AIMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr
from langchain_core.tools import BaseTool
@@ -16,6 +22,7 @@ from langchain_anthropic.chat_models import (
_merge_messages,
convert_to_anthropic_tool,
)
from tests.unit_tests._utils import AnyStr
os.environ["ANTHROPIC_API_KEY"] = "foo"
@@ -419,3 +426,48 @@ def test__format_messages_with_list_content_and_tool_calls() -> None:
)
actual = _format_messages(messages)
assert expected == actual
def test__format_messages_with_function_calls() -> None:
system = SystemMessage("fuzz")
human = HumanMessage("foo")
ai = AIMessage(
"thought",
additional_kwargs={
"function_call": {"arguments": '{"baz":"buzz"}', "name": "bar"}
},
)
function = FunctionMessage(
content="blurb",
name="bar",
)
messages = [system, human, ai, function]
expected = (
"fuzz",
[
{"role": "user", "content": "foo"},
{
"role": "assistant",
"content": [
{
"type": "text",
"text": "thought",
},
{
"type": "tool_use",
"name": "bar",
"id": AnyStr(),
"input": {"baz": "buzz"},
},
],
},
{
"role": "user",
"content": [
{"type": "tool_result", "content": "blurb", "tool_use_id": AnyStr()}
],
},
],
)
actual = _format_messages(messages)
assert expected == actual