mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 18:50:33 +00:00
Add `base_url` alias and `XAI_API_BASE` env variable support to
`ChatXAI.xai_api_base`, aligning the xAI integration with the pattern
used across other partner packages (OpenAI, Groq, Fireworks, etc.).
Previously the base URL was a plain string field with no alias or
env-var lookup, making it inconsistent with the rest of the ecosystem
and harder to configure in deployment environments.
## Changes
- Add `alias="base_url"` and `default_factory=from_env("XAI_API_BASE",
default="https://api.x.ai/v1/")` to `ChatXAI.xai_api_base`, matching the
convention in `langchain_openai`, `langchain_groq`, and
`langchain_fireworks`
167 lines
5.2 KiB
Python
167 lines
5.2 KiB
Python
import json
|
|
|
|
import pytest # type: ignore[import-not-found]
|
|
from langchain_core.messages import (
|
|
AIMessage,
|
|
FunctionMessage,
|
|
HumanMessage,
|
|
SystemMessage,
|
|
ToolMessage,
|
|
)
|
|
from langchain_openai.chat_models.base import (
|
|
_convert_dict_to_message,
|
|
_convert_message_to_dict,
|
|
)
|
|
from pydantic import SecretStr
|
|
|
|
from langchain_xai import ChatXAI
|
|
|
|
MODEL_NAME = "grok-4"
|
|
|
|
|
|
def test_initialization() -> None:
|
|
"""Test chat model initialization."""
|
|
ChatXAI(model=MODEL_NAME)
|
|
|
|
|
|
def test_profile() -> None:
|
|
model = ChatXAI(model="grok-4")
|
|
assert model.profile
|
|
|
|
|
|
def test_xai_model_param() -> None:
|
|
llm = ChatXAI(model="foo")
|
|
assert llm.model_name == "foo"
|
|
llm = ChatXAI(model_name="foo") # type: ignore[call-arg]
|
|
assert llm.model_name == "foo"
|
|
ls_params = llm._get_ls_params()
|
|
assert ls_params.get("ls_provider") == "xai"
|
|
|
|
|
|
def test_chat_xai_invalid_streaming_params() -> None:
|
|
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
|
with pytest.raises(ValueError):
|
|
ChatXAI(
|
|
model=MODEL_NAME,
|
|
max_tokens=10,
|
|
streaming=True,
|
|
temperature=0,
|
|
n=5,
|
|
)
|
|
|
|
|
|
def test_chat_xai_extra_kwargs() -> None:
|
|
"""Test extra kwargs to chat xai."""
|
|
# Check that foo is saved in extra_kwargs.
|
|
llm = ChatXAI(model=MODEL_NAME, foo=3, max_tokens=10) # type: ignore[call-arg]
|
|
assert llm.max_tokens == 10
|
|
assert llm.model_kwargs == {"foo": 3}
|
|
|
|
# Test that if extra_kwargs are provided, they are added to it.
|
|
llm = ChatXAI(model=MODEL_NAME, foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg]
|
|
assert llm.model_kwargs == {"foo": 3, "bar": 2}
|
|
|
|
# Test that if provided twice it errors
|
|
with pytest.raises(ValueError):
|
|
ChatXAI(model=MODEL_NAME, foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg]
|
|
|
|
|
|
def test_chat_xai_base_url_alias() -> None:
|
|
llm = ChatXAI(
|
|
model=MODEL_NAME,
|
|
api_key=SecretStr("test-api-key"),
|
|
base_url="http://example.test/v1",
|
|
)
|
|
assert llm.xai_api_base == "http://example.test/v1"
|
|
assert llm.model_kwargs == {}
|
|
|
|
|
|
def test_chat_xai_api_base_from_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
monkeypatch.setenv("XAI_API_BASE", "http://env.example.test/v1")
|
|
|
|
llm = ChatXAI(
|
|
model=MODEL_NAME,
|
|
api_key=SecretStr("test-api-key"),
|
|
)
|
|
|
|
assert llm.xai_api_base == "http://env.example.test/v1"
|
|
|
|
|
|
def test_function_dict_to_message_function_message() -> None:
|
|
content = json.dumps({"result": "Example #1"})
|
|
name = "test_function"
|
|
result = _convert_dict_to_message(
|
|
{
|
|
"role": "function",
|
|
"name": name,
|
|
"content": content,
|
|
}
|
|
)
|
|
assert isinstance(result, FunctionMessage)
|
|
assert result.name == name
|
|
assert result.content == content
|
|
|
|
|
|
def test_convert_dict_to_message_human() -> None:
|
|
message = {"role": "user", "content": "foo"}
|
|
result = _convert_dict_to_message(message)
|
|
expected_output = HumanMessage(content="foo")
|
|
assert result == expected_output
|
|
assert _convert_message_to_dict(expected_output) == message
|
|
|
|
|
|
def test__convert_dict_to_message_human_with_name() -> None:
|
|
message = {"role": "user", "content": "foo", "name": "test"}
|
|
result = _convert_dict_to_message(message)
|
|
expected_output = HumanMessage(content="foo", name="test")
|
|
assert result == expected_output
|
|
assert _convert_message_to_dict(expected_output) == message
|
|
|
|
|
|
def test_convert_dict_to_message_ai() -> None:
|
|
message = {"role": "assistant", "content": "foo"}
|
|
result = _convert_dict_to_message(message)
|
|
expected_output = AIMessage(content="foo")
|
|
assert result == expected_output
|
|
assert _convert_message_to_dict(expected_output) == message
|
|
|
|
|
|
def test_convert_dict_to_message_ai_with_name() -> None:
|
|
message = {"role": "assistant", "content": "foo", "name": "test"}
|
|
result = _convert_dict_to_message(message)
|
|
expected_output = AIMessage(content="foo", name="test")
|
|
assert result == expected_output
|
|
assert _convert_message_to_dict(expected_output) == message
|
|
|
|
|
|
def test_convert_dict_to_message_system() -> None:
|
|
message = {"role": "system", "content": "foo"}
|
|
result = _convert_dict_to_message(message)
|
|
expected_output = SystemMessage(content="foo")
|
|
assert result == expected_output
|
|
assert _convert_message_to_dict(expected_output) == message
|
|
|
|
|
|
def test_convert_dict_to_message_system_with_name() -> None:
|
|
message = {"role": "system", "content": "foo", "name": "test"}
|
|
result = _convert_dict_to_message(message)
|
|
expected_output = SystemMessage(content="foo", name="test")
|
|
assert result == expected_output
|
|
assert _convert_message_to_dict(expected_output) == message
|
|
|
|
|
|
def test_convert_dict_to_message_tool() -> None:
|
|
message = {"role": "tool", "content": "foo", "tool_call_id": "bar"}
|
|
result = _convert_dict_to_message(message)
|
|
expected_output = ToolMessage(content="foo", tool_call_id="bar")
|
|
assert result == expected_output
|
|
assert _convert_message_to_dict(expected_output) == message
|
|
|
|
|
|
def test_stream_usage_metadata() -> None:
|
|
model = ChatXAI(model=MODEL_NAME)
|
|
assert model.stream_usage is True
|
|
|
|
model = ChatXAI(model=MODEL_NAME, stream_usage=False)
|
|
assert model.stream_usage is False
|