mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-20 01:49:51 +00:00
**Description:** Adding chat completions to the Together AI package, which is our most popular API. Also staying backwards compatible with the old API so folks can continue to use the completions API as well. Also moved the embedding API to use the OpenAI library to standardize it further. **Twitter handle:** @nutlope - [x] **Add tests and docs**: If you're adding a new integration, please include - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, hwchase17. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
193 lines
5.6 KiB
Python
193 lines
5.6 KiB
Python
import json
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from langchain_core.messages import (
|
|
AIMessage,
|
|
FunctionMessage,
|
|
HumanMessage,
|
|
SystemMessage,
|
|
ToolMessage,
|
|
)
|
|
from langchain_openai.chat_models.base import (
|
|
_convert_dict_to_message,
|
|
_convert_message_to_dict,
|
|
)
|
|
|
|
from langchain_together import ChatTogether
|
|
|
|
|
|
def test_initialization() -> None:
|
|
"""Test chat model initialization."""
|
|
ChatTogether()
|
|
|
|
|
|
def test_together_model_param() -> None:
|
|
llm = ChatTogether(model="foo")
|
|
assert llm.model_name == "foo"
|
|
llm = ChatTogether(model_name="foo")
|
|
assert llm.model_name == "foo"
|
|
|
|
|
|
def test_function_dict_to_message_function_message() -> None:
|
|
content = json.dumps({"result": "Example #1"})
|
|
name = "test_function"
|
|
result = _convert_dict_to_message(
|
|
{
|
|
"role": "function",
|
|
"name": name,
|
|
"content": content,
|
|
}
|
|
)
|
|
assert isinstance(result, FunctionMessage)
|
|
assert result.name == name
|
|
assert result.content == content
|
|
|
|
|
|
def test_convert_dict_to_message_human() -> None:
|
|
message = {"role": "user", "content": "foo"}
|
|
result = _convert_dict_to_message(message)
|
|
expected_output = HumanMessage(content="foo")
|
|
assert result == expected_output
|
|
assert _convert_message_to_dict(expected_output) == message
|
|
|
|
|
|
def test__convert_dict_to_message_human_with_name() -> None:
|
|
message = {"role": "user", "content": "foo", "name": "test"}
|
|
result = _convert_dict_to_message(message)
|
|
expected_output = HumanMessage(content="foo", name="test")
|
|
assert result == expected_output
|
|
assert _convert_message_to_dict(expected_output) == message
|
|
|
|
|
|
def test_convert_dict_to_message_ai() -> None:
|
|
message = {"role": "assistant", "content": "foo"}
|
|
result = _convert_dict_to_message(message)
|
|
expected_output = AIMessage(content="foo")
|
|
assert result == expected_output
|
|
assert _convert_message_to_dict(expected_output) == message
|
|
|
|
|
|
def test_convert_dict_to_message_ai_with_name() -> None:
|
|
message = {"role": "assistant", "content": "foo", "name": "test"}
|
|
result = _convert_dict_to_message(message)
|
|
expected_output = AIMessage(content="foo", name="test")
|
|
assert result == expected_output
|
|
assert _convert_message_to_dict(expected_output) == message
|
|
|
|
|
|
def test_convert_dict_to_message_system() -> None:
|
|
message = {"role": "system", "content": "foo"}
|
|
result = _convert_dict_to_message(message)
|
|
expected_output = SystemMessage(content="foo")
|
|
assert result == expected_output
|
|
assert _convert_message_to_dict(expected_output) == message
|
|
|
|
|
|
def test_convert_dict_to_message_system_with_name() -> None:
|
|
message = {"role": "system", "content": "foo", "name": "test"}
|
|
result = _convert_dict_to_message(message)
|
|
expected_output = SystemMessage(content="foo", name="test")
|
|
assert result == expected_output
|
|
assert _convert_message_to_dict(expected_output) == message
|
|
|
|
|
|
def test_convert_dict_to_message_tool() -> None:
|
|
message = {"role": "tool", "content": "foo", "tool_call_id": "bar"}
|
|
result = _convert_dict_to_message(message)
|
|
expected_output = ToolMessage(content="foo", tool_call_id="bar")
|
|
assert result == expected_output
|
|
assert _convert_message_to_dict(expected_output) == message
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_completion() -> dict:
|
|
return {
|
|
"id": "chatcmpl-7fcZavknQda3SQ",
|
|
"object": "chat.completion",
|
|
"created": 1689989000,
|
|
"model": "meta-llama/Llama-3-8b-chat-hf",
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "Bab",
|
|
"name": "KimSolar",
|
|
},
|
|
"finish_reason": "stop",
|
|
}
|
|
],
|
|
}
|
|
|
|
|
|
def test_together_invoke(mock_completion: dict) -> None:
|
|
llm = ChatTogether()
|
|
mock_client = MagicMock()
|
|
completed = False
|
|
|
|
def mock_create(*args: Any, **kwargs: Any) -> Any:
|
|
nonlocal completed
|
|
completed = True
|
|
return mock_completion
|
|
|
|
mock_client.create = mock_create
|
|
with patch.object(
|
|
llm,
|
|
"client",
|
|
mock_client,
|
|
):
|
|
res = llm.invoke("bab")
|
|
assert res.content == "Bab"
|
|
assert completed
|
|
|
|
|
|
async def test_together_ainvoke(mock_completion: dict) -> None:
|
|
llm = ChatTogether()
|
|
mock_client = AsyncMock()
|
|
completed = False
|
|
|
|
async def mock_create(*args: Any, **kwargs: Any) -> Any:
|
|
nonlocal completed
|
|
completed = True
|
|
return mock_completion
|
|
|
|
mock_client.create = mock_create
|
|
with patch.object(
|
|
llm,
|
|
"async_client",
|
|
mock_client,
|
|
):
|
|
res = await llm.ainvoke("bab")
|
|
assert res.content == "Bab"
|
|
assert completed
|
|
|
|
|
|
def test_together_invoke_name(mock_completion: dict) -> None:
|
|
llm = ChatTogether()
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.create.return_value = mock_completion
|
|
|
|
with patch.object(
|
|
llm,
|
|
"client",
|
|
mock_client,
|
|
):
|
|
messages = [
|
|
HumanMessage(content="Foo", name="Zorba"),
|
|
]
|
|
res = llm.invoke(messages)
|
|
call_args, call_kwargs = mock_client.create.call_args
|
|
assert len(call_args) == 0 # no positional args
|
|
call_messages = call_kwargs["messages"]
|
|
assert len(call_messages) == 1
|
|
assert call_messages[0]["role"] == "user"
|
|
assert call_messages[0]["content"] == "Foo"
|
|
assert call_messages[0]["name"] == "Zorba"
|
|
|
|
# check return type has name
|
|
assert res.content == "Bab"
|
|
assert res.name == "KimSolar"
|