mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-17 07:26:16 +00:00
community: fixed critical bugs at Writer provider (#27879)
This commit is contained in:
@@ -1,10 +0,0 @@
|
||||
"""Test Writer API wrapper."""
|
||||
|
||||
from langchain_community.llms.writer import Writer
|
||||
|
||||
|
||||
def test_writer_call() -> None:
|
||||
"""Test valid call to Writer."""
|
||||
llm = Writer()
|
||||
output = llm.invoke("Say foo:")
|
||||
assert isinstance(output, str)
|
@@ -1,61 +1,251 @@
|
||||
"""Unit tests for Writer chat model integration."""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Type
|
||||
from unittest import mock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks.manager import CallbackManager
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
from langchain_standard_tests.unit_tests import ChatModelUnitTests
|
||||
from pydantic import SecretStr
|
||||
|
||||
from langchain_community.chat_models.writer import ChatWriter, _convert_dict_to_message
|
||||
from langchain_community.chat_models.writer import ChatWriter
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
"""Classes for mocking Writer responses."""
|
||||
|
||||
|
||||
class ChoiceDelta:
|
||||
def __init__(self, content: str):
|
||||
self.content = content
|
||||
|
||||
|
||||
class ChunkChoice:
|
||||
def __init__(self, index: int, finish_reason: str, delta: ChoiceDelta):
|
||||
self.index = index
|
||||
self.finish_reason = finish_reason
|
||||
self.delta = delta
|
||||
|
||||
|
||||
class ChatCompletionChunk:
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
object: str,
|
||||
created: int,
|
||||
model: str,
|
||||
choices: List[ChunkChoice],
|
||||
):
|
||||
self.id = id
|
||||
self.object = object
|
||||
self.created = created
|
||||
self.model = model
|
||||
self.choices = choices
|
||||
|
||||
|
||||
class ToolCallFunction:
|
||||
def __init__(self, name: str, arguments: str):
|
||||
self.name = name
|
||||
self.arguments = arguments
|
||||
|
||||
|
||||
class ChoiceMessageToolCall:
|
||||
def __init__(self, id: str, type: str, function: ToolCallFunction):
|
||||
self.id = id
|
||||
self.type = type
|
||||
self.function = function
|
||||
|
||||
|
||||
class Usage:
|
||||
def __init__(
|
||||
self,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
):
|
||||
self.prompt_tokens = prompt_tokens
|
||||
self.completion_tokens = completion_tokens
|
||||
self.total_tokens = total_tokens
|
||||
|
||||
|
||||
class ChoiceMessage:
|
||||
def __init__(
|
||||
self,
|
||||
role: str,
|
||||
content: str,
|
||||
tool_calls: Optional[List[ChoiceMessageToolCall]] = None,
|
||||
):
|
||||
self.role = role
|
||||
self.content = content
|
||||
self.tool_calls = tool_calls
|
||||
|
||||
|
||||
class Choice:
|
||||
def __init__(self, index: int, finish_reason: str, message: ChoiceMessage):
|
||||
self.index = index
|
||||
self.finish_reason = finish_reason
|
||||
self.message = message
|
||||
|
||||
|
||||
class Chat:
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
object: str,
|
||||
created: int,
|
||||
system_fingerprint: str,
|
||||
model: str,
|
||||
usage: Usage,
|
||||
choices: List[Choice],
|
||||
):
|
||||
self.id = id
|
||||
self.object = object
|
||||
self.created = created
|
||||
self.system_fingerprint = system_fingerprint
|
||||
self.model = model
|
||||
self.usage = usage
|
||||
self.choices = choices
|
||||
|
||||
|
||||
@pytest.mark.requires("writerai")
|
||||
class TestChatWriterCustom:
|
||||
"""Test case for ChatWriter"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_unstreaming_completion(self) -> Chat:
|
||||
"""Fixture providing a mock API response."""
|
||||
return Chat(
|
||||
id="chat-12345",
|
||||
object="chat.completion",
|
||||
created=1699000000,
|
||||
model="palmyra-x-004",
|
||||
system_fingerprint="v1",
|
||||
usage=Usage(prompt_tokens=10, completion_tokens=8, total_tokens=18),
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
finish_reason="stop",
|
||||
message=ChoiceMessage(
|
||||
role="assistant",
|
||||
content="Hello! How can I help you?",
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_tool_call_choice_response(self) -> Chat:
|
||||
return Chat(
|
||||
id="chat-12345",
|
||||
object="chat.completion",
|
||||
created=1699000000,
|
||||
model="palmyra-x-004",
|
||||
system_fingerprint="v1",
|
||||
usage=Usage(prompt_tokens=29, completion_tokens=32, total_tokens=61),
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
finish_reason="tool_calls",
|
||||
message=ChoiceMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChoiceMessageToolCall(
|
||||
id="call_abc123",
|
||||
type="function",
|
||||
function=ToolCallFunction(
|
||||
name="GetWeather",
|
||||
arguments='{"location": "London"}',
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_streaming_chunks(self) -> List[ChatCompletionChunk]:
|
||||
"""Fixture providing mock streaming response chunks."""
|
||||
return [
|
||||
ChatCompletionChunk(
|
||||
id="chat-12345",
|
||||
object="chat.completion",
|
||||
created=1699000000,
|
||||
model="palmyra-x-004",
|
||||
choices=[
|
||||
ChunkChoice(
|
||||
index=0,
|
||||
finish_reason="stop",
|
||||
delta=ChoiceDelta(content="Hello! "),
|
||||
)
|
||||
],
|
||||
),
|
||||
ChatCompletionChunk(
|
||||
id="chat-12345",
|
||||
object="chat.completion",
|
||||
created=1699000000,
|
||||
model="palmyra-x-004",
|
||||
choices=[
|
||||
ChunkChoice(
|
||||
index=0,
|
||||
finish_reason="stop",
|
||||
delta=ChoiceDelta(content="How can I help you?"),
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
class TestChatWriter:
|
||||
def test_writer_model_param(self) -> None:
|
||||
"""Test different ways to initialize the chat model."""
|
||||
test_cases: List[dict] = [
|
||||
{"model_name": "palmyra-x-004", "writer_api_key": "test-key"},
|
||||
{"model": "palmyra-x-004", "writer_api_key": "test-key"},
|
||||
{"model_name": "palmyra-x-004", "writer_api_key": "test-key"},
|
||||
{
|
||||
"model_name": "palmyra-x-004",
|
||||
"api_key": "key",
|
||||
},
|
||||
{
|
||||
"model": "palmyra-x-004",
|
||||
"api_key": "key",
|
||||
},
|
||||
{
|
||||
"model_name": "palmyra-x-004",
|
||||
"api_key": "key",
|
||||
},
|
||||
{
|
||||
"model": "palmyra-x-004",
|
||||
"writer_api_key": "test-key",
|
||||
"temperature": 0.5,
|
||||
"api_key": "key",
|
||||
},
|
||||
]
|
||||
|
||||
for case in test_cases:
|
||||
chat = ChatWriter(**case)
|
||||
assert chat.model_name == "palmyra-x-004"
|
||||
assert chat.writer_api_key
|
||||
assert chat.writer_api_key.get_secret_value() == "test-key"
|
||||
assert chat.temperature == (0.5 if "temperature" in case else 0.7)
|
||||
|
||||
def test_convert_dict_to_message_human(self) -> None:
|
||||
def test_convert_writer_to_langchain_human(self) -> None:
|
||||
"""Test converting a human message dict to a LangChain message."""
|
||||
message = {"role": "user", "content": "Hello"}
|
||||
result = _convert_dict_to_message(message)
|
||||
result = ChatWriter._convert_writer_to_langchain(message)
|
||||
assert isinstance(result, HumanMessage)
|
||||
assert result.content == "Hello"
|
||||
|
||||
def test_convert_dict_to_message_ai(self) -> None:
|
||||
def test_convert_writer_to_langchain_ai(self) -> None:
|
||||
"""Test converting an AI message dict to a LangChain message."""
|
||||
message = {"role": "assistant", "content": "Hello"}
|
||||
result = _convert_dict_to_message(message)
|
||||
result = ChatWriter._convert_writer_to_langchain(message)
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.content == "Hello"
|
||||
|
||||
def test_convert_dict_to_message_system(self) -> None:
|
||||
def test_convert_writer_to_langchain_system(self) -> None:
|
||||
"""Test converting a system message dict to a LangChain message."""
|
||||
message = {"role": "system", "content": "You are a helpful assistant"}
|
||||
result = _convert_dict_to_message(message)
|
||||
result = ChatWriter._convert_writer_to_langchain(message)
|
||||
assert isinstance(result, SystemMessage)
|
||||
assert result.content == "You are a helpful assistant"
|
||||
|
||||
def test_convert_dict_to_message_tool_call(self) -> None:
|
||||
def test_convert_writer_to_langchain_tool_call(self) -> None:
|
||||
"""Test converting a tool call message dict to a LangChain message."""
|
||||
content = json.dumps({"result": 42})
|
||||
message = {
|
||||
@@ -64,12 +254,12 @@ class TestChatWriter:
|
||||
"content": content,
|
||||
"tool_call_id": "call_abc123",
|
||||
}
|
||||
result = _convert_dict_to_message(message)
|
||||
result = ChatWriter._convert_writer_to_langchain(message)
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result.name == "get_number"
|
||||
assert result.content == content
|
||||
|
||||
def test_convert_dict_to_message_with_tool_calls(self) -> None:
|
||||
def test_convert_writer_to_langchain_with_tool_calls(self) -> None:
|
||||
"""Test converting an AIMessage with tool calls."""
|
||||
message = {
|
||||
"role": "assistant",
|
||||
@@ -85,131 +275,55 @@ class TestChatWriter:
|
||||
}
|
||||
],
|
||||
}
|
||||
result = _convert_dict_to_message(message)
|
||||
result = ChatWriter._convert_writer_to_langchain(message)
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.tool_calls
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0]["name"] == "get_weather"
|
||||
assert result.tool_calls[0]["args"]["location"] == "London"
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_completion(self) -> Dict[str, Any]:
|
||||
"""Fixture providing a mock API response."""
|
||||
return {
|
||||
"id": "chat-12345",
|
||||
"object": "chat.completion",
|
||||
"created": 1699000000,
|
||||
"model": "palmyra-x-004",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I help you?",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18},
|
||||
}
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_response(self) -> Dict[str, Any]:
|
||||
response = {
|
||||
"id": "chat-12345",
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_abc123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "GetWeather",
|
||||
"arguments": '{"location": "London"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
"finish_reason": "tool_calls",
|
||||
}
|
||||
],
|
||||
}
|
||||
return response
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_streaming_chunks(self) -> List[Dict[str, Any]]:
|
||||
"""Fixture providing mock streaming response chunks."""
|
||||
return [
|
||||
{
|
||||
"id": "chat-12345",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1699000000,
|
||||
"model": "palmyra-x-004",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"role": "assistant",
|
||||
"content": "Hello",
|
||||
},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "chat-12345",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1699000000,
|
||||
"model": "palmyra-x-004",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": "!",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
def test_sync_completion(self, mock_completion: Dict[str, Any]) -> None:
|
||||
def test_sync_completion(
|
||||
self, mock_unstreaming_completion: List[ChatCompletionChunk]
|
||||
) -> None:
|
||||
"""Test basic chat completion with mocked response."""
|
||||
chat = ChatWriter(api_key=SecretStr("test-key"))
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.chat.return_value = mock_completion
|
||||
chat = ChatWriter(api_key=SecretStr("key"))
|
||||
|
||||
with patch.object(chat, "client", mock_client):
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.chat.return_value = mock_unstreaming_completion
|
||||
|
||||
with mock.patch.object(chat, "client", mock_client):
|
||||
message = HumanMessage(content="Hi there!")
|
||||
response = chat.invoke([message])
|
||||
assert isinstance(response, AIMessage)
|
||||
assert response.content == "Hello! How can I help you?"
|
||||
|
||||
async def test_async_completion(self, mock_completion: Dict[str, Any]) -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_completion(
|
||||
self, mock_unstreaming_completion: List[ChatCompletionChunk]
|
||||
) -> None:
|
||||
"""Test async chat completion with mocked response."""
|
||||
chat = ChatWriter(api_key=SecretStr("test-key"))
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.chat.return_value = mock_completion
|
||||
chat = ChatWriter(api_key=SecretStr("key"))
|
||||
|
||||
with patch.object(chat, "async_client", mock_client):
|
||||
mock_async_client = AsyncMock()
|
||||
mock_async_client.chat.chat.return_value = mock_unstreaming_completion
|
||||
|
||||
with mock.patch.object(chat, "async_client", mock_async_client):
|
||||
message = HumanMessage(content="Hi there!")
|
||||
response = await chat.ainvoke([message])
|
||||
assert isinstance(response, AIMessage)
|
||||
assert response.content == "Hello! How can I help you?"
|
||||
|
||||
def test_sync_streaming(self, mock_streaming_chunks: List[Dict[str, Any]]) -> None:
|
||||
def test_sync_streaming(
|
||||
self, mock_streaming_chunks: List[ChatCompletionChunk]
|
||||
) -> None:
|
||||
"""Test sync streaming with callback handler."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
|
||||
chat = ChatWriter(
|
||||
streaming=True,
|
||||
api_key=SecretStr("key"),
|
||||
callback_manager=callback_manager,
|
||||
max_tokens=10,
|
||||
api_key=SecretStr("test-key"),
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
@@ -217,42 +331,46 @@ class TestChatWriter:
|
||||
mock_response.__iter__.return_value = mock_streaming_chunks
|
||||
mock_client.chat.chat.return_value = mock_response
|
||||
|
||||
with patch.object(chat, "client", mock_client):
|
||||
with mock.patch.object(chat, "client", mock_client):
|
||||
message = HumanMessage(content="Hi")
|
||||
response = chat.invoke([message])
|
||||
|
||||
assert isinstance(response, AIMessage)
|
||||
response = chat.stream([message])
|
||||
response_message = ""
|
||||
for chunk in response:
|
||||
response_message += str(chunk.content)
|
||||
assert callback_handler.llm_streams > 0
|
||||
assert response.content == "Hello!"
|
||||
assert response_message == "Hello! How can I help you?"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_streaming(
|
||||
self, mock_streaming_chunks: List[Dict[str, Any]]
|
||||
self, mock_streaming_chunks: List[ChatCompletionChunk]
|
||||
) -> None:
|
||||
"""Test async streaming with callback handler."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
|
||||
chat = ChatWriter(
|
||||
streaming=True,
|
||||
api_key=SecretStr("key"),
|
||||
callback_manager=callback_manager,
|
||||
max_tokens=10,
|
||||
api_key=SecretStr("test-key"),
|
||||
)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_async_client = AsyncMock()
|
||||
mock_response = AsyncMock()
|
||||
mock_response.__aiter__.return_value = mock_streaming_chunks
|
||||
mock_client.chat.chat.return_value = mock_response
|
||||
mock_async_client.chat.chat.return_value = mock_response
|
||||
|
||||
with patch.object(chat, "async_client", mock_client):
|
||||
with mock.patch.object(chat, "async_client", mock_async_client):
|
||||
message = HumanMessage(content="Hi")
|
||||
response = await chat.ainvoke([message])
|
||||
|
||||
assert isinstance(response, AIMessage)
|
||||
response = chat.astream([message])
|
||||
response_message = ""
|
||||
async for chunk in response:
|
||||
response_message += str(chunk.content)
|
||||
assert callback_handler.llm_streams > 0
|
||||
assert response.content == "Hello!"
|
||||
assert response_message == "Hello! How can I help you?"
|
||||
|
||||
def test_sync_tool_calling(self, mock_response: Dict[str, Any]) -> None:
|
||||
def test_sync_tool_calling(
|
||||
self, mock_tool_call_choice_response: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Test synchronous tool calling functionality."""
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -261,23 +379,27 @@ class TestChatWriter:
|
||||
|
||||
location: str = Field(..., description="The location to get weather for")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.chat.return_value = mock_response
|
||||
chat = ChatWriter(api_key=SecretStr("key"))
|
||||
|
||||
chat = ChatWriter(api_key=SecretStr("test-key"), client=mock_client)
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.chat.return_value = mock_tool_call_choice_response
|
||||
|
||||
chat_with_tools = chat.bind_tools(
|
||||
tools=[GetWeather],
|
||||
tool_choice="GetWeather",
|
||||
)
|
||||
|
||||
response = chat_with_tools.invoke("What's the weather in London?")
|
||||
assert isinstance(response, AIMessage)
|
||||
assert response.tool_calls
|
||||
assert response.tool_calls[0]["name"] == "GetWeather"
|
||||
assert response.tool_calls[0]["args"]["location"] == "London"
|
||||
with mock.patch.object(chat, "client", mock_client):
|
||||
response = chat_with_tools.invoke("What's the weather in London?")
|
||||
assert isinstance(response, AIMessage)
|
||||
assert response.tool_calls
|
||||
assert response.tool_calls[0]["name"] == "GetWeather"
|
||||
assert response.tool_calls[0]["args"]["location"] == "London"
|
||||
|
||||
async def test_async_tool_calling(self, mock_response: Dict[str, Any]) -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_tool_calling(
|
||||
self, mock_tool_call_choice_response: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Test asynchronous tool calling functionality."""
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -286,18 +408,101 @@ class TestChatWriter:
|
||||
|
||||
location: str = Field(..., description="The location to get weather for")
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.chat.return_value = mock_response
|
||||
mock_async_client = AsyncMock()
|
||||
mock_async_client.chat.chat.return_value = mock_tool_call_choice_response
|
||||
|
||||
chat = ChatWriter(api_key=SecretStr("test-key"), async_client=mock_client)
|
||||
chat = ChatWriter(api_key=SecretStr("key"))
|
||||
|
||||
chat_with_tools = chat.bind_tools(
|
||||
tools=[GetWeather],
|
||||
tool_choice="GetWeather",
|
||||
)
|
||||
|
||||
response = await chat_with_tools.ainvoke("What's the weather in London?")
|
||||
assert isinstance(response, AIMessage)
|
||||
assert response.tool_calls
|
||||
assert response.tool_calls[0]["name"] == "GetWeather"
|
||||
assert response.tool_calls[0]["args"]["location"] == "London"
|
||||
with mock.patch.object(chat, "async_client", mock_async_client):
|
||||
response = await chat_with_tools.ainvoke("What's the weather in London?")
|
||||
assert isinstance(response, AIMessage)
|
||||
assert response.tool_calls
|
||||
assert response.tool_calls[0]["name"] == "GetWeather"
|
||||
assert response.tool_calls[0]["args"]["location"] == "London"
|
||||
|
||||
|
||||
@pytest.mark.requires("writerai")
|
||||
class TestChatWriterStandart(ChatModelUnitTests):
|
||||
"""Test case for ChatWriter that inherits from standard LangChain tests."""
|
||||
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
"""Return ChatWriter model class."""
|
||||
return ChatWriter
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> Dict:
|
||||
"""Return any additional parameters needed."""
|
||||
return {
|
||||
"api_key": "fake-api-key",
|
||||
"model_name": "palmyra-x-004",
|
||||
}
|
||||
|
||||
@property
|
||||
def has_tool_calling(self) -> bool:
|
||||
"""Writer supports tool/function calling."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def tool_choice_value(self) -> Optional[str]:
|
||||
"""Value to use for tool choice in tests."""
|
||||
return "auto"
|
||||
|
||||
@property
|
||||
def has_structured_output(self) -> bool:
|
||||
"""Writer does not yet support structured output."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_image_inputs(self) -> bool:
|
||||
"""Writer does not support image inputs."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_video_inputs(self) -> bool:
|
||||
"""Writer does not support video inputs."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def returns_usage_metadata(self) -> bool:
|
||||
"""Writer returns token usage information."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_anthropic_inputs(self) -> bool:
|
||||
"""Writer does not support anthropic inputs."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_image_tool_message(self) -> bool:
|
||||
"""Writer does not support image tool message."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def supported_usage_metadata_details(
|
||||
self,
|
||||
) -> Dict[
|
||||
Literal["invoke", "stream"],
|
||||
List[
|
||||
Literal[
|
||||
"audio_input",
|
||||
"audio_output",
|
||||
"reasoning_output",
|
||||
"cache_read_input",
|
||||
"cache_creation_input",
|
||||
]
|
||||
],
|
||||
]:
|
||||
"""Return which types of usage metadata your model supports."""
|
||||
return {"invoke": ["cache_creation_input"], "stream": ["reasoning_output"]}
|
||||
|
||||
@property
|
||||
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
|
||||
"""Return env vars, init args, and expected instance attrs for initializing
|
||||
from env vars."""
|
||||
return {"WRITER_API_KEY": "key"}, {"api_key": "key"}, {"api_key": "key"}
|
||||
|
202
libs/community/tests/unit_tests/llms/test_writer.py
Normal file
202
libs/community/tests/unit_tests/llms/test_writer.py
Normal file
@@ -0,0 +1,202 @@
|
||||
from typing import List
|
||||
from unittest import mock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
from pydantic import SecretStr
|
||||
|
||||
from langchain_community.llms.writer import Writer
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
"""Classes for mocking Writer responses."""
|
||||
|
||||
|
||||
class Choice:
|
||||
def __init__(self, text: str):
|
||||
self.text = text
|
||||
|
||||
|
||||
class Completion:
|
||||
def __init__(self, choices: List[Choice]):
|
||||
self.choices = choices
|
||||
|
||||
|
||||
class StreamingData:
|
||||
def __init__(self, value: str):
|
||||
self.value = value
|
||||
|
||||
|
||||
@pytest.mark.requires("writerai")
|
||||
class TestWriterLLM:
|
||||
"""Unit tests for Writer LLM integration."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_unstreaming_completion(self) -> Completion:
|
||||
"""Fixture providing a mock API response."""
|
||||
return Completion(choices=[Choice(text="Hello! How can I help you?")])
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_streaming_completion(self) -> List[StreamingData]:
|
||||
"""Fixture providing mock streaming response chunks."""
|
||||
return [
|
||||
StreamingData(value="Hello! "),
|
||||
StreamingData(value="How can I"),
|
||||
StreamingData(value=" help you?"),
|
||||
]
|
||||
|
||||
def test_sync_unstream_completion(
|
||||
self, mock_unstreaming_completion: Completion
|
||||
) -> None:
|
||||
"""Test basic llm call with mocked response."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.completions.create.return_value = mock_unstreaming_completion
|
||||
|
||||
llm = Writer(api_key=SecretStr("key"))
|
||||
|
||||
with mock.patch.object(llm, "client", mock_client):
|
||||
response_text = llm.invoke(input="Hello")
|
||||
|
||||
assert response_text == "Hello! How can I help you?"
|
||||
|
||||
def test_sync_unstream_completion_with_params(
|
||||
self, mock_unstreaming_completion: Completion
|
||||
) -> None:
|
||||
"""Test llm call with passed params with mocked response."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.completions.create.return_value = mock_unstreaming_completion
|
||||
|
||||
llm = Writer(api_key=SecretStr("key"), temperature=1)
|
||||
|
||||
with mock.patch.object(llm, "client", mock_client):
|
||||
response_text = llm.invoke(input="Hello")
|
||||
|
||||
assert response_text == "Hello! How can I help you?"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_unstream_completion(
|
||||
self, mock_unstreaming_completion: Completion
|
||||
) -> None:
|
||||
"""Test async chat completion with mocked response."""
|
||||
mock_async_client = AsyncMock()
|
||||
mock_async_client.completions.create.return_value = mock_unstreaming_completion
|
||||
|
||||
llm = Writer(api_key=SecretStr("key"))
|
||||
|
||||
with mock.patch.object(llm, "async_client", mock_async_client):
|
||||
response_text = await llm.ainvoke(input="Hello")
|
||||
|
||||
assert response_text == "Hello! How can I help you?"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_unstream_completion_with_params(
|
||||
self, mock_unstreaming_completion: Completion
|
||||
) -> None:
|
||||
"""Test async llm call with passed params with mocked response."""
|
||||
mock_async_client = AsyncMock()
|
||||
mock_async_client.completions.create.return_value = mock_unstreaming_completion
|
||||
|
||||
llm = Writer(api_key=SecretStr("key"), temperature=1)
|
||||
|
||||
with mock.patch.object(llm, "async_client", mock_async_client):
|
||||
response_text = await llm.ainvoke(input="Hello")
|
||||
|
||||
assert response_text == "Hello! How can I help you?"
|
||||
|
||||
def test_sync_streaming_completion(
|
||||
self, mock_streaming_completion: List[StreamingData]
|
||||
) -> None:
|
||||
"""Test sync streaming."""
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.__iter__.return_value = mock_streaming_completion
|
||||
mock_client.completions.create.return_value = mock_response
|
||||
|
||||
llm = Writer(api_key=SecretStr("key"))
|
||||
|
||||
with mock.patch.object(llm, "client", mock_client):
|
||||
response = llm.stream(input="Hello")
|
||||
|
||||
response_message = ""
|
||||
for chunk in response:
|
||||
response_message += chunk
|
||||
|
||||
assert response_message == "Hello! How can I help you?"
|
||||
|
||||
def test_sync_streaming_completion_with_callback_handler(
|
||||
self, mock_streaming_completion: List[StreamingData]
|
||||
) -> None:
|
||||
"""Test sync streaming with callback handler."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.__iter__.return_value = mock_streaming_completion
|
||||
mock_client.completions.create.return_value = mock_response
|
||||
|
||||
llm = Writer(
|
||||
api_key=SecretStr("key"),
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
|
||||
with mock.patch.object(llm, "client", mock_client):
|
||||
response = llm.stream(input="Hello")
|
||||
|
||||
response_message = ""
|
||||
for chunk in response:
|
||||
response_message += chunk
|
||||
|
||||
assert callback_handler.llm_streams == 3
|
||||
assert response_message == "Hello! How can I help you?"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_streaming_completion(
|
||||
self, mock_streaming_completion: Completion
|
||||
) -> None:
|
||||
"""Test async streaming with callback handler."""
|
||||
|
||||
mock_async_client = AsyncMock()
|
||||
mock_response = AsyncMock()
|
||||
mock_response.__aiter__.return_value = mock_streaming_completion
|
||||
mock_async_client.completions.create.return_value = mock_response
|
||||
|
||||
llm = Writer(api_key=SecretStr("key"))
|
||||
|
||||
with mock.patch.object(llm, "async_client", mock_async_client):
|
||||
response = llm.astream(input="Hello")
|
||||
|
||||
response_message = ""
|
||||
async for chunk in response:
|
||||
response_message += str(chunk)
|
||||
|
||||
assert response_message == "Hello! How can I help you?"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_streaming_completion_with_callback_handler(
|
||||
self, mock_streaming_completion: Completion
|
||||
) -> None:
|
||||
"""Test async streaming with callback handler."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
|
||||
mock_async_client = AsyncMock()
|
||||
mock_response = AsyncMock()
|
||||
mock_response.__aiter__.return_value = mock_streaming_completion
|
||||
mock_async_client.completions.create.return_value = mock_response
|
||||
|
||||
llm = Writer(
|
||||
api_key=SecretStr("key"),
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
|
||||
with mock.patch.object(llm, "async_client", mock_async_client):
|
||||
response = llm.astream(input="Hello")
|
||||
|
||||
response_message = ""
|
||||
async for chunk in response:
|
||||
response_message += str(chunk)
|
||||
|
||||
assert callback_handler.llm_streams == 3
|
||||
assert response_message == "Hello! How can I help you?"
|
Reference in New Issue
Block a user