langchain/libs/community/tests/unit_tests/chat_models/test_writer.py
2024-11-25 19:28:01 +00:00

509 lines
17 KiB
Python

import json
from typing import Any, Dict, List, Literal, Optional, Tuple, Type
from unittest import mock
from unittest.mock import AsyncMock, MagicMock
import pytest
from langchain_core.callbacks.manager import CallbackManager
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_tests.unit_tests import ChatModelUnitTests
from pydantic import SecretStr
from langchain_community.chat_models.writer import ChatWriter
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
"""Classes for mocking Writer responses."""
class ChoiceDelta:
def __init__(self, content: str):
self.content = content
class ChunkChoice:
def __init__(self, index: int, finish_reason: str, delta: ChoiceDelta):
self.index = index
self.finish_reason = finish_reason
self.delta = delta
class ChatCompletionChunk:
def __init__(
self,
id: str,
object: str,
created: int,
model: str,
choices: List[ChunkChoice],
):
self.id = id
self.object = object
self.created = created
self.model = model
self.choices = choices
class ToolCallFunction:
def __init__(self, name: str, arguments: str):
self.name = name
self.arguments = arguments
class ChoiceMessageToolCall:
def __init__(self, id: str, type: str, function: ToolCallFunction):
self.id = id
self.type = type
self.function = function
class Usage:
def __init__(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
):
self.prompt_tokens = prompt_tokens
self.completion_tokens = completion_tokens
self.total_tokens = total_tokens
class ChoiceMessage:
def __init__(
self,
role: str,
content: str,
tool_calls: Optional[List[ChoiceMessageToolCall]] = None,
):
self.role = role
self.content = content
self.tool_calls = tool_calls
class Choice:
def __init__(self, index: int, finish_reason: str, message: ChoiceMessage):
self.index = index
self.finish_reason = finish_reason
self.message = message
class Chat:
def __init__(
self,
id: str,
object: str,
created: int,
system_fingerprint: str,
model: str,
usage: Usage,
choices: List[Choice],
):
self.id = id
self.object = object
self.created = created
self.system_fingerprint = system_fingerprint
self.model = model
self.usage = usage
self.choices = choices
@pytest.mark.requires("writerai")
class TestChatWriterCustom:
"""Test case for ChatWriter"""
@pytest.fixture(autouse=True)
def mock_unstreaming_completion(self) -> Chat:
"""Fixture providing a mock API response."""
return Chat(
id="chat-12345",
object="chat.completion",
created=1699000000,
model="palmyra-x-004",
system_fingerprint="v1",
usage=Usage(prompt_tokens=10, completion_tokens=8, total_tokens=18),
choices=[
Choice(
index=0,
finish_reason="stop",
message=ChoiceMessage(
role="assistant",
content="Hello! How can I help you?",
),
)
],
)
@pytest.fixture(autouse=True)
def mock_tool_call_choice_response(self) -> Chat:
return Chat(
id="chat-12345",
object="chat.completion",
created=1699000000,
model="palmyra-x-004",
system_fingerprint="v1",
usage=Usage(prompt_tokens=29, completion_tokens=32, total_tokens=61),
choices=[
Choice(
index=0,
finish_reason="tool_calls",
message=ChoiceMessage(
role="assistant",
content="",
tool_calls=[
ChoiceMessageToolCall(
id="call_abc123",
type="function",
function=ToolCallFunction(
name="GetWeather",
arguments='{"location": "London"}',
),
)
],
),
)
],
)
@pytest.fixture(autouse=True)
def mock_streaming_chunks(self) -> List[ChatCompletionChunk]:
"""Fixture providing mock streaming response chunks."""
return [
ChatCompletionChunk(
id="chat-12345",
object="chat.completion",
created=1699000000,
model="palmyra-x-004",
choices=[
ChunkChoice(
index=0,
finish_reason="stop",
delta=ChoiceDelta(content="Hello! "),
)
],
),
ChatCompletionChunk(
id="chat-12345",
object="chat.completion",
created=1699000000,
model="palmyra-x-004",
choices=[
ChunkChoice(
index=0,
finish_reason="stop",
delta=ChoiceDelta(content="How can I help you?"),
)
],
),
]
def test_writer_model_param(self) -> None:
"""Test different ways to initialize the chat model."""
test_cases: List[dict] = [
{
"model_name": "palmyra-x-004",
"api_key": "key",
},
{
"model": "palmyra-x-004",
"api_key": "key",
},
{
"model_name": "palmyra-x-004",
"api_key": "key",
},
{
"model": "palmyra-x-004",
"temperature": 0.5,
"api_key": "key",
},
]
for case in test_cases:
chat = ChatWriter(**case)
assert chat.model_name == "palmyra-x-004"
assert chat.temperature == (0.5 if "temperature" in case else 0.7)
def test_convert_writer_to_langchain_human(self) -> None:
"""Test converting a human message dict to a LangChain message."""
message = {"role": "user", "content": "Hello"}
result = ChatWriter._convert_writer_to_langchain(message)
assert isinstance(result, HumanMessage)
assert result.content == "Hello"
def test_convert_writer_to_langchain_ai(self) -> None:
"""Test converting an AI message dict to a LangChain message."""
message = {"role": "assistant", "content": "Hello"}
result = ChatWriter._convert_writer_to_langchain(message)
assert isinstance(result, AIMessage)
assert result.content == "Hello"
def test_convert_writer_to_langchain_system(self) -> None:
"""Test converting a system message dict to a LangChain message."""
message = {"role": "system", "content": "You are a helpful assistant"}
result = ChatWriter._convert_writer_to_langchain(message)
assert isinstance(result, SystemMessage)
assert result.content == "You are a helpful assistant"
def test_convert_writer_to_langchain_tool_call(self) -> None:
"""Test converting a tool call message dict to a LangChain message."""
content = json.dumps({"result": 42})
message = {
"role": "tool",
"name": "get_number",
"content": content,
"tool_call_id": "call_abc123",
}
result = ChatWriter._convert_writer_to_langchain(message)
assert isinstance(result, ToolMessage)
assert result.name == "get_number"
assert result.content == content
def test_convert_writer_to_langchain_with_tool_calls(self) -> None:
"""Test converting an AIMessage with tool calls."""
message = {
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_abc123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"location": "London"}',
},
}
],
}
result = ChatWriter._convert_writer_to_langchain(message)
assert isinstance(result, AIMessage)
assert result.tool_calls
assert len(result.tool_calls) == 1
assert result.tool_calls[0]["name"] == "get_weather"
assert result.tool_calls[0]["args"]["location"] == "London"
def test_sync_completion(
self, mock_unstreaming_completion: List[ChatCompletionChunk]
) -> None:
"""Test basic chat completion with mocked response."""
chat = ChatWriter(api_key=SecretStr("key"))
mock_client = MagicMock()
mock_client.chat.chat.return_value = mock_unstreaming_completion
with mock.patch.object(chat, "client", mock_client):
message = HumanMessage(content="Hi there!")
response = chat.invoke([message])
assert isinstance(response, AIMessage)
assert response.content == "Hello! How can I help you?"
@pytest.mark.asyncio
async def test_async_completion(
self, mock_unstreaming_completion: List[ChatCompletionChunk]
) -> None:
"""Test async chat completion with mocked response."""
chat = ChatWriter(api_key=SecretStr("key"))
mock_async_client = AsyncMock()
mock_async_client.chat.chat.return_value = mock_unstreaming_completion
with mock.patch.object(chat, "async_client", mock_async_client):
message = HumanMessage(content="Hi there!")
response = await chat.ainvoke([message])
assert isinstance(response, AIMessage)
assert response.content == "Hello! How can I help you?"
def test_sync_streaming(
self, mock_streaming_chunks: List[ChatCompletionChunk]
) -> None:
"""Test sync streaming with callback handler."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatWriter(
api_key=SecretStr("key"),
callback_manager=callback_manager,
max_tokens=10,
)
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.__iter__.return_value = mock_streaming_chunks
mock_client.chat.chat.return_value = mock_response
with mock.patch.object(chat, "client", mock_client):
message = HumanMessage(content="Hi")
response = chat.stream([message])
response_message = ""
for chunk in response:
response_message += str(chunk.content)
assert callback_handler.llm_streams > 0
assert response_message == "Hello! How can I help you?"
@pytest.mark.asyncio
async def test_async_streaming(
self, mock_streaming_chunks: List[ChatCompletionChunk]
) -> None:
"""Test async streaming with callback handler."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatWriter(
api_key=SecretStr("key"),
callback_manager=callback_manager,
max_tokens=10,
)
mock_async_client = AsyncMock()
mock_response = AsyncMock()
mock_response.__aiter__.return_value = mock_streaming_chunks
mock_async_client.chat.chat.return_value = mock_response
with mock.patch.object(chat, "async_client", mock_async_client):
message = HumanMessage(content="Hi")
response = chat.astream([message])
response_message = ""
async for chunk in response:
response_message += str(chunk.content)
assert callback_handler.llm_streams > 0
assert response_message == "Hello! How can I help you?"
def test_sync_tool_calling(
self, mock_tool_call_choice_response: Dict[str, Any]
) -> None:
"""Test synchronous tool calling functionality."""
from pydantic import BaseModel, Field
class GetWeather(BaseModel):
"""Get the weather in a location."""
location: str = Field(..., description="The location to get weather for")
chat = ChatWriter(api_key=SecretStr("key"))
mock_client = MagicMock()
mock_client.chat.chat.return_value = mock_tool_call_choice_response
chat_with_tools = chat.bind_tools(
tools=[GetWeather],
tool_choice="GetWeather",
)
with mock.patch.object(chat, "client", mock_client):
response = chat_with_tools.invoke("What's the weather in London?")
assert isinstance(response, AIMessage)
assert response.tool_calls
assert response.tool_calls[0]["name"] == "GetWeather"
assert response.tool_calls[0]["args"]["location"] == "London"
@pytest.mark.asyncio
async def test_async_tool_calling(
self, mock_tool_call_choice_response: Dict[str, Any]
) -> None:
"""Test asynchronous tool calling functionality."""
from pydantic import BaseModel, Field
class GetWeather(BaseModel):
"""Get the weather in a location."""
location: str = Field(..., description="The location to get weather for")
mock_async_client = AsyncMock()
mock_async_client.chat.chat.return_value = mock_tool_call_choice_response
chat = ChatWriter(api_key=SecretStr("key"))
chat_with_tools = chat.bind_tools(
tools=[GetWeather],
tool_choice="GetWeather",
)
with mock.patch.object(chat, "async_client", mock_async_client):
response = await chat_with_tools.ainvoke("What's the weather in London?")
assert isinstance(response, AIMessage)
assert response.tool_calls
assert response.tool_calls[0]["name"] == "GetWeather"
assert response.tool_calls[0]["args"]["location"] == "London"
@pytest.mark.requires("writerai")
class TestChatWriterStandart(ChatModelUnitTests):
"""Test case for ChatWriter that inherits from standard LangChain tests."""
@property
def chat_model_class(self) -> Type[BaseChatModel]:
"""Return ChatWriter model class."""
return ChatWriter
@property
def chat_model_params(self) -> Dict:
"""Return any additional parameters needed."""
return {
"api_key": "fake-api-key",
"model_name": "palmyra-x-004",
}
@property
def has_tool_calling(self) -> bool:
"""Writer supports tool/function calling."""
return True
@property
def tool_choice_value(self) -> Optional[str]:
"""Value to use for tool choice in tests."""
return "auto"
@property
def has_structured_output(self) -> bool:
"""Writer does not yet support structured output."""
return False
@property
def supports_image_inputs(self) -> bool:
"""Writer does not support image inputs."""
return False
@property
def supports_video_inputs(self) -> bool:
"""Writer does not support video inputs."""
return False
@property
def returns_usage_metadata(self) -> bool:
"""Writer returns token usage information."""
return True
@property
def supports_anthropic_inputs(self) -> bool:
"""Writer does not support anthropic inputs."""
return False
@property
def supports_image_tool_message(self) -> bool:
"""Writer does not support image tool message."""
return False
@property
def supported_usage_metadata_details(
self,
) -> Dict[
Literal["invoke", "stream"],
List[
Literal[
"audio_input",
"audio_output",
"reasoning_output",
"cache_read_input",
"cache_creation_input",
]
],
]:
"""Return which types of usage metadata your model supports."""
return {"invoke": ["cache_creation_input"], "stream": ["reasoning_output"]}
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
"""Return env vars, init args, and expected instance attrs for initializing
from env vars."""
return {"WRITER_API_KEY": "key"}, {"api_key": "key"}, {"api_key": "key"}