mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
fix(xai): inject model_provider in response_metadata (#33543)
plus tests minor rfc
This commit is contained in:
@@ -529,6 +529,9 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
|||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
rtn = super()._create_chat_result(response, generation_info)
|
rtn = super()._create_chat_result(response, generation_info)
|
||||||
|
|
||||||
|
for generation in rtn.generations:
|
||||||
|
generation.message.response_metadata["model_provider"] = "xai"
|
||||||
|
|
||||||
if not isinstance(response, openai.BaseModel):
|
if not isinstance(response, openai.BaseModel):
|
||||||
return rtn
|
return rtn
|
||||||
|
|
||||||
@@ -555,6 +558,10 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
|||||||
default_chunk_class,
|
default_chunk_class,
|
||||||
base_generation_info,
|
base_generation_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if generation_chunk:
|
||||||
|
generation_chunk.message.response_metadata["model_provider"] = "xai"
|
||||||
|
|
||||||
if (choices := chunk.get("choices")) and generation_chunk:
|
if (choices := chunk.get("choices")) and generation_chunk:
|
||||||
top = choices[0]
|
top = choices[0]
|
||||||
if isinstance(generation_chunk.message, AIMessageChunk) and (
|
if isinstance(generation_chunk.message, AIMessageChunk) and (
|
||||||
|
|||||||
@@ -0,0 +1,96 @@
|
|||||||
|
"""Integration tests for ChatXAI specific features."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessageChunk
|
||||||
|
|
||||||
|
from langchain_xai import ChatXAI
|
||||||
|
|
||||||
|
MODEL_NAME = "grok-4-fast-reasoning"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("output_version", ["", "v1"])
|
||||||
|
def test_reasoning(output_version: Literal["", "v1"]) -> None:
|
||||||
|
"""Test reasoning features.
|
||||||
|
|
||||||
|
Note: `grok-4` does not return `reasoning_content`, but may optionally return
|
||||||
|
encrypted reasoning content if `use_encrypted_content` is set to True.
|
||||||
|
"""
|
||||||
|
# Test reasoning effort
|
||||||
|
if output_version:
|
||||||
|
chat_model = ChatXAI(
|
||||||
|
model="grok-3-mini",
|
||||||
|
reasoning_effort="low",
|
||||||
|
output_version=output_version,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chat_model = ChatXAI(
|
||||||
|
model="grok-3-mini",
|
||||||
|
reasoning_effort="low",
|
||||||
|
)
|
||||||
|
input_message = "What is 3^3?"
|
||||||
|
response = chat_model.invoke(input_message)
|
||||||
|
assert response.content
|
||||||
|
assert response.additional_kwargs["reasoning_content"]
|
||||||
|
|
||||||
|
# Test streaming
|
||||||
|
full: BaseMessageChunk | None = None
|
||||||
|
for chunk in chat_model.stream(input_message):
|
||||||
|
full = chunk if full is None else full + chunk
|
||||||
|
assert isinstance(full, AIMessageChunk)
|
||||||
|
assert full.additional_kwargs["reasoning_content"]
|
||||||
|
|
||||||
|
# Check that we can access reasoning content blocks
|
||||||
|
assert response.content_blocks
|
||||||
|
reasoning_content = (
|
||||||
|
block for block in response.content_blocks if block["type"] == "reasoning"
|
||||||
|
)
|
||||||
|
assert len(list(reasoning_content)) >= 1
|
||||||
|
|
||||||
|
# Test that passing message with reasoning back in works
|
||||||
|
follow_up_message = "Based on your reasoning, what is 4^4?"
|
||||||
|
followup = chat_model.invoke([input_message, response, follow_up_message])
|
||||||
|
assert followup.content
|
||||||
|
assert followup.additional_kwargs["reasoning_content"]
|
||||||
|
followup_reasoning = (
|
||||||
|
block for block in followup.content_blocks if block["type"] == "reasoning"
|
||||||
|
)
|
||||||
|
assert len(list(followup_reasoning)) >= 1
|
||||||
|
|
||||||
|
# Test passing in a ReasoningContentBlock
|
||||||
|
response_metadata = {"model_provider": "xai"}
|
||||||
|
if output_version:
|
||||||
|
response_metadata["output_version"] = output_version
|
||||||
|
msg_w_reasoning = AIMessage(
|
||||||
|
content_blocks=response.content_blocks,
|
||||||
|
response_metadata=response_metadata,
|
||||||
|
)
|
||||||
|
followup_2 = chat_model.invoke(
|
||||||
|
[msg_w_reasoning, "Based on your reasoning, what is 5^5?"]
|
||||||
|
)
|
||||||
|
assert followup_2.content
|
||||||
|
assert followup_2.additional_kwargs["reasoning_content"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_web_search() -> None:
|
||||||
|
llm = ChatXAI(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
search_parameters={"mode": "on", "max_search_results": 3},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test invoke
|
||||||
|
response = llm.invoke("Provide me a digest of world news in the last 24 hours.")
|
||||||
|
assert response.content
|
||||||
|
assert response.additional_kwargs["citations"]
|
||||||
|
assert len(response.additional_kwargs["citations"]) <= 3
|
||||||
|
|
||||||
|
# Test streaming
|
||||||
|
full = None
|
||||||
|
for chunk in llm.stream("Provide me a digest of world news in the last 24 hours."):
|
||||||
|
full = chunk if full is None else full + chunk
|
||||||
|
assert isinstance(full, AIMessageChunk)
|
||||||
|
assert full.additional_kwargs["citations"]
|
||||||
|
assert len(full.additional_kwargs["citations"]) <= 3
|
||||||
@@ -4,26 +4,25 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from langchain_core.messages import AIMessageChunk, BaseMessageChunk
|
import pytest
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
from langchain_core.rate_limiters import InMemoryRateLimiter
|
from langchain_core.rate_limiters import InMemoryRateLimiter
|
||||||
from langchain_tests.integration_tests import ( # type: ignore[import-not-found]
|
from langchain_tests.integration_tests import ( # type: ignore[import-not-found]
|
||||||
ChatModelIntegrationTests, # type: ignore[import-not-found]
|
ChatModelIntegrationTests, # type: ignore[import-not-found]
|
||||||
)
|
)
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_xai import ChatXAI
|
from langchain_xai import ChatXAI
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
|
||||||
# Initialize the rate limiter in global scope, so it can be re-used
|
# Initialize the rate limiter in global scope, so it can be re-used across tests
|
||||||
# across tests.
|
|
||||||
rate_limiter = InMemoryRateLimiter(
|
rate_limiter = InMemoryRateLimiter(
|
||||||
requests_per_second=0.5,
|
requests_per_second=0.5,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
MODEL_NAME = "grok-4-fast-reasoning"
|
||||||
# Not using Grok 4 since it doesn't support reasoning params (effort) or returns
|
|
||||||
# reasoning content.
|
|
||||||
|
|
||||||
|
|
||||||
class TestXAIStandard(ChatModelIntegrationTests):
|
class TestXAIStandard(ChatModelIntegrationTests):
|
||||||
@@ -33,48 +32,28 @@ class TestXAIStandard(ChatModelIntegrationTests):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def chat_model_params(self) -> dict:
|
def chat_model_params(self) -> dict:
|
||||||
# TODO: bump to test new Grok once they implement other features
|
|
||||||
return {
|
return {
|
||||||
"model": "grok-3",
|
"model": MODEL_NAME,
|
||||||
"rate_limiter": rate_limiter,
|
"rate_limiter": rate_limiter,
|
||||||
"stream_usage": True,
|
"stream_usage": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@pytest.mark.xfail(
|
||||||
def test_reasoning_content() -> None:
|
reason="Default model does not support stop sequences, using grok-3 instead"
|
||||||
"""Test reasoning content."""
|
|
||||||
chat_model = ChatXAI(
|
|
||||||
model="grok-3-mini",
|
|
||||||
reasoning_effort="low",
|
|
||||||
)
|
)
|
||||||
response = chat_model.invoke("What is 3^3?")
|
@override
|
||||||
assert response.content
|
def test_stop_sequence(self, model: BaseChatModel) -> None:
|
||||||
assert response.additional_kwargs["reasoning_content"]
|
"""Override to use `grok-3` which supports stop sequences."""
|
||||||
|
params = {**self.chat_model_params, "model": "grok-3"}
|
||||||
|
|
||||||
# Test streaming
|
grok3_model = ChatXAI(**params)
|
||||||
full: BaseMessageChunk | None = None
|
|
||||||
for chunk in chat_model.stream("What is 3^3?"):
|
|
||||||
full = chunk if full is None else full + chunk
|
|
||||||
assert isinstance(full, AIMessageChunk)
|
|
||||||
assert full.additional_kwargs["reasoning_content"]
|
|
||||||
|
|
||||||
|
result = grok3_model.invoke("hi", stop=["you"])
|
||||||
|
assert isinstance(result, AIMessage)
|
||||||
|
|
||||||
def test_web_search() -> None:
|
custom_model = ChatXAI(
|
||||||
llm = ChatXAI(
|
**params,
|
||||||
model="grok-3",
|
stop_sequences=["you"],
|
||||||
search_parameters={"mode": "auto", "max_search_results": 3},
|
)
|
||||||
)
|
result = custom_model.invoke("hi")
|
||||||
|
assert isinstance(result, AIMessage)
|
||||||
# Test invoke
|
|
||||||
response = llm.invoke("Provide me a digest of world news in the last 24 hours.")
|
|
||||||
assert response.content
|
|
||||||
assert response.additional_kwargs["citations"]
|
|
||||||
assert len(response.additional_kwargs["citations"]) <= 3
|
|
||||||
|
|
||||||
# Test streaming
|
|
||||||
full = None
|
|
||||||
for chunk in llm.stream("Provide me a digest of world news in the last 24 hours."):
|
|
||||||
full = chunk if full is None else full + chunk
|
|
||||||
assert isinstance(full, AIMessageChunk)
|
|
||||||
assert full.additional_kwargs["citations"]
|
|
||||||
assert len(full.additional_kwargs["citations"]) <= 3
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@
|
|||||||
'kwargs': dict({
|
'kwargs': dict({
|
||||||
'max_retries': 2,
|
'max_retries': 2,
|
||||||
'max_tokens': 100,
|
'max_tokens': 100,
|
||||||
'model_name': 'grok-beta',
|
'model_name': 'grok-4',
|
||||||
'request_timeout': 60.0,
|
'request_timeout': 60.0,
|
||||||
'stop': list([
|
'stop': list([
|
||||||
]),
|
]),
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ from langchain_tests.unit_tests import ( # type: ignore[import-not-found]
|
|||||||
|
|
||||||
from langchain_xai import ChatXAI
|
from langchain_xai import ChatXAI
|
||||||
|
|
||||||
|
MODEL_NAME = "grok-4"
|
||||||
|
|
||||||
|
|
||||||
class TestXAIStandard(ChatModelUnitTests):
|
class TestXAIStandard(ChatModelUnitTests):
|
||||||
@property
|
@property
|
||||||
@@ -15,7 +17,7 @@ class TestXAIStandard(ChatModelUnitTests):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def chat_model_params(self) -> dict:
|
def chat_model_params(self) -> dict:
|
||||||
return {"model": "grok-beta"}
|
return {"model": MODEL_NAME}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def init_from_env_params(self) -> tuple[dict, dict, dict]:
|
def init_from_env_params(self) -> tuple[dict, dict, dict]:
|
||||||
@@ -24,7 +26,7 @@ class TestXAIStandard(ChatModelUnitTests):
|
|||||||
"XAI_API_KEY": "api_key",
|
"XAI_API_KEY": "api_key",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"model": "grok-beta",
|
"model": MODEL_NAME,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"xai_api_key": "api_key",
|
"xai_api_key": "api_key",
|
||||||
|
|||||||
5
libs/partners/xai/uv.lock
generated
5
libs/partners/xai/uv.lock
generated
@@ -621,7 +621,7 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain-core"
|
name = "langchain-core"
|
||||||
version = "1.0.0a8"
|
version = "1.0.0rc2"
|
||||||
source = { editable = "../../core" }
|
source = { editable = "../../core" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "jsonpatch" },
|
{ name = "jsonpatch" },
|
||||||
@@ -699,6 +699,7 @@ dev = [{ name = "langchain-core", editable = "../../core" }]
|
|||||||
lint = [{ name = "ruff", specifier = ">=0.13.1,<0.14.0" }]
|
lint = [{ name = "ruff", specifier = ">=0.13.1,<0.14.0" }]
|
||||||
test = [
|
test = [
|
||||||
{ name = "freezegun", specifier = ">=1.2.2,<2.0.0" },
|
{ name = "freezegun", specifier = ">=1.2.2,<2.0.0" },
|
||||||
|
{ name = "langchain", editable = "../../langchain_v1" },
|
||||||
{ name = "langchain-core", editable = "../../core" },
|
{ name = "langchain-core", editable = "../../core" },
|
||||||
{ name = "langchain-tests", editable = "../../standard-tests" },
|
{ name = "langchain-tests", editable = "../../standard-tests" },
|
||||||
{ name = "numpy", marker = "python_full_version < '3.13'", specifier = ">=1.26.4" },
|
{ name = "numpy", marker = "python_full_version < '3.13'", specifier = ">=1.26.4" },
|
||||||
@@ -728,7 +729,7 @@ typing = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain-tests"
|
name = "langchain-tests"
|
||||||
version = "1.0.0a2"
|
version = "1.0.0rc1"
|
||||||
source = { editable = "../../standard-tests" }
|
source = { editable = "../../standard-tests" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "httpx" },
|
{ name = "httpx" },
|
||||||
|
|||||||
Reference in New Issue
Block a user