docs(xai): update for Grok 4 (#31953)

This commit is contained in:
Mason Daugherty 2025-07-10 11:06:37 -04:00 committed by GitHub
parent 060fc0e3c9
commit 6594eb8cc1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 168 additions and 61 deletions

View File

@ -345,7 +345,7 @@
"source": [ "source": [
"## API reference\n", "## API reference\n",
"\n", "\n",
"For detailed documentation of all `ChatXAI` features and configurations, head to the API reference: https://python.langchain.com/api_reference/xai/chat_models/langchain_xai.chat_models.ChatXAI.html" "For detailed documentation of all `ChatXAI` features and configurations, head to the [API reference](https://python.langchain.com/api_reference/xai/chat_models/langchain_xai.chat_models.ChatXAI.html)."
] ]
} }
], ],

View File

@ -63,7 +63,7 @@
"\n", "\n",
"chat = ChatXAI(\n", "chat = ChatXAI(\n",
" # xai_api_key=\"YOUR_API_KEY\",\n", " # xai_api_key=\"YOUR_API_KEY\",\n",
" model=\"grok-beta\",\n", " model=\"grok-4\",\n",
")\n", ")\n",
"\n", "\n",
"# stream the response back from the model\n", "# stream the response back from the model\n",

View File

@ -395,10 +395,10 @@ class ChatGroq(BaseChatModel):
# Configure a custom httpx client. See the # Configure a custom httpx client. See the
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details. # [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
http_client: Union[Any, None] = None http_client: Union[Any, None] = None
"""Optional httpx.Client.""" """Optional ``httpx.Client``."""
http_async_client: Union[Any, None] = None http_async_client: Union[Any, None] = None
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify """Optional ``httpx.AsyncClient``. Only used for async invocations. Must specify
http_client as well if you'd like a custom client for sync invocations.""" ``http_client`` as well if you'd like a custom client for sync invocations."""
model_config = ConfigDict( model_config = ConfigDict(
populate_by_name=True, populate_by_name=True,

View File

@ -76,13 +76,13 @@ class AzureChatOpenAI(BaseChatOpenAI):
Max number of retries. Max number of retries.
organization: Optional[str] organization: Optional[str]
OpenAI organization ID. If not passed in will be read from env OpenAI organization ID. If not passed in will be read from env
var OPENAI_ORG_ID. var ``OPENAI_ORG_ID``.
model: Optional[str] model: Optional[str]
The name of the underlying OpenAI model. Used for tracing and token The name of the underlying OpenAI model. Used for tracing and token
counting. Does not affect completion. E.g. "gpt-4", "gpt-35-turbo", etc. counting. Does not affect completion. E.g. ``'gpt-4'``, ``'gpt-35-turbo'``, etc.
model_version: Optional[str] model_version: Optional[str]
The version of the underlying OpenAI model. Used for tracing and token The version of the underlying OpenAI model. Used for tracing and token
counting. Does not affect completion. E.g., "0125", "0125-preview", etc. counting. Does not affect completion. E.g., ``'0125'``, ``'0125-preview'``, etc.
See full list of supported init args and their descriptions in the params section. See full list of supported init args and their descriptions in the params section.

View File

@ -542,12 +542,13 @@ class BaseChatOpenAI(BaseChatModel):
# Configure a custom httpx client. See the # Configure a custom httpx client. See the
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details. # [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
http_client: Union[Any, None] = Field(default=None, exclude=True) http_client: Union[Any, None] = Field(default=None, exclude=True)
"""Optional httpx.Client. Only used for sync invocations. Must specify """Optional ``httpx.Client``. Only used for sync invocations. Must specify
http_async_client as well if you'd like a custom client for async invocations. ``http_async_client`` as well if you'd like a custom client for async
invocations.
""" """
http_async_client: Union[Any, None] = Field(default=None, exclude=True) http_async_client: Union[Any, None] = Field(default=None, exclude=True)
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify """Optional httpx.AsyncClient. Only used for async invocations. Must specify
http_client as well if you'd like a custom client for sync invocations.""" ``http_client`` as well if you'd like a custom client for sync invocations."""
stop: Optional[Union[list[str], str]] = Field(default=None, alias="stop_sequences") stop: Optional[Union[list[str], str]] = Field(default=None, alias="stop_sequences")
"""Default stop sequences.""" """Default stop sequences."""
extra_body: Optional[Mapping[str, Any]] = None extra_body: Optional[Mapping[str, Any]] = None
@ -588,8 +589,8 @@ class BaseChatOpenAI(BaseChatModel):
""" """
service_tier: Optional[str] = None service_tier: Optional[str] = None
"""Latency tier for request. Options are 'auto', 'default', or 'flex'. Relevant """Latency tier for request. Options are ``'auto'``, ``'default'``, or ``'flex'``.
for users of OpenAI's scale tier service. Relevant for users of OpenAI's scale tier service.
""" """
store: Optional[bool] = None store: Optional[bool] = None
@ -600,8 +601,8 @@ class BaseChatOpenAI(BaseChatModel):
""" """
truncation: Optional[str] = None truncation: Optional[str] = None
"""Truncation strategy (Responses API). Can be ``"auto"`` or ``"disabled"`` """Truncation strategy (Responses API). Can be ``'auto'`` or ``'disabled'``
(default). If ``"auto"``, model may drop input items from the middle of the (default). If ``'auto'``, model may drop input items from the middle of the
message sequence to fit the context window. message sequence to fit the context window.
.. versionadded:: 0.3.24 .. versionadded:: 0.3.24
@ -1451,7 +1452,7 @@ class BaseChatOpenAI(BaseChatModel):
Sequence[Union[dict[str, Any], type, Callable, BaseTool]] Sequence[Union[dict[str, Any], type, Callable, BaseTool]]
] = None, ] = None,
) -> int: ) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. """Calculate num tokens for ``gpt-3.5-turbo`` and ``gpt-4`` with ``tiktoken`` package.
**Requirements**: You must have the ``pillow`` installed if you want to count **Requirements**: You must have the ``pillow`` installed if you want to count
image tokens if you are specifying the image as a base64 string, and you must image tokens if you are specifying the image as a base64 string, and you must
@ -1459,14 +1460,13 @@ class BaseChatOpenAI(BaseChatModel):
as a URL. If these aren't installed image inputs will be ignored in token as a URL. If these aren't installed image inputs will be ignored in token
counting. counting.
OpenAI reference: https://github.com/openai/openai-cookbook/blob/ `OpenAI reference <https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb>`__
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
Args: Args:
messages: The message inputs to tokenize. messages: The message inputs to tokenize.
tools: If provided, sequence of dict, BaseModel, function, or BaseTools tools: If provided, sequence of dict, BaseModel, function, or BaseTools
to be converted to tool schemas. to be converted to tool schemas.
""" """ # noqa: E501
# TODO: Count bound tools as part of input. # TODO: Count bound tools as part of input.
if tools is not None: if tools is not None:
warnings.warn( warnings.warn(
@ -2036,13 +2036,13 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
max_retries: Optional[int] max_retries: Optional[int]
Max number of retries. Max number of retries.
api_key: Optional[str] api_key: Optional[str]
OpenAI API key. If not passed in will be read from env var OPENAI_API_KEY. OpenAI API key. If not passed in will be read from env var ``OPENAI_API_KEY``.
base_url: Optional[str] base_url: Optional[str]
Base URL for API requests. Only specify if using a proxy or service Base URL for API requests. Only specify if using a proxy or service
emulator. emulator.
organization: Optional[str] organization: Optional[str]
OpenAI organization ID. If not passed in will be read from env OpenAI organization ID. If not passed in will be read from env
var OPENAI_ORG_ID. var ``OPENAI_ORG_ID``.
See full list of supported init args and their descriptions in the params section. See full list of supported init args and their descriptions in the params section.

View File

@ -93,14 +93,14 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
Name of OpenAI model to use. Name of OpenAI model to use.
dimensions: Optional[int] = None dimensions: Optional[int] = None
The number of dimensions the resulting output embeddings should have. The number of dimensions the resulting output embeddings should have.
Only supported in `text-embedding-3` and later models. Only supported in ``'text-embedding-3'`` and later models.
Key init args client params: Key init args client params:
api_key: Optional[SecretStr] = None api_key: Optional[SecretStr] = None
OpenAI API key. OpenAI API key.
organization: Optional[str] = None organization: Optional[str] = None
OpenAI organization ID. If not passed in will be read OpenAI organization ID. If not passed in will be read
from env var OPENAI_ORG_ID. from env var ``OPENAI_ORG_ID``.
max_retries: int = 2 max_retries: int = 2
Maximum number of retries to make when generating. Maximum number of retries to make when generating.
request_timeout: Optional[Union[float, Tuple[float, float], Any]] = None request_timeout: Optional[Union[float, Tuple[float, float], Any]] = None
@ -194,14 +194,14 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
openai_api_key: Optional[SecretStr] = Field( openai_api_key: Optional[SecretStr] = Field(
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None) alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
) )
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided.""" """Automatically inferred from env var ``OPENAI_API_KEY`` if not provided."""
openai_organization: Optional[str] = Field( openai_organization: Optional[str] = Field(
alias="organization", alias="organization",
default_factory=from_env( default_factory=from_env(
["OPENAI_ORG_ID", "OPENAI_ORGANIZATION"], default=None ["OPENAI_ORG_ID", "OPENAI_ORGANIZATION"], default=None
), ),
) )
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided.""" """Automatically inferred from env var ``OPENAI_ORG_ID`` if not provided."""
allowed_special: Union[Literal["all"], set[str], None] = None allowed_special: Union[Literal["all"], set[str], None] = None
disallowed_special: Union[Literal["all"], set[str], Sequence[str], None] = None disallowed_special: Union[Literal["all"], set[str], Sequence[str], None] = None
chunk_size: int = 1000 chunk_size: int = 1000
@ -211,12 +211,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
request_timeout: Optional[Union[float, tuple[float, float], Any]] = Field( request_timeout: Optional[Union[float, tuple[float, float], Any]] = Field(
default=None, alias="timeout" default=None, alias="timeout"
) )
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or """Timeout for requests to OpenAI completion API. Can be float, ``httpx.Timeout`` or
None.""" None."""
headers: Any = None headers: Any = None
tiktoken_enabled: bool = True tiktoken_enabled: bool = True
"""Set this to False for non-OpenAI implementations of the embeddings API, e.g. """Set this to False for non-OpenAI implementations of the embeddings API, e.g.
the `--extensions openai` extension for `text-generation-webui`""" the ``--extensions openai`` extension for ``text-generation-webui``"""
tiktoken_model_name: Optional[str] = None tiktoken_model_name: Optional[str] = None
"""The model name to pass to tiktoken when using this class. """The model name to pass to tiktoken when using this class.
Tiktoken is used to count the number of tokens in documents to constrain Tiktoken is used to count the number of tokens in documents to constrain
@ -243,12 +243,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
retry_max_seconds: int = 20 retry_max_seconds: int = 20
"""Max number of seconds to wait between retries""" """Max number of seconds to wait between retries"""
http_client: Union[Any, None] = None http_client: Union[Any, None] = None
"""Optional httpx.Client. Only used for sync invocations. Must specify """Optional ``httpx.Client``. Only used for sync invocations. Must specify
http_async_client as well if you'd like a custom client for async invocations. ``http_async_client`` as well if you'd like a custom client for async
invocations.
""" """
http_async_client: Union[Any, None] = None http_async_client: Union[Any, None] = None
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify """Optional ``httpx.AsyncClient``. Only used for async invocations. Must specify
http_client as well if you'd like a custom client for sync invocations.""" ``http_client`` as well if you'd like a custom client for sync invocations."""
check_embedding_ctx_length: bool = True check_embedding_ctx_length: bool = True
"""Whether to check the token length of inputs and automatically split inputs """Whether to check the token length of inputs and automatically split inputs
longer than embedding_ctx_length.""" longer than embedding_ctx_length."""
@ -289,8 +290,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
if self.openai_api_type in ("azure", "azure_ad", "azuread"): if self.openai_api_type in ("azure", "azure_ad", "azuread"):
raise ValueError( raise ValueError(
"If you are using Azure, " "If you are using Azure, please use the `AzureOpenAIEmbeddings` class."
"please use the `AzureOpenAIEmbeddings` class."
) )
client_params: dict = { client_params: dict = {
"api_key": ( "api_key": (

View File

@ -76,7 +76,7 @@ class BaseOpenAI(BaseLLM):
openai_api_key: Optional[SecretStr] = Field( openai_api_key: Optional[SecretStr] = Field(
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None) alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
) )
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided.""" """Automatically inferred from env var ``OPENAI_API_KEY`` if not provided."""
openai_api_base: Optional[str] = Field( openai_api_base: Optional[str] = Field(
alias="base_url", default_factory=from_env("OPENAI_API_BASE", default=None) alias="base_url", default_factory=from_env("OPENAI_API_BASE", default=None)
) )
@ -88,7 +88,7 @@ class BaseOpenAI(BaseLLM):
["OPENAI_ORG_ID", "OPENAI_ORGANIZATION"], default=None ["OPENAI_ORG_ID", "OPENAI_ORGANIZATION"], default=None
), ),
) )
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided.""" """Automatically inferred from env var ``OPENAI_ORG_ID`` if not provided."""
# to support explicit proxy for OpenAI # to support explicit proxy for OpenAI
openai_proxy: Optional[str] = Field( openai_proxy: Optional[str] = Field(
default_factory=from_env("OPENAI_PROXY", default=None) default_factory=from_env("OPENAI_PROXY", default=None)
@ -130,12 +130,13 @@ class BaseOpenAI(BaseLLM):
# Configure a custom httpx client. See the # Configure a custom httpx client. See the
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details. # [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
http_client: Union[Any, None] = None http_client: Union[Any, None] = None
"""Optional httpx.Client. Only used for sync invocations. Must specify """Optional ``httpx.Client``. Only used for sync invocations. Must specify
http_async_client as well if you'd like a custom client for async invocations. ``http_async_client`` as well if you'd like a custom client for async
invocations.
""" """
http_async_client: Union[Any, None] = None http_async_client: Union[Any, None] = None
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify """Optional ``httpx.AsyncClient``. Only used for async invocations. Must specify
http_client as well if you'd like a custom client for sync invocations.""" ``http_client`` as well if you'd like a custom client for sync invocations."""
extra_body: Optional[Mapping[str, Any]] = None extra_body: Optional[Mapping[str, Any]] = None
"""Optional additional JSON properties to include in the request parameters when """Optional additional JSON properties to include in the request parameters when
making requests to OpenAI compatible APIs, such as vLLM.""" making requests to OpenAI compatible APIs, such as vLLM."""
@ -606,13 +607,13 @@ class OpenAI(BaseOpenAI):
max_retries: int max_retries: int
Max number of retries. Max number of retries.
api_key: Optional[str] api_key: Optional[str]
OpenAI API key. If not passed in will be read from env var OPENAI_API_KEY. OpenAI API key. If not passed in will be read from env var ``OPENAI_API_KEY``.
base_url: Optional[str] base_url: Optional[str]
Base URL for API requests. Only specify if using a proxy or service Base URL for API requests. Only specify if using a proxy or service
emulator. emulator.
organization: Optional[str] organization: Optional[str]
OpenAI organization ID. If not passed in will be read from env OpenAI organization ID. If not passed in will be read from env
var OPENAI_ORG_ID. var ``OPENAI_ORG_ID``.
See full list of supported init args and their descriptions in the params section. See full list of supported init args and their descriptions in the params section.

View File

@ -29,6 +29,9 @@ _DictOrPydantic = Union[dict, _BM]
class ChatXAI(BaseChatOpenAI): # type: ignore[override] class ChatXAI(BaseChatOpenAI): # type: ignore[override]
r"""ChatXAI chat model. r"""ChatXAI chat model.
Refer to `xAI's documentation <https://docs.x.ai/docs/api-reference#chat-completions>`__
for more nuanced details on the API's behavior and supported parameters.
Setup: Setup:
Install ``langchain-xai`` and set environment variable ``XAI_API_KEY``. Install ``langchain-xai`` and set environment variable ``XAI_API_KEY``.
@ -42,9 +45,12 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
model: str model: str
Name of model to use. Name of model to use.
temperature: float temperature: float
Sampling temperature. Sampling temperature between ``0`` and ``2``. Higher values mean more random completions,
while lower values (like ``0.2``) mean more focused and deterministic completions.
(Default: ``1``.)
max_tokens: Optional[int] max_tokens: Optional[int]
Max number of tokens to generate. Max number of tokens to generate. Refer to your `model's documentation <https://docs.x.ai/docs/models#model-pricing>`__
for the maximum number of tokens it can generate.
logprobs: Optional[bool] logprobs: Optional[bool]
Whether to return logprobs. Whether to return logprobs.
@ -62,7 +68,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
from langchain_xai import ChatXAI from langchain_xai import ChatXAI
llm = ChatXAI( llm = ChatXAI(
model="grok-beta", model="grok-4",
temperature=0, temperature=0,
max_tokens=None, max_tokens=None,
timeout=None, timeout=None,
@ -89,7 +95,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
content="J'adore la programmation.", content="J'adore la programmation.",
response_metadata={ response_metadata={
'token_usage': {'completion_tokens': 9, 'prompt_tokens': 32, 'total_tokens': 41}, 'token_usage': {'completion_tokens': 9, 'prompt_tokens': 32, 'total_tokens': 41},
'model_name': 'grok-beta', 'model_name': 'grok-4',
'system_fingerprint': None, 'system_fingerprint': None,
'finish_reason': 'stop', 'finish_reason': 'stop',
'logprobs': None 'logprobs': None
@ -113,7 +119,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
content=' programm' id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9' content=' programm' id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
content='ation' id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9' content='ation' id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
content='.' id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9' content='.' id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
content='' response_metadata={'finish_reason': 'stop', 'model_name': 'grok-beta'} id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9' content='' response_metadata={'finish_reason': 'stop', 'model_name': 'grok-4'} id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
Async: Async:
@ -133,7 +139,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
content="J'adore la programmation.", content="J'adore la programmation.",
response_metadata={ response_metadata={
'token_usage': {'completion_tokens': 9, 'prompt_tokens': 32, 'total_tokens': 41}, 'token_usage': {'completion_tokens': 9, 'prompt_tokens': 32, 'total_tokens': 41},
'model_name': 'grok-beta', 'model_name': 'grok-4',
'system_fingerprint': None, 'system_fingerprint': None,
'finish_reason': 'stop', 'finish_reason': 'stop',
'logprobs': None 'logprobs': None
@ -141,12 +147,39 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
id='run-09371a11-7f72-4c53-8e7c-9de5c238b34c-0', id='run-09371a11-7f72-4c53-8e7c-9de5c238b34c-0',
usage_metadata={'input_tokens': 32, 'output_tokens': 9, 'total_tokens': 41}) usage_metadata={'input_tokens': 32, 'output_tokens': 9, 'total_tokens': 41})
Tool calling: Reasoning:
`Certain xAI models <https://docs.x.ai/docs/models#model-pricing>`__ support reasoning,
which allows the model to provide reasoning content along with the response.
If provided, reasoning content is returned under the ``additional_kwargs`` field of the
AIMessage or AIMessageChunk.
If supported, reasoning effort can be specified in the model constructor's ``extra_body``
argument, which will control the amount of reasoning the model does. The value can be one of
``'low'`` or ``'high'``.
.. code-block:: python
model = ChatXAI(
model="grok-3-mini",
extra_body={"reasoning_effort": "high"},
)
.. note::
As of 2025-07-10, ``reasoning_content`` is only returned in Grok 3 models, such as
`Grok 3 Mini <https://docs.x.ai/docs/models/grok-3-mini>`__.
.. note::
Note that in `Grok 4 <https://docs.x.ai/docs/models/grok-4-0709>`__, as of 2025-07-10,
reasoning is not exposed in ``reasoning_content`` (other than initial ``'Thinking...'`` text),
reasoning cannot be disabled, and the ``reasoning_effort`` cannot be specified.
Tool calling / function calling:
.. code-block:: python .. code-block:: python
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
llm = ChatXAI(model="grok-beta") llm = ChatXAI(model="grok-4")
class GetWeather(BaseModel): class GetWeather(BaseModel):
'''Get the current weather in a given location''' '''Get the current weather in a given location'''
@ -168,7 +201,6 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
) )
ai_msg.tool_calls ai_msg.tool_calls
.. code-block:: python .. code-block:: python
[ [
@ -186,6 +218,67 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
} }
] ]
.. note::
With stream response, the tool / function call will be returned in whole in a
single chunk, instead of being streamed across chunks.
Tool choice can be controlled by setting the ``tool_choice`` parameter in the model
constructor's ``extra_body`` argument. For example, to disable tool / function calling:
.. code-block:: python
llm = ChatXAI(model="grok-4", extra_body={"tool_choice": "none"})
To require that the model always calls a tool / function, set ``tool_choice`` to ``'required'``:
.. code-block:: python
llm = ChatXAI(model="grok-4", extra_body={"tool_choice": "required"})
To specify a tool / function to call, set ``tool_choice`` to the name of the tool / function:
.. code-block:: python
from pydantic import BaseModel, Field
llm = ChatXAI(
model="grok-4",
extra_body={
"tool_choice": {"type": "function", "function": {"name": "GetWeather"}}
},
)
class GetWeather(BaseModel):
\"\"\"Get the current weather in a given location\"\"\"
location: str = Field(..., description='The city and state, e.g. San Francisco, CA')
class GetPopulation(BaseModel):
\"\"\"Get the current population in a given location\"\"\"
location: str = Field(..., description='The city and state, e.g. San Francisco, CA')
llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
ai_msg = llm_with_tools.invoke(
"Which city is bigger: LA or NY?",
)
ai_msg.tool_calls
The resulting tool call would be:
.. code-block:: python
[{'name': 'GetWeather',
'args': {'location': 'Los Angeles, CA'},
'id': 'call_81668711',
'type': 'tool_call'}]
Parallel tool calling / parallel function calling:
By default, parallel tool / function calling is enabled, so you can process
multiple function calls in one request/response cycle. When two or more tool calls
are required, all of the tool call requests will be included in the response body.
Structured output: Structured output:
.. code-block:: python .. code-block:: python
@ -222,7 +315,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
from langchain_xai import ChatXAI from langchain_xai import ChatXAI
llm = ChatXAI( llm = ChatXAI(
model="grok-3-latest", model="grok-4",
search_parameters={ search_parameters={
"mode": "auto", "mode": "auto",
# Example optional parameters below: # Example optional parameters below:
@ -234,6 +327,10 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
llm.invoke("Provide me a digest of world news in the last 24 hours.") llm.invoke("Provide me a digest of world news in the last 24 hours.")
.. note::
`Citations <https://docs.x.ai/docs/guides/live-search#returning-citations>`__
are only available in `Grok 3 <https://docs.x.ai/docs/models/grok-3>`__.
Token usage: Token usage:
.. code-block:: python .. code-block:: python
@ -275,7 +372,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
'prompt_tokens': 19, 'prompt_tokens': 19,
'total_tokens': 23 'total_tokens': 23
}, },
'model_name': 'grok-beta', 'model_name': 'grok-4',
'system_fingerprint': None, 'system_fingerprint': None,
'finish_reason': 'stop', 'finish_reason': 'stop',
'logprobs': None 'logprobs': None
@ -283,7 +380,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
""" # noqa: E501 """ # noqa: E501
model_name: str = Field(alias="model") model_name: str = Field(default="grok-4", alias="model")
"""Model name to use.""" """Model name to use."""
xai_api_key: Optional[SecretStr] = Field( xai_api_key: Optional[SecretStr] = Field(
alias="api_key", alias="api_key",

View File

@ -18,6 +18,10 @@ rate_limiter = InMemoryRateLimiter(
) )
# Not using Grok 4 since it doesn't support reasoning params (effort) or returns
# reasoning content.
class TestXAIStandard(ChatModelIntegrationTests): class TestXAIStandard(ChatModelIntegrationTests):
@property @property
def chat_model_class(self) -> type[BaseChatModel]: def chat_model_class(self) -> type[BaseChatModel]:
@ -25,6 +29,7 @@ class TestXAIStandard(ChatModelIntegrationTests):
@property @property
def chat_model_params(self) -> dict: def chat_model_params(self) -> dict:
# TODO: bump to test new Grok once they implement other features
return { return {
"model": "grok-3", "model": "grok-3",
"rate_limiter": rate_limiter, "rate_limiter": rate_limiter,
@ -35,7 +40,7 @@ class TestXAIStandard(ChatModelIntegrationTests):
def test_reasoning_content() -> None: def test_reasoning_content() -> None:
"""Test reasoning content.""" """Test reasoning content."""
chat_model = ChatXAI( chat_model = ChatXAI(
model="grok-3-mini-beta", model="grok-3-mini",
reasoning_effort="low", reasoning_effort="low",
) )
response = chat_model.invoke("What is 3^3?") response = chat_model.invoke("What is 3^3?")
@ -52,7 +57,7 @@ def test_reasoning_content() -> None:
def test_web_search() -> None: def test_web_search() -> None:
llm = ChatXAI( llm = ChatXAI(
model="grok-3-latest", model="grok-3",
search_parameters={"mode": "auto", "max_search_results": 3}, search_parameters={"mode": "auto", "max_search_results": 3},
) )

View File

@ -15,10 +15,12 @@ from langchain_openai.chat_models.base import (
from langchain_xai import ChatXAI from langchain_xai import ChatXAI
MODEL_NAME = "grok-4"
def test_initialization() -> None: def test_initialization() -> None:
"""Test chat model initialization.""" """Test chat model initialization."""
ChatXAI(model="grok-beta") ChatXAI(model=MODEL_NAME)
def test_xai_model_param() -> None: def test_xai_model_param() -> None:
@ -34,7 +36,7 @@ def test_chat_xai_invalid_streaming_params() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback.""" """Test that streaming correctly invokes on_llm_new_token callback."""
with pytest.raises(ValueError): with pytest.raises(ValueError):
ChatXAI( ChatXAI(
model="grok-beta", model=MODEL_NAME,
max_tokens=10, max_tokens=10,
streaming=True, streaming=True,
temperature=0, temperature=0,
@ -45,17 +47,17 @@ def test_chat_xai_invalid_streaming_params() -> None:
def test_chat_xai_extra_kwargs() -> None: def test_chat_xai_extra_kwargs() -> None:
"""Test extra kwargs to chat xai.""" """Test extra kwargs to chat xai."""
# Check that foo is saved in extra_kwargs. # Check that foo is saved in extra_kwargs.
llm = ChatXAI(model="grok-beta", foo=3, max_tokens=10) # type: ignore[call-arg] llm = ChatXAI(model=MODEL_NAME, foo=3, max_tokens=10) # type: ignore[call-arg]
assert llm.max_tokens == 10 assert llm.max_tokens == 10
assert llm.model_kwargs == {"foo": 3} assert llm.model_kwargs == {"foo": 3}
# Test that if extra_kwargs are provided, they are added to it. # Test that if extra_kwargs are provided, they are added to it.
llm = ChatXAI(model="grok-beta", foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg] llm = ChatXAI(model=MODEL_NAME, foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg]
assert llm.model_kwargs == {"foo": 3, "bar": 2} assert llm.model_kwargs == {"foo": 3, "bar": 2}
# Test that if provided twice it errors # Test that if provided twice it errors
with pytest.raises(ValueError): with pytest.raises(ValueError):
ChatXAI(model="grok-beta", foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg] ChatXAI(model=MODEL_NAME, foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg]
def test_function_dict_to_message_function_message() -> None: def test_function_dict_to_message_function_message() -> None:

View File

@ -1,7 +1,9 @@
from langchain_xai import ChatXAI from langchain_xai import ChatXAI
MODEL_NAME = "grok-4"
def test_chat_xai_secrets() -> None: def test_chat_xai_secrets() -> None:
o = ChatXAI(model="grok-beta", xai_api_key="foo") # type: ignore[call-arg] o = ChatXAI(model=MODEL_NAME, xai_api_key="foo") # type: ignore[call-arg]
s = str(o) s = str(o)
assert "foo" not in s assert "foo" not in s