mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-14 23:26:34 +00:00
Library name was updated after https://github.com/langchain-ai/langchain/pull/27879 branched off master.
509 lines
17 KiB
Python
509 lines
17 KiB
Python
import json
|
|
from typing import Any, Dict, List, Literal, Optional, Tuple, Type
|
|
from unittest import mock
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
from langchain_core.callbacks.manager import CallbackManager
|
|
from langchain_core.language_models import BaseChatModel
|
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
|
from langchain_tests.unit_tests import ChatModelUnitTests
|
|
from pydantic import SecretStr
|
|
|
|
from langchain_community.chat_models.writer import ChatWriter
|
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
|
|
|
"""Classes for mocking Writer responses."""
|
|
|
|
|
|
class ChoiceDelta:
|
|
def __init__(self, content: str):
|
|
self.content = content
|
|
|
|
|
|
class ChunkChoice:
|
|
def __init__(self, index: int, finish_reason: str, delta: ChoiceDelta):
|
|
self.index = index
|
|
self.finish_reason = finish_reason
|
|
self.delta = delta
|
|
|
|
|
|
class ChatCompletionChunk:
|
|
def __init__(
|
|
self,
|
|
id: str,
|
|
object: str,
|
|
created: int,
|
|
model: str,
|
|
choices: List[ChunkChoice],
|
|
):
|
|
self.id = id
|
|
self.object = object
|
|
self.created = created
|
|
self.model = model
|
|
self.choices = choices
|
|
|
|
|
|
class ToolCallFunction:
|
|
def __init__(self, name: str, arguments: str):
|
|
self.name = name
|
|
self.arguments = arguments
|
|
|
|
|
|
class ChoiceMessageToolCall:
|
|
def __init__(self, id: str, type: str, function: ToolCallFunction):
|
|
self.id = id
|
|
self.type = type
|
|
self.function = function
|
|
|
|
|
|
class Usage:
|
|
def __init__(
|
|
self,
|
|
prompt_tokens: int,
|
|
completion_tokens: int,
|
|
total_tokens: int,
|
|
):
|
|
self.prompt_tokens = prompt_tokens
|
|
self.completion_tokens = completion_tokens
|
|
self.total_tokens = total_tokens
|
|
|
|
|
|
class ChoiceMessage:
|
|
def __init__(
|
|
self,
|
|
role: str,
|
|
content: str,
|
|
tool_calls: Optional[List[ChoiceMessageToolCall]] = None,
|
|
):
|
|
self.role = role
|
|
self.content = content
|
|
self.tool_calls = tool_calls
|
|
|
|
|
|
class Choice:
|
|
def __init__(self, index: int, finish_reason: str, message: ChoiceMessage):
|
|
self.index = index
|
|
self.finish_reason = finish_reason
|
|
self.message = message
|
|
|
|
|
|
class Chat:
|
|
def __init__(
|
|
self,
|
|
id: str,
|
|
object: str,
|
|
created: int,
|
|
system_fingerprint: str,
|
|
model: str,
|
|
usage: Usage,
|
|
choices: List[Choice],
|
|
):
|
|
self.id = id
|
|
self.object = object
|
|
self.created = created
|
|
self.system_fingerprint = system_fingerprint
|
|
self.model = model
|
|
self.usage = usage
|
|
self.choices = choices
|
|
|
|
|
|
@pytest.mark.requires("writerai")
|
|
class TestChatWriterCustom:
|
|
"""Test case for ChatWriter"""
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def mock_unstreaming_completion(self) -> Chat:
|
|
"""Fixture providing a mock API response."""
|
|
return Chat(
|
|
id="chat-12345",
|
|
object="chat.completion",
|
|
created=1699000000,
|
|
model="palmyra-x-004",
|
|
system_fingerprint="v1",
|
|
usage=Usage(prompt_tokens=10, completion_tokens=8, total_tokens=18),
|
|
choices=[
|
|
Choice(
|
|
index=0,
|
|
finish_reason="stop",
|
|
message=ChoiceMessage(
|
|
role="assistant",
|
|
content="Hello! How can I help you?",
|
|
),
|
|
)
|
|
],
|
|
)
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def mock_tool_call_choice_response(self) -> Chat:
|
|
return Chat(
|
|
id="chat-12345",
|
|
object="chat.completion",
|
|
created=1699000000,
|
|
model="palmyra-x-004",
|
|
system_fingerprint="v1",
|
|
usage=Usage(prompt_tokens=29, completion_tokens=32, total_tokens=61),
|
|
choices=[
|
|
Choice(
|
|
index=0,
|
|
finish_reason="tool_calls",
|
|
message=ChoiceMessage(
|
|
role="assistant",
|
|
content="",
|
|
tool_calls=[
|
|
ChoiceMessageToolCall(
|
|
id="call_abc123",
|
|
type="function",
|
|
function=ToolCallFunction(
|
|
name="GetWeather",
|
|
arguments='{"location": "London"}',
|
|
),
|
|
)
|
|
],
|
|
),
|
|
)
|
|
],
|
|
)
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def mock_streaming_chunks(self) -> List[ChatCompletionChunk]:
|
|
"""Fixture providing mock streaming response chunks."""
|
|
return [
|
|
ChatCompletionChunk(
|
|
id="chat-12345",
|
|
object="chat.completion",
|
|
created=1699000000,
|
|
model="palmyra-x-004",
|
|
choices=[
|
|
ChunkChoice(
|
|
index=0,
|
|
finish_reason="stop",
|
|
delta=ChoiceDelta(content="Hello! "),
|
|
)
|
|
],
|
|
),
|
|
ChatCompletionChunk(
|
|
id="chat-12345",
|
|
object="chat.completion",
|
|
created=1699000000,
|
|
model="palmyra-x-004",
|
|
choices=[
|
|
ChunkChoice(
|
|
index=0,
|
|
finish_reason="stop",
|
|
delta=ChoiceDelta(content="How can I help you?"),
|
|
)
|
|
],
|
|
),
|
|
]
|
|
|
|
def test_writer_model_param(self) -> None:
|
|
"""Test different ways to initialize the chat model."""
|
|
test_cases: List[dict] = [
|
|
{
|
|
"model_name": "palmyra-x-004",
|
|
"api_key": "key",
|
|
},
|
|
{
|
|
"model": "palmyra-x-004",
|
|
"api_key": "key",
|
|
},
|
|
{
|
|
"model_name": "palmyra-x-004",
|
|
"api_key": "key",
|
|
},
|
|
{
|
|
"model": "palmyra-x-004",
|
|
"temperature": 0.5,
|
|
"api_key": "key",
|
|
},
|
|
]
|
|
|
|
for case in test_cases:
|
|
chat = ChatWriter(**case)
|
|
assert chat.model_name == "palmyra-x-004"
|
|
assert chat.temperature == (0.5 if "temperature" in case else 0.7)
|
|
|
|
def test_convert_writer_to_langchain_human(self) -> None:
|
|
"""Test converting a human message dict to a LangChain message."""
|
|
message = {"role": "user", "content": "Hello"}
|
|
result = ChatWriter._convert_writer_to_langchain(message)
|
|
assert isinstance(result, HumanMessage)
|
|
assert result.content == "Hello"
|
|
|
|
def test_convert_writer_to_langchain_ai(self) -> None:
|
|
"""Test converting an AI message dict to a LangChain message."""
|
|
message = {"role": "assistant", "content": "Hello"}
|
|
result = ChatWriter._convert_writer_to_langchain(message)
|
|
assert isinstance(result, AIMessage)
|
|
assert result.content == "Hello"
|
|
|
|
def test_convert_writer_to_langchain_system(self) -> None:
|
|
"""Test converting a system message dict to a LangChain message."""
|
|
message = {"role": "system", "content": "You are a helpful assistant"}
|
|
result = ChatWriter._convert_writer_to_langchain(message)
|
|
assert isinstance(result, SystemMessage)
|
|
assert result.content == "You are a helpful assistant"
|
|
|
|
def test_convert_writer_to_langchain_tool_call(self) -> None:
|
|
"""Test converting a tool call message dict to a LangChain message."""
|
|
content = json.dumps({"result": 42})
|
|
message = {
|
|
"role": "tool",
|
|
"name": "get_number",
|
|
"content": content,
|
|
"tool_call_id": "call_abc123",
|
|
}
|
|
result = ChatWriter._convert_writer_to_langchain(message)
|
|
assert isinstance(result, ToolMessage)
|
|
assert result.name == "get_number"
|
|
assert result.content == content
|
|
|
|
def test_convert_writer_to_langchain_with_tool_calls(self) -> None:
|
|
"""Test converting an AIMessage with tool calls."""
|
|
message = {
|
|
"role": "assistant",
|
|
"content": "",
|
|
"tool_calls": [
|
|
{
|
|
"id": "call_abc123",
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_weather",
|
|
"arguments": '{"location": "London"}',
|
|
},
|
|
}
|
|
],
|
|
}
|
|
result = ChatWriter._convert_writer_to_langchain(message)
|
|
assert isinstance(result, AIMessage)
|
|
assert result.tool_calls
|
|
assert len(result.tool_calls) == 1
|
|
assert result.tool_calls[0]["name"] == "get_weather"
|
|
assert result.tool_calls[0]["args"]["location"] == "London"
|
|
|
|
def test_sync_completion(
|
|
self, mock_unstreaming_completion: List[ChatCompletionChunk]
|
|
) -> None:
|
|
"""Test basic chat completion with mocked response."""
|
|
chat = ChatWriter(api_key=SecretStr("key"))
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.chat.chat.return_value = mock_unstreaming_completion
|
|
|
|
with mock.patch.object(chat, "client", mock_client):
|
|
message = HumanMessage(content="Hi there!")
|
|
response = chat.invoke([message])
|
|
assert isinstance(response, AIMessage)
|
|
assert response.content == "Hello! How can I help you?"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_completion(
|
|
self, mock_unstreaming_completion: List[ChatCompletionChunk]
|
|
) -> None:
|
|
"""Test async chat completion with mocked response."""
|
|
chat = ChatWriter(api_key=SecretStr("key"))
|
|
|
|
mock_async_client = AsyncMock()
|
|
mock_async_client.chat.chat.return_value = mock_unstreaming_completion
|
|
|
|
with mock.patch.object(chat, "async_client", mock_async_client):
|
|
message = HumanMessage(content="Hi there!")
|
|
response = await chat.ainvoke([message])
|
|
assert isinstance(response, AIMessage)
|
|
assert response.content == "Hello! How can I help you?"
|
|
|
|
def test_sync_streaming(
|
|
self, mock_streaming_chunks: List[ChatCompletionChunk]
|
|
) -> None:
|
|
"""Test sync streaming with callback handler."""
|
|
callback_handler = FakeCallbackHandler()
|
|
callback_manager = CallbackManager([callback_handler])
|
|
|
|
chat = ChatWriter(
|
|
api_key=SecretStr("key"),
|
|
callback_manager=callback_manager,
|
|
max_tokens=10,
|
|
)
|
|
|
|
mock_client = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.__iter__.return_value = mock_streaming_chunks
|
|
mock_client.chat.chat.return_value = mock_response
|
|
|
|
with mock.patch.object(chat, "client", mock_client):
|
|
message = HumanMessage(content="Hi")
|
|
response = chat.stream([message])
|
|
response_message = ""
|
|
for chunk in response:
|
|
response_message += str(chunk.content)
|
|
assert callback_handler.llm_streams > 0
|
|
assert response_message == "Hello! How can I help you?"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_streaming(
|
|
self, mock_streaming_chunks: List[ChatCompletionChunk]
|
|
) -> None:
|
|
"""Test async streaming with callback handler."""
|
|
callback_handler = FakeCallbackHandler()
|
|
callback_manager = CallbackManager([callback_handler])
|
|
|
|
chat = ChatWriter(
|
|
api_key=SecretStr("key"),
|
|
callback_manager=callback_manager,
|
|
max_tokens=10,
|
|
)
|
|
|
|
mock_async_client = AsyncMock()
|
|
mock_response = AsyncMock()
|
|
mock_response.__aiter__.return_value = mock_streaming_chunks
|
|
mock_async_client.chat.chat.return_value = mock_response
|
|
|
|
with mock.patch.object(chat, "async_client", mock_async_client):
|
|
message = HumanMessage(content="Hi")
|
|
response = chat.astream([message])
|
|
response_message = ""
|
|
async for chunk in response:
|
|
response_message += str(chunk.content)
|
|
assert callback_handler.llm_streams > 0
|
|
assert response_message == "Hello! How can I help you?"
|
|
|
|
def test_sync_tool_calling(
|
|
self, mock_tool_call_choice_response: Dict[str, Any]
|
|
) -> None:
|
|
"""Test synchronous tool calling functionality."""
|
|
from pydantic import BaseModel, Field
|
|
|
|
class GetWeather(BaseModel):
|
|
"""Get the weather in a location."""
|
|
|
|
location: str = Field(..., description="The location to get weather for")
|
|
|
|
chat = ChatWriter(api_key=SecretStr("key"))
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.chat.chat.return_value = mock_tool_call_choice_response
|
|
|
|
chat_with_tools = chat.bind_tools(
|
|
tools=[GetWeather],
|
|
tool_choice="GetWeather",
|
|
)
|
|
|
|
with mock.patch.object(chat, "client", mock_client):
|
|
response = chat_with_tools.invoke("What's the weather in London?")
|
|
assert isinstance(response, AIMessage)
|
|
assert response.tool_calls
|
|
assert response.tool_calls[0]["name"] == "GetWeather"
|
|
assert response.tool_calls[0]["args"]["location"] == "London"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_tool_calling(
|
|
self, mock_tool_call_choice_response: Dict[str, Any]
|
|
) -> None:
|
|
"""Test asynchronous tool calling functionality."""
|
|
from pydantic import BaseModel, Field
|
|
|
|
class GetWeather(BaseModel):
|
|
"""Get the weather in a location."""
|
|
|
|
location: str = Field(..., description="The location to get weather for")
|
|
|
|
mock_async_client = AsyncMock()
|
|
mock_async_client.chat.chat.return_value = mock_tool_call_choice_response
|
|
|
|
chat = ChatWriter(api_key=SecretStr("key"))
|
|
|
|
chat_with_tools = chat.bind_tools(
|
|
tools=[GetWeather],
|
|
tool_choice="GetWeather",
|
|
)
|
|
|
|
with mock.patch.object(chat, "async_client", mock_async_client):
|
|
response = await chat_with_tools.ainvoke("What's the weather in London?")
|
|
assert isinstance(response, AIMessage)
|
|
assert response.tool_calls
|
|
assert response.tool_calls[0]["name"] == "GetWeather"
|
|
assert response.tool_calls[0]["args"]["location"] == "London"
|
|
|
|
|
|
@pytest.mark.requires("writerai")
|
|
class TestChatWriterStandart(ChatModelUnitTests):
|
|
"""Test case for ChatWriter that inherits from standard LangChain tests."""
|
|
|
|
@property
|
|
def chat_model_class(self) -> Type[BaseChatModel]:
|
|
"""Return ChatWriter model class."""
|
|
return ChatWriter
|
|
|
|
@property
|
|
def chat_model_params(self) -> Dict:
|
|
"""Return any additional parameters needed."""
|
|
return {
|
|
"api_key": "fake-api-key",
|
|
"model_name": "palmyra-x-004",
|
|
}
|
|
|
|
@property
|
|
def has_tool_calling(self) -> bool:
|
|
"""Writer supports tool/function calling."""
|
|
return True
|
|
|
|
@property
|
|
def tool_choice_value(self) -> Optional[str]:
|
|
"""Value to use for tool choice in tests."""
|
|
return "auto"
|
|
|
|
@property
|
|
def has_structured_output(self) -> bool:
|
|
"""Writer does not yet support structured output."""
|
|
return False
|
|
|
|
@property
|
|
def supports_image_inputs(self) -> bool:
|
|
"""Writer does not support image inputs."""
|
|
return False
|
|
|
|
@property
|
|
def supports_video_inputs(self) -> bool:
|
|
"""Writer does not support video inputs."""
|
|
return False
|
|
|
|
@property
|
|
def returns_usage_metadata(self) -> bool:
|
|
"""Writer returns token usage information."""
|
|
return True
|
|
|
|
@property
|
|
def supports_anthropic_inputs(self) -> bool:
|
|
"""Writer does not support anthropic inputs."""
|
|
return False
|
|
|
|
@property
|
|
def supports_image_tool_message(self) -> bool:
|
|
"""Writer does not support image tool message."""
|
|
return False
|
|
|
|
@property
|
|
def supported_usage_metadata_details(
|
|
self,
|
|
) -> Dict[
|
|
Literal["invoke", "stream"],
|
|
List[
|
|
Literal[
|
|
"audio_input",
|
|
"audio_output",
|
|
"reasoning_output",
|
|
"cache_read_input",
|
|
"cache_creation_input",
|
|
]
|
|
],
|
|
]:
|
|
"""Return which types of usage metadata your model supports."""
|
|
return {"invoke": ["cache_creation_input"], "stream": ["reasoning_output"]}
|
|
|
|
@property
|
|
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
|
|
"""Return env vars, init args, and expected instance attrs for initializing
|
|
from env vars."""
|
|
return {"WRITER_API_KEY": "key"}, {"api_key": "key"}, {"api_key": "key"}
|