Files
langchain/libs/partners/together/tests/unit_tests/test_chat_models.py
ccurme 181dfef118 core, standard tests, partner packages: add test for model params (#21677)
1. Adds `.get_ls_params` to BaseChatModel which returns
```python
class LangSmithParams(TypedDict, total=False):
    ls_provider: str
    ls_model_name: str
    ls_model_type: Literal["chat"]
    ls_temperature: Optional[float]
    ls_max_tokens: Optional[int]
    ls_stop: Optional[List[str]]
```
by default it will only return
```python
{ls_model_type="chat", ls_stop=stop}
```

2. Add these params to inheritable metadata in
`CallbackManager.configure`

3. Implement `.get_ls_params` and populate all params for Anthropic +
all subclasses of BaseChatOpenAI

Sample trace:
https://smith.langchain.com/public/d2962673-4c83-47c7-b51e-61d07aaffb1b/r

**OpenAI**:
<img width="984" alt="Screenshot 2024-05-17 at 10 03 35 AM"
src="https://github.com/langchain-ai/langchain/assets/26529506/2ef41f74-a9df-4e0e-905d-da74fa82a910">

**Anthropic**:
<img width="978" alt="Screenshot 2024-05-17 at 10 06 07 AM"
src="https://github.com/langchain-ai/langchain/assets/26529506/39701c9f-7da5-4f1a-ab14-84e9169d63e7">

**Mistral** (and all others for which params are not yet populated):
<img width="977" alt="Screenshot 2024-05-17 at 10 08 43 AM"
src="https://github.com/langchain-ai/langchain/assets/26529506/37d7d894-fec2-4300-986f-49a5f0191b03">
2024-05-17 13:51:26 -04:00

195 lines
5.7 KiB
Python

import json
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from langchain_core.messages import (
AIMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_openai.chat_models.base import (
_convert_dict_to_message,
_convert_message_to_dict,
)
from langchain_together import ChatTogether
def test_initialization() -> None:
"""Test chat model initialization."""
ChatTogether()
def test_together_model_param() -> None:
llm = ChatTogether(model="foo")
assert llm.model_name == "foo"
llm = ChatTogether(model_name="foo")
assert llm.model_name == "foo"
ls_params = llm._get_ls_params()
assert ls_params["ls_provider"] == "together"
def test_function_dict_to_message_function_message() -> None:
content = json.dumps({"result": "Example #1"})
name = "test_function"
result = _convert_dict_to_message(
{
"role": "function",
"name": name,
"content": content,
}
)
assert isinstance(result, FunctionMessage)
assert result.name == name
assert result.content == content
def test_convert_dict_to_message_human() -> None:
message = {"role": "user", "content": "foo"}
result = _convert_dict_to_message(message)
expected_output = HumanMessage(content="foo")
assert result == expected_output
assert _convert_message_to_dict(expected_output) == message
def test__convert_dict_to_message_human_with_name() -> None:
message = {"role": "user", "content": "foo", "name": "test"}
result = _convert_dict_to_message(message)
expected_output = HumanMessage(content="foo", name="test")
assert result == expected_output
assert _convert_message_to_dict(expected_output) == message
def test_convert_dict_to_message_ai() -> None:
message = {"role": "assistant", "content": "foo"}
result = _convert_dict_to_message(message)
expected_output = AIMessage(content="foo")
assert result == expected_output
assert _convert_message_to_dict(expected_output) == message
def test_convert_dict_to_message_ai_with_name() -> None:
message = {"role": "assistant", "content": "foo", "name": "test"}
result = _convert_dict_to_message(message)
expected_output = AIMessage(content="foo", name="test")
assert result == expected_output
assert _convert_message_to_dict(expected_output) == message
def test_convert_dict_to_message_system() -> None:
message = {"role": "system", "content": "foo"}
result = _convert_dict_to_message(message)
expected_output = SystemMessage(content="foo")
assert result == expected_output
assert _convert_message_to_dict(expected_output) == message
def test_convert_dict_to_message_system_with_name() -> None:
message = {"role": "system", "content": "foo", "name": "test"}
result = _convert_dict_to_message(message)
expected_output = SystemMessage(content="foo", name="test")
assert result == expected_output
assert _convert_message_to_dict(expected_output) == message
def test_convert_dict_to_message_tool() -> None:
message = {"role": "tool", "content": "foo", "tool_call_id": "bar"}
result = _convert_dict_to_message(message)
expected_output = ToolMessage(content="foo", tool_call_id="bar")
assert result == expected_output
assert _convert_message_to_dict(expected_output) == message
@pytest.fixture
def mock_completion() -> dict:
return {
"id": "chatcmpl-7fcZavknQda3SQ",
"object": "chat.completion",
"created": 1689989000,
"model": "meta-llama/Llama-3-8b-chat-hf",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Bab",
"name": "KimSolar",
},
"finish_reason": "stop",
}
],
}
def test_together_invoke(mock_completion: dict) -> None:
llm = ChatTogether()
mock_client = MagicMock()
completed = False
def mock_create(*args: Any, **kwargs: Any) -> Any:
nonlocal completed
completed = True
return mock_completion
mock_client.create = mock_create
with patch.object(
llm,
"client",
mock_client,
):
res = llm.invoke("bab")
assert res.content == "Bab"
assert completed
async def test_together_ainvoke(mock_completion: dict) -> None:
llm = ChatTogether()
mock_client = AsyncMock()
completed = False
async def mock_create(*args: Any, **kwargs: Any) -> Any:
nonlocal completed
completed = True
return mock_completion
mock_client.create = mock_create
with patch.object(
llm,
"async_client",
mock_client,
):
res = await llm.ainvoke("bab")
assert res.content == "Bab"
assert completed
def test_together_invoke_name(mock_completion: dict) -> None:
llm = ChatTogether()
mock_client = MagicMock()
mock_client.create.return_value = mock_completion
with patch.object(
llm,
"client",
mock_client,
):
messages = [
HumanMessage(content="Foo", name="Zorba"),
]
res = llm.invoke(messages)
call_args, call_kwargs = mock_client.create.call_args
assert len(call_args) == 0 # no positional args
call_messages = call_kwargs["messages"]
assert len(call_messages) == 1
assert call_messages[0]["role"] == "user"
assert call_messages[0]["content"] == "Foo"
assert call_messages[0]["name"] == "Zorba"
# check return type has name
assert res.content == "Bab"
assert res.name == "KimSolar"