mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
132 lines
4.3 KiB
Python
132 lines
4.3 KiB
Python
import json
|
|
|
|
import pytest # type: ignore[import-not-found]
|
|
from langchain_core.messages import (
|
|
AIMessage,
|
|
FunctionMessage,
|
|
HumanMessage,
|
|
SystemMessage,
|
|
ToolMessage,
|
|
)
|
|
from langchain_openai.chat_models.base import (
|
|
_convert_dict_to_message,
|
|
_convert_message_to_dict,
|
|
)
|
|
|
|
from langchain_xai import ChatXAI
|
|
|
|
MODEL_NAME = "grok-4"
|
|
|
|
|
|
def test_initialization() -> None:
|
|
"""Test chat model initialization."""
|
|
ChatXAI(model=MODEL_NAME)
|
|
|
|
|
|
def test_xai_model_param() -> None:
|
|
llm = ChatXAI(model="foo")
|
|
assert llm.model_name == "foo"
|
|
llm = ChatXAI(model_name="foo") # type: ignore[call-arg]
|
|
assert llm.model_name == "foo"
|
|
ls_params = llm._get_ls_params()
|
|
assert ls_params["ls_provider"] == "xai"
|
|
|
|
|
|
def test_chat_xai_invalid_streaming_params() -> None:
|
|
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
|
with pytest.raises(ValueError):
|
|
ChatXAI(
|
|
model=MODEL_NAME,
|
|
max_tokens=10,
|
|
streaming=True,
|
|
temperature=0,
|
|
n=5,
|
|
)
|
|
|
|
|
|
def test_chat_xai_extra_kwargs() -> None:
|
|
"""Test extra kwargs to chat xai."""
|
|
# Check that foo is saved in extra_kwargs.
|
|
llm = ChatXAI(model=MODEL_NAME, foo=3, max_tokens=10) # type: ignore[call-arg]
|
|
assert llm.max_tokens == 10
|
|
assert llm.model_kwargs == {"foo": 3}
|
|
|
|
# Test that if extra_kwargs are provided, they are added to it.
|
|
llm = ChatXAI(model=MODEL_NAME, foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg]
|
|
assert llm.model_kwargs == {"foo": 3, "bar": 2}
|
|
|
|
# Test that if provided twice it errors
|
|
with pytest.raises(ValueError):
|
|
ChatXAI(model=MODEL_NAME, foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg]
|
|
|
|
|
|
def test_function_dict_to_message_function_message() -> None:
|
|
content = json.dumps({"result": "Example #1"})
|
|
name = "test_function"
|
|
result = _convert_dict_to_message(
|
|
{
|
|
"role": "function",
|
|
"name": name,
|
|
"content": content,
|
|
}
|
|
)
|
|
assert isinstance(result, FunctionMessage)
|
|
assert result.name == name
|
|
assert result.content == content
|
|
|
|
|
|
def test_convert_dict_to_message_human() -> None:
|
|
message = {"role": "user", "content": "foo"}
|
|
result = _convert_dict_to_message(message)
|
|
expected_output = HumanMessage(content="foo")
|
|
assert result == expected_output
|
|
assert _convert_message_to_dict(expected_output) == message
|
|
|
|
|
|
def test__convert_dict_to_message_human_with_name() -> None:
|
|
message = {"role": "user", "content": "foo", "name": "test"}
|
|
result = _convert_dict_to_message(message)
|
|
expected_output = HumanMessage(content="foo", name="test")
|
|
assert result == expected_output
|
|
assert _convert_message_to_dict(expected_output) == message
|
|
|
|
|
|
def test_convert_dict_to_message_ai() -> None:
|
|
message = {"role": "assistant", "content": "foo"}
|
|
result = _convert_dict_to_message(message)
|
|
expected_output = AIMessage(content="foo")
|
|
assert result == expected_output
|
|
assert _convert_message_to_dict(expected_output) == message
|
|
|
|
|
|
def test_convert_dict_to_message_ai_with_name() -> None:
|
|
message = {"role": "assistant", "content": "foo", "name": "test"}
|
|
result = _convert_dict_to_message(message)
|
|
expected_output = AIMessage(content="foo", name="test")
|
|
assert result == expected_output
|
|
assert _convert_message_to_dict(expected_output) == message
|
|
|
|
|
|
def test_convert_dict_to_message_system() -> None:
|
|
message = {"role": "system", "content": "foo"}
|
|
result = _convert_dict_to_message(message)
|
|
expected_output = SystemMessage(content="foo")
|
|
assert result == expected_output
|
|
assert _convert_message_to_dict(expected_output) == message
|
|
|
|
|
|
def test_convert_dict_to_message_system_with_name() -> None:
|
|
message = {"role": "system", "content": "foo", "name": "test"}
|
|
result = _convert_dict_to_message(message)
|
|
expected_output = SystemMessage(content="foo", name="test")
|
|
assert result == expected_output
|
|
assert _convert_message_to_dict(expected_output) == message
|
|
|
|
|
|
def test_convert_dict_to_message_tool() -> None:
|
|
message = {"role": "tool", "content": "foo", "tool_call_id": "bar"}
|
|
result = _convert_dict_to_message(message)
|
|
expected_output = ToolMessage(content="foo", tool_call_id="bar")
|
|
assert result == expected_output
|
|
assert _convert_message_to_dict(expected_output) == message
|