docs(xai): update for Grok 4 (#31953)

This commit is contained in:
Mason Daugherty 2025-07-10 11:06:37 -04:00 committed by GitHub
parent 060fc0e3c9
commit 6594eb8cc1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 168 additions and 61 deletions

View File

@ -345,7 +345,7 @@
"source": [
"## API reference\n",
"\n",
"For detailed documentation of all `ChatXAI` features and configurations, head to the API reference: https://python.langchain.com/api_reference/xai/chat_models/langchain_xai.chat_models.ChatXAI.html"
"For detailed documentation of all `ChatXAI` features and configurations, head to the [API reference](https://python.langchain.com/api_reference/xai/chat_models/langchain_xai.chat_models.ChatXAI.html)."
]
}
],

View File

@ -63,7 +63,7 @@
"\n",
"chat = ChatXAI(\n",
" # xai_api_key=\"YOUR_API_KEY\",\n",
" model=\"grok-beta\",\n",
" model=\"grok-4\",\n",
")\n",
"\n",
"# stream the response back from the model\n",

View File

@ -395,10 +395,10 @@ class ChatGroq(BaseChatModel):
# Configure a custom httpx client. See the
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
http_client: Union[Any, None] = None
"""Optional httpx.Client."""
"""Optional ``httpx.Client``."""
http_async_client: Union[Any, None] = None
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
http_client as well if you'd like a custom client for sync invocations."""
"""Optional ``httpx.AsyncClient``. Only used for async invocations. Must specify
``http_client`` as well if you'd like a custom client for sync invocations."""
model_config = ConfigDict(
populate_by_name=True,

View File

@ -76,13 +76,13 @@ class AzureChatOpenAI(BaseChatOpenAI):
Max number of retries.
organization: Optional[str]
OpenAI organization ID. If not passed in will be read from env
var OPENAI_ORG_ID.
var ``OPENAI_ORG_ID``.
model: Optional[str]
The name of the underlying OpenAI model. Used for tracing and token
counting. Does not affect completion. E.g. "gpt-4", "gpt-35-turbo", etc.
counting. Does not affect completion. E.g. ``'gpt-4'``, ``'gpt-35-turbo'``, etc.
model_version: Optional[str]
The version of the underlying OpenAI model. Used for tracing and token
counting. Does not affect completion. E.g., "0125", "0125-preview", etc.
counting. Does not affect completion. E.g., ``'0125'``, ``'0125-preview'``, etc.
See full list of supported init args and their descriptions in the params section.

View File

@ -542,12 +542,13 @@ class BaseChatOpenAI(BaseChatModel):
# Configure a custom httpx client. See the
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
http_client: Union[Any, None] = Field(default=None, exclude=True)
"""Optional httpx.Client. Only used for sync invocations. Must specify
http_async_client as well if you'd like a custom client for async invocations.
"""Optional ``httpx.Client``. Only used for sync invocations. Must specify
``http_async_client`` as well if you'd like a custom client for async
invocations.
"""
http_async_client: Union[Any, None] = Field(default=None, exclude=True)
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
http_client as well if you'd like a custom client for sync invocations."""
``http_client`` as well if you'd like a custom client for sync invocations."""
stop: Optional[Union[list[str], str]] = Field(default=None, alias="stop_sequences")
"""Default stop sequences."""
extra_body: Optional[Mapping[str, Any]] = None
@ -588,8 +589,8 @@ class BaseChatOpenAI(BaseChatModel):
"""
service_tier: Optional[str] = None
"""Latency tier for request. Options are 'auto', 'default', or 'flex'. Relevant
for users of OpenAI's scale tier service.
"""Latency tier for request. Options are ``'auto'``, ``'default'``, or ``'flex'``.
Relevant for users of OpenAI's scale tier service.
"""
store: Optional[bool] = None
@ -600,8 +601,8 @@ class BaseChatOpenAI(BaseChatModel):
"""
truncation: Optional[str] = None
"""Truncation strategy (Responses API). Can be ``"auto"`` or ``"disabled"``
(default). If ``"auto"``, model may drop input items from the middle of the
"""Truncation strategy (Responses API). Can be ``'auto'`` or ``'disabled'``
(default). If ``'auto'``, model may drop input items from the middle of the
message sequence to fit the context window.
.. versionadded:: 0.3.24
@ -1451,7 +1452,7 @@ class BaseChatOpenAI(BaseChatModel):
Sequence[Union[dict[str, Any], type, Callable, BaseTool]]
] = None,
) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
"""Calculate num tokens for ``gpt-3.5-turbo`` and ``gpt-4`` with ``tiktoken`` package.
**Requirements**: You must have the ``pillow`` installed if you want to count
image tokens if you are specifying the image as a base64 string, and you must
@ -1459,14 +1460,13 @@ class BaseChatOpenAI(BaseChatModel):
as a URL. If these aren't installed image inputs will be ignored in token
counting.
OpenAI reference: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
`OpenAI reference <https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb>`__
Args:
messages: The message inputs to tokenize.
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
to be converted to tool schemas.
"""
""" # noqa: E501
# TODO: Count bound tools as part of input.
if tools is not None:
warnings.warn(
@ -2036,13 +2036,13 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
max_retries: Optional[int]
Max number of retries.
api_key: Optional[str]
OpenAI API key. If not passed in will be read from env var OPENAI_API_KEY.
OpenAI API key. If not passed in will be read from env var ``OPENAI_API_KEY``.
base_url: Optional[str]
Base URL for API requests. Only specify if using a proxy or service
emulator.
organization: Optional[str]
OpenAI organization ID. If not passed in will be read from env
var OPENAI_ORG_ID.
var ``OPENAI_ORG_ID``.
See full list of supported init args and their descriptions in the params section.

View File

@ -93,14 +93,14 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
Name of OpenAI model to use.
dimensions: Optional[int] = None
The number of dimensions the resulting output embeddings should have.
Only supported in `text-embedding-3` and later models.
Only supported in ``'text-embedding-3'`` and later models.
Key init args client params:
api_key: Optional[SecretStr] = None
OpenAI API key.
organization: Optional[str] = None
OpenAI organization ID. If not passed in will be read
from env var OPENAI_ORG_ID.
from env var ``OPENAI_ORG_ID``.
max_retries: int = 2
Maximum number of retries to make when generating.
request_timeout: Optional[Union[float, Tuple[float, float], Any]] = None
@ -194,14 +194,14 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
openai_api_key: Optional[SecretStr] = Field(
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
)
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
"""Automatically inferred from env var ``OPENAI_API_KEY`` if not provided."""
openai_organization: Optional[str] = Field(
alias="organization",
default_factory=from_env(
["OPENAI_ORG_ID", "OPENAI_ORGANIZATION"], default=None
),
)
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
"""Automatically inferred from env var ``OPENAI_ORG_ID`` if not provided."""
allowed_special: Union[Literal["all"], set[str], None] = None
disallowed_special: Union[Literal["all"], set[str], Sequence[str], None] = None
chunk_size: int = 1000
@ -211,12 +211,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
request_timeout: Optional[Union[float, tuple[float, float], Any]] = Field(
default=None, alias="timeout"
)
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
"""Timeout for requests to OpenAI completion API. Can be float, ``httpx.Timeout`` or
None."""
headers: Any = None
tiktoken_enabled: bool = True
"""Set this to False for non-OpenAI implementations of the embeddings API, e.g.
the `--extensions openai` extension for `text-generation-webui`"""
the ``--extensions openai`` extension for ``text-generation-webui``"""
tiktoken_model_name: Optional[str] = None
"""The model name to pass to tiktoken when using this class.
Tiktoken is used to count the number of tokens in documents to constrain
@ -243,12 +243,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
retry_max_seconds: int = 20
"""Max number of seconds to wait between retries"""
http_client: Union[Any, None] = None
"""Optional httpx.Client. Only used for sync invocations. Must specify
http_async_client as well if you'd like a custom client for async invocations.
"""Optional ``httpx.Client``. Only used for sync invocations. Must specify
``http_async_client`` as well if you'd like a custom client for async
invocations.
"""
http_async_client: Union[Any, None] = None
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
http_client as well if you'd like a custom client for sync invocations."""
"""Optional ``httpx.AsyncClient``. Only used for async invocations. Must specify
``http_client`` as well if you'd like a custom client for sync invocations."""
check_embedding_ctx_length: bool = True
"""Whether to check the token length of inputs and automatically split inputs
longer than embedding_ctx_length."""
@ -289,8 +290,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""Validate that api key and python package exists in environment."""
if self.openai_api_type in ("azure", "azure_ad", "azuread"):
raise ValueError(
"If you are using Azure, "
"please use the `AzureOpenAIEmbeddings` class."
"If you are using Azure, please use the `AzureOpenAIEmbeddings` class."
)
client_params: dict = {
"api_key": (

View File

@ -76,7 +76,7 @@ class BaseOpenAI(BaseLLM):
openai_api_key: Optional[SecretStr] = Field(
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
)
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
"""Automatically inferred from env var ``OPENAI_API_KEY`` if not provided."""
openai_api_base: Optional[str] = Field(
alias="base_url", default_factory=from_env("OPENAI_API_BASE", default=None)
)
@ -88,7 +88,7 @@ class BaseOpenAI(BaseLLM):
["OPENAI_ORG_ID", "OPENAI_ORGANIZATION"], default=None
),
)
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
"""Automatically inferred from env var ``OPENAI_ORG_ID`` if not provided."""
# to support explicit proxy for OpenAI
openai_proxy: Optional[str] = Field(
default_factory=from_env("OPENAI_PROXY", default=None)
@ -130,12 +130,13 @@ class BaseOpenAI(BaseLLM):
# Configure a custom httpx client. See the
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
http_client: Union[Any, None] = None
"""Optional httpx.Client. Only used for sync invocations. Must specify
http_async_client as well if you'd like a custom client for async invocations.
"""Optional ``httpx.Client``. Only used for sync invocations. Must specify
``http_async_client`` as well if you'd like a custom client for async
invocations.
"""
http_async_client: Union[Any, None] = None
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
http_client as well if you'd like a custom client for sync invocations."""
"""Optional ``httpx.AsyncClient``. Only used for async invocations. Must specify
``http_client`` as well if you'd like a custom client for sync invocations."""
extra_body: Optional[Mapping[str, Any]] = None
"""Optional additional JSON properties to include in the request parameters when
making requests to OpenAI compatible APIs, such as vLLM."""
@ -606,13 +607,13 @@ class OpenAI(BaseOpenAI):
max_retries: int
Max number of retries.
api_key: Optional[str]
OpenAI API key. If not passed in will be read from env var OPENAI_API_KEY.
OpenAI API key. If not passed in will be read from env var ``OPENAI_API_KEY``.
base_url: Optional[str]
Base URL for API requests. Only specify if using a proxy or service
emulator.
organization: Optional[str]
OpenAI organization ID. If not passed in will be read from env
var OPENAI_ORG_ID.
var ``OPENAI_ORG_ID``.
See full list of supported init args and their descriptions in the params section.

View File

@ -29,6 +29,9 @@ _DictOrPydantic = Union[dict, _BM]
class ChatXAI(BaseChatOpenAI): # type: ignore[override]
r"""ChatXAI chat model.
Refer to `xAI's documentation <https://docs.x.ai/docs/api-reference#chat-completions>`__
for more nuanced details on the API's behavior and supported parameters.
Setup:
Install ``langchain-xai`` and set environment variable ``XAI_API_KEY``.
@ -42,9 +45,12 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
model: str
Name of model to use.
temperature: float
Sampling temperature.
Sampling temperature between ``0`` and ``2``. Higher values mean more random completions,
while lower values (like ``0.2``) mean more focused and deterministic completions.
(Default: ``1``.)
max_tokens: Optional[int]
Max number of tokens to generate.
Max number of tokens to generate. Refer to your `model's documentation <https://docs.x.ai/docs/models#model-pricing>`__
for the maximum number of tokens it can generate.
logprobs: Optional[bool]
Whether to return logprobs.
@ -62,7 +68,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
from langchain_xai import ChatXAI
llm = ChatXAI(
model="grok-beta",
model="grok-4",
temperature=0,
max_tokens=None,
timeout=None,
@ -89,7 +95,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
content="J'adore la programmation.",
response_metadata={
'token_usage': {'completion_tokens': 9, 'prompt_tokens': 32, 'total_tokens': 41},
'model_name': 'grok-beta',
'model_name': 'grok-4',
'system_fingerprint': None,
'finish_reason': 'stop',
'logprobs': None
@ -113,7 +119,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
content=' programm' id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
content='ation' id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
content='.' id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
content='' response_metadata={'finish_reason': 'stop', 'model_name': 'grok-beta'} id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
content='' response_metadata={'finish_reason': 'stop', 'model_name': 'grok-4'} id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
Async:
@ -133,7 +139,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
content="J'adore la programmation.",
response_metadata={
'token_usage': {'completion_tokens': 9, 'prompt_tokens': 32, 'total_tokens': 41},
'model_name': 'grok-beta',
'model_name': 'grok-4',
'system_fingerprint': None,
'finish_reason': 'stop',
'logprobs': None
@ -141,12 +147,39 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
id='run-09371a11-7f72-4c53-8e7c-9de5c238b34c-0',
usage_metadata={'input_tokens': 32, 'output_tokens': 9, 'total_tokens': 41})
Tool calling:
Reasoning:
`Certain xAI models <https://docs.x.ai/docs/models#model-pricing>`__ support reasoning,
which allows the model to provide reasoning content along with the response.
If provided, reasoning content is returned under the ``additional_kwargs`` field of the
AIMessage or AIMessageChunk.
If supported, reasoning effort can be specified in the model constructor's ``extra_body``
argument, which will control the amount of reasoning the model does. The value can be one of
``'low'`` or ``'high'``.
.. code-block:: python
model = ChatXAI(
model="grok-3-mini",
extra_body={"reasoning_effort": "high"},
)
.. note::
As of 2025-07-10, ``reasoning_content`` is only returned in Grok 3 models, such as
`Grok 3 Mini <https://docs.x.ai/docs/models/grok-3-mini>`__.
.. note::
Note that in `Grok 4 <https://docs.x.ai/docs/models/grok-4-0709>`__, as of 2025-07-10,
reasoning is not exposed in ``reasoning_content`` (other than initial ``'Thinking...'`` text),
reasoning cannot be disabled, and the ``reasoning_effort`` cannot be specified.
Tool calling / function calling:
.. code-block:: python
from pydantic import BaseModel, Field
llm = ChatXAI(model="grok-beta")
llm = ChatXAI(model="grok-4")
class GetWeather(BaseModel):
'''Get the current weather in a given location'''
@ -168,7 +201,6 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
)
ai_msg.tool_calls
.. code-block:: python
[
@ -186,6 +218,67 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
}
]
.. note::
With stream response, the tool / function call will be returned in whole in a
single chunk, instead of being streamed across chunks.
Tool choice can be controlled by setting the ``tool_choice`` parameter in the model
constructor's ``extra_body`` argument. For example, to disable tool / function calling:
.. code-block:: python
llm = ChatXAI(model="grok-4", extra_body={"tool_choice": "none"})
To require that the model always calls a tool / function, set ``tool_choice`` to ``'required'``:
.. code-block:: python
llm = ChatXAI(model="grok-4", extra_body={"tool_choice": "required"})
To specify a tool / function to call, set ``tool_choice`` to the name of the tool / function:
.. code-block:: python
from pydantic import BaseModel, Field
llm = ChatXAI(
model="grok-4",
extra_body={
"tool_choice": {"type": "function", "function": {"name": "GetWeather"}}
},
)
class GetWeather(BaseModel):
\"\"\"Get the current weather in a given location\"\"\"
location: str = Field(..., description='The city and state, e.g. San Francisco, CA')
class GetPopulation(BaseModel):
\"\"\"Get the current population in a given location\"\"\"
location: str = Field(..., description='The city and state, e.g. San Francisco, CA')
llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
ai_msg = llm_with_tools.invoke(
"Which city is bigger: LA or NY?",
)
ai_msg.tool_calls
The resulting tool call would be:
.. code-block:: python
[{'name': 'GetWeather',
'args': {'location': 'Los Angeles, CA'},
'id': 'call_81668711',
'type': 'tool_call'}]
Parallel tool calling / parallel function calling:
By default, parallel tool / function calling is enabled, so you can process
multiple function calls in one request/response cycle. When two or more tool calls
are required, all of the tool call requests will be included in the response body.
Structured output:
.. code-block:: python
@ -222,7 +315,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
from langchain_xai import ChatXAI
llm = ChatXAI(
model="grok-3-latest",
model="grok-4",
search_parameters={
"mode": "auto",
# Example optional parameters below:
@ -234,6 +327,10 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
llm.invoke("Provide me a digest of world news in the last 24 hours.")
.. note::
`Citations <https://docs.x.ai/docs/guides/live-search#returning-citations>`__
are only available in `Grok 3 <https://docs.x.ai/docs/models/grok-3>`__.
Token usage:
.. code-block:: python
@ -275,7 +372,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
'prompt_tokens': 19,
'total_tokens': 23
},
'model_name': 'grok-beta',
'model_name': 'grok-4',
'system_fingerprint': None,
'finish_reason': 'stop',
'logprobs': None
@ -283,7 +380,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
""" # noqa: E501
model_name: str = Field(alias="model")
model_name: str = Field(default="grok-4", alias="model")
"""Model name to use."""
xai_api_key: Optional[SecretStr] = Field(
alias="api_key",

View File

@ -18,6 +18,10 @@ rate_limiter = InMemoryRateLimiter(
)
# Not using Grok 4 since it doesn't support reasoning params (effort) or returns
# reasoning content.
class TestXAIStandard(ChatModelIntegrationTests):
@property
def chat_model_class(self) -> type[BaseChatModel]:
@ -25,6 +29,7 @@ class TestXAIStandard(ChatModelIntegrationTests):
@property
def chat_model_params(self) -> dict:
# TODO: bump to test new Grok once they implement other features
return {
"model": "grok-3",
"rate_limiter": rate_limiter,
@ -35,7 +40,7 @@ class TestXAIStandard(ChatModelIntegrationTests):
def test_reasoning_content() -> None:
"""Test reasoning content."""
chat_model = ChatXAI(
model="grok-3-mini-beta",
model="grok-3-mini",
reasoning_effort="low",
)
response = chat_model.invoke("What is 3^3?")
@ -52,7 +57,7 @@ def test_reasoning_content() -> None:
def test_web_search() -> None:
llm = ChatXAI(
model="grok-3-latest",
model="grok-3",
search_parameters={"mode": "auto", "max_search_results": 3},
)

View File

@ -15,10 +15,12 @@ from langchain_openai.chat_models.base import (
from langchain_xai import ChatXAI
MODEL_NAME = "grok-4"
def test_initialization() -> None:
"""Test chat model initialization."""
ChatXAI(model="grok-beta")
ChatXAI(model=MODEL_NAME)
def test_xai_model_param() -> None:
@ -34,7 +36,7 @@ def test_chat_xai_invalid_streaming_params() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
with pytest.raises(ValueError):
ChatXAI(
model="grok-beta",
model=MODEL_NAME,
max_tokens=10,
streaming=True,
temperature=0,
@ -45,17 +47,17 @@ def test_chat_xai_invalid_streaming_params() -> None:
def test_chat_xai_extra_kwargs() -> None:
"""Test extra kwargs to chat xai."""
# Check that foo is saved in extra_kwargs.
llm = ChatXAI(model="grok-beta", foo=3, max_tokens=10) # type: ignore[call-arg]
llm = ChatXAI(model=MODEL_NAME, foo=3, max_tokens=10) # type: ignore[call-arg]
assert llm.max_tokens == 10
assert llm.model_kwargs == {"foo": 3}
# Test that if extra_kwargs are provided, they are added to it.
llm = ChatXAI(model="grok-beta", foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg]
llm = ChatXAI(model=MODEL_NAME, foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg]
assert llm.model_kwargs == {"foo": 3, "bar": 2}
# Test that if provided twice it errors
with pytest.raises(ValueError):
ChatXAI(model="grok-beta", foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg]
ChatXAI(model=MODEL_NAME, foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg]
def test_function_dict_to_message_function_message() -> None:

View File

@ -1,7 +1,9 @@
from langchain_xai import ChatXAI
MODEL_NAME = "grok-4"
def test_chat_xai_secrets() -> None:
o = ChatXAI(model="grok-beta", xai_api_key="foo") # type: ignore[call-arg]
o = ChatXAI(model=MODEL_NAME, xai_api_key="foo") # type: ignore[call-arg]
s = str(o)
assert "foo" not in s