community: fixed critical bugs at Writer provider (#27879)

This commit is contained in:
Yan
2024-11-25 20:03:37 +03:00
committed by GitHub
parent 6ed2d387bb
commit c60695a1c7
8 changed files with 1205 additions and 542 deletions

View File

@@ -1,10 +0,0 @@
"""Test Writer API wrapper."""
from langchain_community.llms.writer import Writer
def test_writer_call() -> None:
"""Test valid call to Writer."""
llm = Writer()
output = llm.invoke("Say foo:")
assert isinstance(output, str)

View File

@@ -1,61 +1,251 @@
"""Unit tests for Writer chat model integration."""
import json
from typing import Any, Dict, List
from unittest.mock import AsyncMock, MagicMock, patch
from typing import Any, Dict, List, Literal, Optional, Tuple, Type
from unittest import mock
from unittest.mock import AsyncMock, MagicMock
import pytest
from langchain_core.callbacks.manager import CallbackManager
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_standard_tests.unit_tests import ChatModelUnitTests
from pydantic import SecretStr
from langchain_community.chat_models.writer import ChatWriter, _convert_dict_to_message
from langchain_community.chat_models.writer import ChatWriter
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
"""Classes for mocking Writer responses."""
class ChoiceDelta:
def __init__(self, content: str):
self.content = content
class ChunkChoice:
def __init__(self, index: int, finish_reason: str, delta: ChoiceDelta):
self.index = index
self.finish_reason = finish_reason
self.delta = delta
class ChatCompletionChunk:
def __init__(
self,
id: str,
object: str,
created: int,
model: str,
choices: List[ChunkChoice],
):
self.id = id
self.object = object
self.created = created
self.model = model
self.choices = choices
class ToolCallFunction:
def __init__(self, name: str, arguments: str):
self.name = name
self.arguments = arguments
class ChoiceMessageToolCall:
def __init__(self, id: str, type: str, function: ToolCallFunction):
self.id = id
self.type = type
self.function = function
class Usage:
def __init__(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
):
self.prompt_tokens = prompt_tokens
self.completion_tokens = completion_tokens
self.total_tokens = total_tokens
class ChoiceMessage:
def __init__(
self,
role: str,
content: str,
tool_calls: Optional[List[ChoiceMessageToolCall]] = None,
):
self.role = role
self.content = content
self.tool_calls = tool_calls
class Choice:
def __init__(self, index: int, finish_reason: str, message: ChoiceMessage):
self.index = index
self.finish_reason = finish_reason
self.message = message
class Chat:
def __init__(
self,
id: str,
object: str,
created: int,
system_fingerprint: str,
model: str,
usage: Usage,
choices: List[Choice],
):
self.id = id
self.object = object
self.created = created
self.system_fingerprint = system_fingerprint
self.model = model
self.usage = usage
self.choices = choices
@pytest.mark.requires("writerai")
class TestChatWriterCustom:
"""Test case for ChatWriter"""
@pytest.fixture(autouse=True)
def mock_unstreaming_completion(self) -> Chat:
"""Fixture providing a mock API response."""
return Chat(
id="chat-12345",
object="chat.completion",
created=1699000000,
model="palmyra-x-004",
system_fingerprint="v1",
usage=Usage(prompt_tokens=10, completion_tokens=8, total_tokens=18),
choices=[
Choice(
index=0,
finish_reason="stop",
message=ChoiceMessage(
role="assistant",
content="Hello! How can I help you?",
),
)
],
)
@pytest.fixture(autouse=True)
def mock_tool_call_choice_response(self) -> Chat:
return Chat(
id="chat-12345",
object="chat.completion",
created=1699000000,
model="palmyra-x-004",
system_fingerprint="v1",
usage=Usage(prompt_tokens=29, completion_tokens=32, total_tokens=61),
choices=[
Choice(
index=0,
finish_reason="tool_calls",
message=ChoiceMessage(
role="assistant",
content="",
tool_calls=[
ChoiceMessageToolCall(
id="call_abc123",
type="function",
function=ToolCallFunction(
name="GetWeather",
arguments='{"location": "London"}',
),
)
],
),
)
],
)
@pytest.fixture(autouse=True)
def mock_streaming_chunks(self) -> List[ChatCompletionChunk]:
"""Fixture providing mock streaming response chunks."""
return [
ChatCompletionChunk(
id="chat-12345",
object="chat.completion",
created=1699000000,
model="palmyra-x-004",
choices=[
ChunkChoice(
index=0,
finish_reason="stop",
delta=ChoiceDelta(content="Hello! "),
)
],
),
ChatCompletionChunk(
id="chat-12345",
object="chat.completion",
created=1699000000,
model="palmyra-x-004",
choices=[
ChunkChoice(
index=0,
finish_reason="stop",
delta=ChoiceDelta(content="How can I help you?"),
)
],
),
]
class TestChatWriter:
def test_writer_model_param(self) -> None:
"""Test different ways to initialize the chat model."""
test_cases: List[dict] = [
{"model_name": "palmyra-x-004", "writer_api_key": "test-key"},
{"model": "palmyra-x-004", "writer_api_key": "test-key"},
{"model_name": "palmyra-x-004", "writer_api_key": "test-key"},
{
"model_name": "palmyra-x-004",
"api_key": "key",
},
{
"model": "palmyra-x-004",
"api_key": "key",
},
{
"model_name": "palmyra-x-004",
"api_key": "key",
},
{
"model": "palmyra-x-004",
"writer_api_key": "test-key",
"temperature": 0.5,
"api_key": "key",
},
]
for case in test_cases:
chat = ChatWriter(**case)
assert chat.model_name == "palmyra-x-004"
assert chat.writer_api_key
assert chat.writer_api_key.get_secret_value() == "test-key"
assert chat.temperature == (0.5 if "temperature" in case else 0.7)
def test_convert_dict_to_message_human(self) -> None:
def test_convert_writer_to_langchain_human(self) -> None:
"""Test converting a human message dict to a LangChain message."""
message = {"role": "user", "content": "Hello"}
result = _convert_dict_to_message(message)
result = ChatWriter._convert_writer_to_langchain(message)
assert isinstance(result, HumanMessage)
assert result.content == "Hello"
def test_convert_dict_to_message_ai(self) -> None:
def test_convert_writer_to_langchain_ai(self) -> None:
"""Test converting an AI message dict to a LangChain message."""
message = {"role": "assistant", "content": "Hello"}
result = _convert_dict_to_message(message)
result = ChatWriter._convert_writer_to_langchain(message)
assert isinstance(result, AIMessage)
assert result.content == "Hello"
def test_convert_dict_to_message_system(self) -> None:
def test_convert_writer_to_langchain_system(self) -> None:
"""Test converting a system message dict to a LangChain message."""
message = {"role": "system", "content": "You are a helpful assistant"}
result = _convert_dict_to_message(message)
result = ChatWriter._convert_writer_to_langchain(message)
assert isinstance(result, SystemMessage)
assert result.content == "You are a helpful assistant"
def test_convert_dict_to_message_tool_call(self) -> None:
def test_convert_writer_to_langchain_tool_call(self) -> None:
"""Test converting a tool call message dict to a LangChain message."""
content = json.dumps({"result": 42})
message = {
@@ -64,12 +254,12 @@ class TestChatWriter:
"content": content,
"tool_call_id": "call_abc123",
}
result = _convert_dict_to_message(message)
result = ChatWriter._convert_writer_to_langchain(message)
assert isinstance(result, ToolMessage)
assert result.name == "get_number"
assert result.content == content
def test_convert_dict_to_message_with_tool_calls(self) -> None:
def test_convert_writer_to_langchain_with_tool_calls(self) -> None:
"""Test converting an AIMessage with tool calls."""
message = {
"role": "assistant",
@@ -85,131 +275,55 @@ class TestChatWriter:
}
],
}
result = _convert_dict_to_message(message)
result = ChatWriter._convert_writer_to_langchain(message)
assert isinstance(result, AIMessage)
assert result.tool_calls
assert len(result.tool_calls) == 1
assert result.tool_calls[0]["name"] == "get_weather"
assert result.tool_calls[0]["args"]["location"] == "London"
@pytest.fixture(autouse=True)
def mock_completion(self) -> Dict[str, Any]:
"""Fixture providing a mock API response."""
return {
"id": "chat-12345",
"object": "chat.completion",
"created": 1699000000,
"model": "palmyra-x-004",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello! How can I help you?",
},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18},
}
@pytest.fixture(autouse=True)
def mock_response(self) -> Dict[str, Any]:
response = {
"id": "chat-12345",
"choices": [
{
"message": {
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_abc123",
"type": "function",
"function": {
"name": "GetWeather",
"arguments": '{"location": "London"}',
},
}
],
},
"finish_reason": "tool_calls",
}
],
}
return response
@pytest.fixture(autouse=True)
def mock_streaming_chunks(self) -> List[Dict[str, Any]]:
"""Fixture providing mock streaming response chunks."""
return [
{
"id": "chat-12345",
"object": "chat.completion.chunk",
"created": 1699000000,
"model": "palmyra-x-004",
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": "Hello",
},
"finish_reason": None,
}
],
},
{
"id": "chat-12345",
"object": "chat.completion.chunk",
"created": 1699000000,
"model": "palmyra-x-004",
"choices": [
{
"index": 0,
"delta": {
"content": "!",
},
"finish_reason": "stop",
}
],
},
]
def test_sync_completion(self, mock_completion: Dict[str, Any]) -> None:
def test_sync_completion(
self, mock_unstreaming_completion: List[ChatCompletionChunk]
) -> None:
"""Test basic chat completion with mocked response."""
chat = ChatWriter(api_key=SecretStr("test-key"))
mock_client = MagicMock()
mock_client.chat.chat.return_value = mock_completion
chat = ChatWriter(api_key=SecretStr("key"))
with patch.object(chat, "client", mock_client):
mock_client = MagicMock()
mock_client.chat.chat.return_value = mock_unstreaming_completion
with mock.patch.object(chat, "client", mock_client):
message = HumanMessage(content="Hi there!")
response = chat.invoke([message])
assert isinstance(response, AIMessage)
assert response.content == "Hello! How can I help you?"
async def test_async_completion(self, mock_completion: Dict[str, Any]) -> None:
@pytest.mark.asyncio
async def test_async_completion(
self, mock_unstreaming_completion: List[ChatCompletionChunk]
) -> None:
"""Test async chat completion with mocked response."""
chat = ChatWriter(api_key=SecretStr("test-key"))
mock_client = AsyncMock()
mock_client.chat.chat.return_value = mock_completion
chat = ChatWriter(api_key=SecretStr("key"))
with patch.object(chat, "async_client", mock_client):
mock_async_client = AsyncMock()
mock_async_client.chat.chat.return_value = mock_unstreaming_completion
with mock.patch.object(chat, "async_client", mock_async_client):
message = HumanMessage(content="Hi there!")
response = await chat.ainvoke([message])
assert isinstance(response, AIMessage)
assert response.content == "Hello! How can I help you?"
def test_sync_streaming(self, mock_streaming_chunks: List[Dict[str, Any]]) -> None:
def test_sync_streaming(
self, mock_streaming_chunks: List[ChatCompletionChunk]
) -> None:
"""Test sync streaming with callback handler."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatWriter(
streaming=True,
api_key=SecretStr("key"),
callback_manager=callback_manager,
max_tokens=10,
api_key=SecretStr("test-key"),
)
mock_client = MagicMock()
@@ -217,42 +331,46 @@ class TestChatWriter:
mock_response.__iter__.return_value = mock_streaming_chunks
mock_client.chat.chat.return_value = mock_response
with patch.object(chat, "client", mock_client):
with mock.patch.object(chat, "client", mock_client):
message = HumanMessage(content="Hi")
response = chat.invoke([message])
assert isinstance(response, AIMessage)
response = chat.stream([message])
response_message = ""
for chunk in response:
response_message += str(chunk.content)
assert callback_handler.llm_streams > 0
assert response.content == "Hello!"
assert response_message == "Hello! How can I help you?"
@pytest.mark.asyncio
async def test_async_streaming(
self, mock_streaming_chunks: List[Dict[str, Any]]
self, mock_streaming_chunks: List[ChatCompletionChunk]
) -> None:
"""Test async streaming with callback handler."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatWriter(
streaming=True,
api_key=SecretStr("key"),
callback_manager=callback_manager,
max_tokens=10,
api_key=SecretStr("test-key"),
)
mock_client = AsyncMock()
mock_async_client = AsyncMock()
mock_response = AsyncMock()
mock_response.__aiter__.return_value = mock_streaming_chunks
mock_client.chat.chat.return_value = mock_response
mock_async_client.chat.chat.return_value = mock_response
with patch.object(chat, "async_client", mock_client):
with mock.patch.object(chat, "async_client", mock_async_client):
message = HumanMessage(content="Hi")
response = await chat.ainvoke([message])
assert isinstance(response, AIMessage)
response = chat.astream([message])
response_message = ""
async for chunk in response:
response_message += str(chunk.content)
assert callback_handler.llm_streams > 0
assert response.content == "Hello!"
assert response_message == "Hello! How can I help you?"
def test_sync_tool_calling(self, mock_response: Dict[str, Any]) -> None:
def test_sync_tool_calling(
self, mock_tool_call_choice_response: Dict[str, Any]
) -> None:
"""Test synchronous tool calling functionality."""
from pydantic import BaseModel, Field
@@ -261,23 +379,27 @@ class TestChatWriter:
location: str = Field(..., description="The location to get weather for")
mock_client = MagicMock()
mock_client.chat.chat.return_value = mock_response
chat = ChatWriter(api_key=SecretStr("key"))
chat = ChatWriter(api_key=SecretStr("test-key"), client=mock_client)
mock_client = MagicMock()
mock_client.chat.chat.return_value = mock_tool_call_choice_response
chat_with_tools = chat.bind_tools(
tools=[GetWeather],
tool_choice="GetWeather",
)
response = chat_with_tools.invoke("What's the weather in London?")
assert isinstance(response, AIMessage)
assert response.tool_calls
assert response.tool_calls[0]["name"] == "GetWeather"
assert response.tool_calls[0]["args"]["location"] == "London"
with mock.patch.object(chat, "client", mock_client):
response = chat_with_tools.invoke("What's the weather in London?")
assert isinstance(response, AIMessage)
assert response.tool_calls
assert response.tool_calls[0]["name"] == "GetWeather"
assert response.tool_calls[0]["args"]["location"] == "London"
async def test_async_tool_calling(self, mock_response: Dict[str, Any]) -> None:
@pytest.mark.asyncio
async def test_async_tool_calling(
self, mock_tool_call_choice_response: Dict[str, Any]
) -> None:
"""Test asynchronous tool calling functionality."""
from pydantic import BaseModel, Field
@@ -286,18 +408,101 @@ class TestChatWriter:
location: str = Field(..., description="The location to get weather for")
mock_client = AsyncMock()
mock_client.chat.chat.return_value = mock_response
mock_async_client = AsyncMock()
mock_async_client.chat.chat.return_value = mock_tool_call_choice_response
chat = ChatWriter(api_key=SecretStr("test-key"), async_client=mock_client)
chat = ChatWriter(api_key=SecretStr("key"))
chat_with_tools = chat.bind_tools(
tools=[GetWeather],
tool_choice="GetWeather",
)
response = await chat_with_tools.ainvoke("What's the weather in London?")
assert isinstance(response, AIMessage)
assert response.tool_calls
assert response.tool_calls[0]["name"] == "GetWeather"
assert response.tool_calls[0]["args"]["location"] == "London"
with mock.patch.object(chat, "async_client", mock_async_client):
response = await chat_with_tools.ainvoke("What's the weather in London?")
assert isinstance(response, AIMessage)
assert response.tool_calls
assert response.tool_calls[0]["name"] == "GetWeather"
assert response.tool_calls[0]["args"]["location"] == "London"
@pytest.mark.requires("writerai")
class TestChatWriterStandart(ChatModelUnitTests):
"""Test case for ChatWriter that inherits from standard LangChain tests."""
@property
def chat_model_class(self) -> Type[BaseChatModel]:
"""Return ChatWriter model class."""
return ChatWriter
@property
def chat_model_params(self) -> Dict:
"""Return any additional parameters needed."""
return {
"api_key": "fake-api-key",
"model_name": "palmyra-x-004",
}
@property
def has_tool_calling(self) -> bool:
"""Writer supports tool/function calling."""
return True
@property
def tool_choice_value(self) -> Optional[str]:
"""Value to use for tool choice in tests."""
return "auto"
@property
def has_structured_output(self) -> bool:
"""Writer does not yet support structured output."""
return False
@property
def supports_image_inputs(self) -> bool:
"""Writer does not support image inputs."""
return False
@property
def supports_video_inputs(self) -> bool:
"""Writer does not support video inputs."""
return False
@property
def returns_usage_metadata(self) -> bool:
"""Writer returns token usage information."""
return True
@property
def supports_anthropic_inputs(self) -> bool:
"""Writer does not support anthropic inputs."""
return False
@property
def supports_image_tool_message(self) -> bool:
"""Writer does not support image tool message."""
return False
@property
def supported_usage_metadata_details(
self,
) -> Dict[
Literal["invoke", "stream"],
List[
Literal[
"audio_input",
"audio_output",
"reasoning_output",
"cache_read_input",
"cache_creation_input",
]
],
]:
"""Return which types of usage metadata your model supports."""
return {"invoke": ["cache_creation_input"], "stream": ["reasoning_output"]}
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
"""Return env vars, init args, and expected instance attrs for initializing
from env vars."""
return {"WRITER_API_KEY": "key"}, {"api_key": "key"}, {"api_key": "key"}

View File

@@ -0,0 +1,202 @@
from typing import List
from unittest import mock
from unittest.mock import AsyncMock, MagicMock
import pytest
from langchain_core.callbacks import CallbackManager
from pydantic import SecretStr
from langchain_community.llms.writer import Writer
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
"""Classes for mocking Writer responses."""
class Choice:
def __init__(self, text: str):
self.text = text
class Completion:
def __init__(self, choices: List[Choice]):
self.choices = choices
class StreamingData:
def __init__(self, value: str):
self.value = value
@pytest.mark.requires("writerai")
class TestWriterLLM:
"""Unit tests for Writer LLM integration."""
@pytest.fixture(autouse=True)
def mock_unstreaming_completion(self) -> Completion:
"""Fixture providing a mock API response."""
return Completion(choices=[Choice(text="Hello! How can I help you?")])
@pytest.fixture(autouse=True)
def mock_streaming_completion(self) -> List[StreamingData]:
"""Fixture providing mock streaming response chunks."""
return [
StreamingData(value="Hello! "),
StreamingData(value="How can I"),
StreamingData(value=" help you?"),
]
def test_sync_unstream_completion(
self, mock_unstreaming_completion: Completion
) -> None:
"""Test basic llm call with mocked response."""
mock_client = MagicMock()
mock_client.completions.create.return_value = mock_unstreaming_completion
llm = Writer(api_key=SecretStr("key"))
with mock.patch.object(llm, "client", mock_client):
response_text = llm.invoke(input="Hello")
assert response_text == "Hello! How can I help you?"
def test_sync_unstream_completion_with_params(
self, mock_unstreaming_completion: Completion
) -> None:
"""Test llm call with passed params with mocked response."""
mock_client = MagicMock()
mock_client.completions.create.return_value = mock_unstreaming_completion
llm = Writer(api_key=SecretStr("key"), temperature=1)
with mock.patch.object(llm, "client", mock_client):
response_text = llm.invoke(input="Hello")
assert response_text == "Hello! How can I help you?"
@pytest.mark.asyncio
async def test_async_unstream_completion(
self, mock_unstreaming_completion: Completion
) -> None:
"""Test async chat completion with mocked response."""
mock_async_client = AsyncMock()
mock_async_client.completions.create.return_value = mock_unstreaming_completion
llm = Writer(api_key=SecretStr("key"))
with mock.patch.object(llm, "async_client", mock_async_client):
response_text = await llm.ainvoke(input="Hello")
assert response_text == "Hello! How can I help you?"
@pytest.mark.asyncio
async def test_async_unstream_completion_with_params(
self, mock_unstreaming_completion: Completion
) -> None:
"""Test async llm call with passed params with mocked response."""
mock_async_client = AsyncMock()
mock_async_client.completions.create.return_value = mock_unstreaming_completion
llm = Writer(api_key=SecretStr("key"), temperature=1)
with mock.patch.object(llm, "async_client", mock_async_client):
response_text = await llm.ainvoke(input="Hello")
assert response_text == "Hello! How can I help you?"
def test_sync_streaming_completion(
self, mock_streaming_completion: List[StreamingData]
) -> None:
"""Test sync streaming."""
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.__iter__.return_value = mock_streaming_completion
mock_client.completions.create.return_value = mock_response
llm = Writer(api_key=SecretStr("key"))
with mock.patch.object(llm, "client", mock_client):
response = llm.stream(input="Hello")
response_message = ""
for chunk in response:
response_message += chunk
assert response_message == "Hello! How can I help you?"
def test_sync_streaming_completion_with_callback_handler(
self, mock_streaming_completion: List[StreamingData]
) -> None:
"""Test sync streaming with callback handler."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.__iter__.return_value = mock_streaming_completion
mock_client.completions.create.return_value = mock_response
llm = Writer(
api_key=SecretStr("key"),
callback_manager=callback_manager,
)
with mock.patch.object(llm, "client", mock_client):
response = llm.stream(input="Hello")
response_message = ""
for chunk in response:
response_message += chunk
assert callback_handler.llm_streams == 3
assert response_message == "Hello! How can I help you?"
@pytest.mark.asyncio
async def test_async_streaming_completion(
self, mock_streaming_completion: Completion
) -> None:
"""Test async streaming with callback handler."""
mock_async_client = AsyncMock()
mock_response = AsyncMock()
mock_response.__aiter__.return_value = mock_streaming_completion
mock_async_client.completions.create.return_value = mock_response
llm = Writer(api_key=SecretStr("key"))
with mock.patch.object(llm, "async_client", mock_async_client):
response = llm.astream(input="Hello")
response_message = ""
async for chunk in response:
response_message += str(chunk)
assert response_message == "Hello! How can I help you?"
@pytest.mark.asyncio
async def test_async_streaming_completion_with_callback_handler(
self, mock_streaming_completion: Completion
) -> None:
"""Test async streaming with callback handler."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
mock_async_client = AsyncMock()
mock_response = AsyncMock()
mock_response.__aiter__.return_value = mock_streaming_completion
mock_async_client.completions.create.return_value = mock_response
llm = Writer(
api_key=SecretStr("key"),
callback_manager=callback_manager,
)
with mock.patch.object(llm, "async_client", mock_async_client):
response = llm.astream(input="Hello")
response_message = ""
async for chunk in response:
response_message += str(chunk)
assert callback_handler.llm_streams == 3
assert response_message == "Hello! How can I help you?"