mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-18 10:43:36 +00:00
docs(xai): update for Grok 4 (#31953)
This commit is contained in:
parent
060fc0e3c9
commit
6594eb8cc1
@ -345,7 +345,7 @@
|
||||
"source": [
|
||||
"## API reference\n",
|
||||
"\n",
|
||||
"For detailed documentation of all `ChatXAI` features and configurations, head to the API reference: https://python.langchain.com/api_reference/xai/chat_models/langchain_xai.chat_models.ChatXAI.html"
|
||||
"For detailed documentation of all `ChatXAI` features and configurations, head to the [API reference](https://python.langchain.com/api_reference/xai/chat_models/langchain_xai.chat_models.ChatXAI.html)."
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -63,7 +63,7 @@
|
||||
"\n",
|
||||
"chat = ChatXAI(\n",
|
||||
" # xai_api_key=\"YOUR_API_KEY\",\n",
|
||||
" model=\"grok-beta\",\n",
|
||||
" model=\"grok-4\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# stream the response back from the model\n",
|
||||
|
@ -395,10 +395,10 @@ class ChatGroq(BaseChatModel):
|
||||
# Configure a custom httpx client. See the
|
||||
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
|
||||
http_client: Union[Any, None] = None
|
||||
"""Optional httpx.Client."""
|
||||
"""Optional ``httpx.Client``."""
|
||||
http_async_client: Union[Any, None] = None
|
||||
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
|
||||
http_client as well if you'd like a custom client for sync invocations."""
|
||||
"""Optional ``httpx.AsyncClient``. Only used for async invocations. Must specify
|
||||
``http_client`` as well if you'd like a custom client for sync invocations."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
|
@ -76,13 +76,13 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
Max number of retries.
|
||||
organization: Optional[str]
|
||||
OpenAI organization ID. If not passed in will be read from env
|
||||
var OPENAI_ORG_ID.
|
||||
var ``OPENAI_ORG_ID``.
|
||||
model: Optional[str]
|
||||
The name of the underlying OpenAI model. Used for tracing and token
|
||||
counting. Does not affect completion. E.g. "gpt-4", "gpt-35-turbo", etc.
|
||||
counting. Does not affect completion. E.g. ``'gpt-4'``, ``'gpt-35-turbo'``, etc.
|
||||
model_version: Optional[str]
|
||||
The version of the underlying OpenAI model. Used for tracing and token
|
||||
counting. Does not affect completion. E.g., "0125", "0125-preview", etc.
|
||||
counting. Does not affect completion. E.g., ``'0125'``, ``'0125-preview'``, etc.
|
||||
|
||||
See full list of supported init args and their descriptions in the params section.
|
||||
|
||||
|
@ -542,12 +542,13 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
# Configure a custom httpx client. See the
|
||||
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
|
||||
http_client: Union[Any, None] = Field(default=None, exclude=True)
|
||||
"""Optional httpx.Client. Only used for sync invocations. Must specify
|
||||
http_async_client as well if you'd like a custom client for async invocations.
|
||||
"""Optional ``httpx.Client``. Only used for sync invocations. Must specify
|
||||
``http_async_client`` as well if you'd like a custom client for async
|
||||
invocations.
|
||||
"""
|
||||
http_async_client: Union[Any, None] = Field(default=None, exclude=True)
|
||||
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
|
||||
http_client as well if you'd like a custom client for sync invocations."""
|
||||
``http_client`` as well if you'd like a custom client for sync invocations."""
|
||||
stop: Optional[Union[list[str], str]] = Field(default=None, alias="stop_sequences")
|
||||
"""Default stop sequences."""
|
||||
extra_body: Optional[Mapping[str, Any]] = None
|
||||
@ -588,8 +589,8 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
"""
|
||||
|
||||
service_tier: Optional[str] = None
|
||||
"""Latency tier for request. Options are 'auto', 'default', or 'flex'. Relevant
|
||||
for users of OpenAI's scale tier service.
|
||||
"""Latency tier for request. Options are ``'auto'``, ``'default'``, or ``'flex'``.
|
||||
Relevant for users of OpenAI's scale tier service.
|
||||
"""
|
||||
|
||||
store: Optional[bool] = None
|
||||
@ -600,8 +601,8 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
"""
|
||||
|
||||
truncation: Optional[str] = None
|
||||
"""Truncation strategy (Responses API). Can be ``"auto"`` or ``"disabled"``
|
||||
(default). If ``"auto"``, model may drop input items from the middle of the
|
||||
"""Truncation strategy (Responses API). Can be ``'auto'`` or ``'disabled'``
|
||||
(default). If ``'auto'``, model may drop input items from the middle of the
|
||||
message sequence to fit the context window.
|
||||
|
||||
.. versionadded:: 0.3.24
|
||||
@ -1451,7 +1452,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
Sequence[Union[dict[str, Any], type, Callable, BaseTool]]
|
||||
] = None,
|
||||
) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
"""Calculate num tokens for ``gpt-3.5-turbo`` and ``gpt-4`` with ``tiktoken`` package.
|
||||
|
||||
**Requirements**: You must have the ``pillow`` installed if you want to count
|
||||
image tokens if you are specifying the image as a base64 string, and you must
|
||||
@ -1459,14 +1460,13 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
as a URL. If these aren't installed image inputs will be ignored in token
|
||||
counting.
|
||||
|
||||
OpenAI reference: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
|
||||
`OpenAI reference <https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb>`__
|
||||
|
||||
Args:
|
||||
messages: The message inputs to tokenize.
|
||||
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
|
||||
to be converted to tool schemas.
|
||||
"""
|
||||
""" # noqa: E501
|
||||
# TODO: Count bound tools as part of input.
|
||||
if tools is not None:
|
||||
warnings.warn(
|
||||
@ -2036,13 +2036,13 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
||||
max_retries: Optional[int]
|
||||
Max number of retries.
|
||||
api_key: Optional[str]
|
||||
OpenAI API key. If not passed in will be read from env var OPENAI_API_KEY.
|
||||
OpenAI API key. If not passed in will be read from env var ``OPENAI_API_KEY``.
|
||||
base_url: Optional[str]
|
||||
Base URL for API requests. Only specify if using a proxy or service
|
||||
emulator.
|
||||
organization: Optional[str]
|
||||
OpenAI organization ID. If not passed in will be read from env
|
||||
var OPENAI_ORG_ID.
|
||||
var ``OPENAI_ORG_ID``.
|
||||
|
||||
See full list of supported init args and their descriptions in the params section.
|
||||
|
||||
|
@ -93,14 +93,14 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
Name of OpenAI model to use.
|
||||
dimensions: Optional[int] = None
|
||||
The number of dimensions the resulting output embeddings should have.
|
||||
Only supported in `text-embedding-3` and later models.
|
||||
Only supported in ``'text-embedding-3'`` and later models.
|
||||
|
||||
Key init args — client params:
|
||||
api_key: Optional[SecretStr] = None
|
||||
OpenAI API key.
|
||||
organization: Optional[str] = None
|
||||
OpenAI organization ID. If not passed in will be read
|
||||
from env var OPENAI_ORG_ID.
|
||||
from env var ``OPENAI_ORG_ID``.
|
||||
max_retries: int = 2
|
||||
Maximum number of retries to make when generating.
|
||||
request_timeout: Optional[Union[float, Tuple[float, float], Any]] = None
|
||||
@ -194,14 +194,14 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
openai_api_key: Optional[SecretStr] = Field(
|
||||
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
|
||||
)
|
||||
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
|
||||
"""Automatically inferred from env var ``OPENAI_API_KEY`` if not provided."""
|
||||
openai_organization: Optional[str] = Field(
|
||||
alias="organization",
|
||||
default_factory=from_env(
|
||||
["OPENAI_ORG_ID", "OPENAI_ORGANIZATION"], default=None
|
||||
),
|
||||
)
|
||||
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
|
||||
"""Automatically inferred from env var ``OPENAI_ORG_ID`` if not provided."""
|
||||
allowed_special: Union[Literal["all"], set[str], None] = None
|
||||
disallowed_special: Union[Literal["all"], set[str], Sequence[str], None] = None
|
||||
chunk_size: int = 1000
|
||||
@ -211,12 +211,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
request_timeout: Optional[Union[float, tuple[float, float], Any]] = Field(
|
||||
default=None, alias="timeout"
|
||||
)
|
||||
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
|
||||
"""Timeout for requests to OpenAI completion API. Can be float, ``httpx.Timeout`` or
|
||||
None."""
|
||||
headers: Any = None
|
||||
tiktoken_enabled: bool = True
|
||||
"""Set this to False for non-OpenAI implementations of the embeddings API, e.g.
|
||||
the `--extensions openai` extension for `text-generation-webui`"""
|
||||
the ``--extensions openai`` extension for ``text-generation-webui``"""
|
||||
tiktoken_model_name: Optional[str] = None
|
||||
"""The model name to pass to tiktoken when using this class.
|
||||
Tiktoken is used to count the number of tokens in documents to constrain
|
||||
@ -243,12 +243,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
retry_max_seconds: int = 20
|
||||
"""Max number of seconds to wait between retries"""
|
||||
http_client: Union[Any, None] = None
|
||||
"""Optional httpx.Client. Only used for sync invocations. Must specify
|
||||
http_async_client as well if you'd like a custom client for async invocations.
|
||||
"""Optional ``httpx.Client``. Only used for sync invocations. Must specify
|
||||
``http_async_client`` as well if you'd like a custom client for async
|
||||
invocations.
|
||||
"""
|
||||
http_async_client: Union[Any, None] = None
|
||||
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
|
||||
http_client as well if you'd like a custom client for sync invocations."""
|
||||
"""Optional ``httpx.AsyncClient``. Only used for async invocations. Must specify
|
||||
``http_client`` as well if you'd like a custom client for sync invocations."""
|
||||
check_embedding_ctx_length: bool = True
|
||||
"""Whether to check the token length of inputs and automatically split inputs
|
||||
longer than embedding_ctx_length."""
|
||||
@ -289,8 +290,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if self.openai_api_type in ("azure", "azure_ad", "azuread"):
|
||||
raise ValueError(
|
||||
"If you are using Azure, "
|
||||
"please use the `AzureOpenAIEmbeddings` class."
|
||||
"If you are using Azure, please use the `AzureOpenAIEmbeddings` class."
|
||||
)
|
||||
client_params: dict = {
|
||||
"api_key": (
|
||||
|
@ -76,7 +76,7 @@ class BaseOpenAI(BaseLLM):
|
||||
openai_api_key: Optional[SecretStr] = Field(
|
||||
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
|
||||
)
|
||||
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
|
||||
"""Automatically inferred from env var ``OPENAI_API_KEY`` if not provided."""
|
||||
openai_api_base: Optional[str] = Field(
|
||||
alias="base_url", default_factory=from_env("OPENAI_API_BASE", default=None)
|
||||
)
|
||||
@ -88,7 +88,7 @@ class BaseOpenAI(BaseLLM):
|
||||
["OPENAI_ORG_ID", "OPENAI_ORGANIZATION"], default=None
|
||||
),
|
||||
)
|
||||
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
|
||||
"""Automatically inferred from env var ``OPENAI_ORG_ID`` if not provided."""
|
||||
# to support explicit proxy for OpenAI
|
||||
openai_proxy: Optional[str] = Field(
|
||||
default_factory=from_env("OPENAI_PROXY", default=None)
|
||||
@ -130,12 +130,13 @@ class BaseOpenAI(BaseLLM):
|
||||
# Configure a custom httpx client. See the
|
||||
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
|
||||
http_client: Union[Any, None] = None
|
||||
"""Optional httpx.Client. Only used for sync invocations. Must specify
|
||||
http_async_client as well if you'd like a custom client for async invocations.
|
||||
"""Optional ``httpx.Client``. Only used for sync invocations. Must specify
|
||||
``http_async_client`` as well if you'd like a custom client for async
|
||||
invocations.
|
||||
"""
|
||||
http_async_client: Union[Any, None] = None
|
||||
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
|
||||
http_client as well if you'd like a custom client for sync invocations."""
|
||||
"""Optional ``httpx.AsyncClient``. Only used for async invocations. Must specify
|
||||
``http_client`` as well if you'd like a custom client for sync invocations."""
|
||||
extra_body: Optional[Mapping[str, Any]] = None
|
||||
"""Optional additional JSON properties to include in the request parameters when
|
||||
making requests to OpenAI compatible APIs, such as vLLM."""
|
||||
@ -606,13 +607,13 @@ class OpenAI(BaseOpenAI):
|
||||
max_retries: int
|
||||
Max number of retries.
|
||||
api_key: Optional[str]
|
||||
OpenAI API key. If not passed in will be read from env var OPENAI_API_KEY.
|
||||
OpenAI API key. If not passed in will be read from env var ``OPENAI_API_KEY``.
|
||||
base_url: Optional[str]
|
||||
Base URL for API requests. Only specify if using a proxy or service
|
||||
emulator.
|
||||
organization: Optional[str]
|
||||
OpenAI organization ID. If not passed in will be read from env
|
||||
var OPENAI_ORG_ID.
|
||||
var ``OPENAI_ORG_ID``.
|
||||
|
||||
See full list of supported init args and their descriptions in the params section.
|
||||
|
||||
|
@ -29,6 +29,9 @@ _DictOrPydantic = Union[dict, _BM]
|
||||
class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
r"""ChatXAI chat model.
|
||||
|
||||
Refer to `xAI's documentation <https://docs.x.ai/docs/api-reference#chat-completions>`__
|
||||
for more nuanced details on the API's behavior and supported parameters.
|
||||
|
||||
Setup:
|
||||
Install ``langchain-xai`` and set environment variable ``XAI_API_KEY``.
|
||||
|
||||
@ -42,9 +45,12 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
model: str
|
||||
Name of model to use.
|
||||
temperature: float
|
||||
Sampling temperature.
|
||||
Sampling temperature between ``0`` and ``2``. Higher values mean more random completions,
|
||||
while lower values (like ``0.2``) mean more focused and deterministic completions.
|
||||
(Default: ``1``.)
|
||||
max_tokens: Optional[int]
|
||||
Max number of tokens to generate.
|
||||
Max number of tokens to generate. Refer to your `model's documentation <https://docs.x.ai/docs/models#model-pricing>`__
|
||||
for the maximum number of tokens it can generate.
|
||||
logprobs: Optional[bool]
|
||||
Whether to return logprobs.
|
||||
|
||||
@ -62,7 +68,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
from langchain_xai import ChatXAI
|
||||
|
||||
llm = ChatXAI(
|
||||
model="grok-beta",
|
||||
model="grok-4",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
@ -89,7 +95,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
content="J'adore la programmation.",
|
||||
response_metadata={
|
||||
'token_usage': {'completion_tokens': 9, 'prompt_tokens': 32, 'total_tokens': 41},
|
||||
'model_name': 'grok-beta',
|
||||
'model_name': 'grok-4',
|
||||
'system_fingerprint': None,
|
||||
'finish_reason': 'stop',
|
||||
'logprobs': None
|
||||
@ -113,7 +119,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
content=' programm' id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
|
||||
content='ation' id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
|
||||
content='.' id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
|
||||
content='' response_metadata={'finish_reason': 'stop', 'model_name': 'grok-beta'} id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
|
||||
content='' response_metadata={'finish_reason': 'stop', 'model_name': 'grok-4'} id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
|
||||
|
||||
|
||||
Async:
|
||||
@ -133,7 +139,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
content="J'adore la programmation.",
|
||||
response_metadata={
|
||||
'token_usage': {'completion_tokens': 9, 'prompt_tokens': 32, 'total_tokens': 41},
|
||||
'model_name': 'grok-beta',
|
||||
'model_name': 'grok-4',
|
||||
'system_fingerprint': None,
|
||||
'finish_reason': 'stop',
|
||||
'logprobs': None
|
||||
@ -141,12 +147,39 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
id='run-09371a11-7f72-4c53-8e7c-9de5c238b34c-0',
|
||||
usage_metadata={'input_tokens': 32, 'output_tokens': 9, 'total_tokens': 41})
|
||||
|
||||
Tool calling:
|
||||
Reasoning:
|
||||
`Certain xAI models <https://docs.x.ai/docs/models#model-pricing>`__ support reasoning,
|
||||
which allows the model to provide reasoning content along with the response.
|
||||
|
||||
If provided, reasoning content is returned under the ``additional_kwargs`` field of the
|
||||
AIMessage or AIMessageChunk.
|
||||
|
||||
If supported, reasoning effort can be specified in the model constructor's ``extra_body``
|
||||
argument, which will control the amount of reasoning the model does. The value can be one of
|
||||
``'low'`` or ``'high'``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = ChatXAI(
|
||||
model="grok-3-mini",
|
||||
extra_body={"reasoning_effort": "high"},
|
||||
)
|
||||
|
||||
.. note::
|
||||
As of 2025-07-10, ``reasoning_content`` is only returned in Grok 3 models, such as
|
||||
`Grok 3 Mini <https://docs.x.ai/docs/models/grok-3-mini>`__.
|
||||
|
||||
.. note::
|
||||
Note that in `Grok 4 <https://docs.x.ai/docs/models/grok-4-0709>`__, as of 2025-07-10,
|
||||
reasoning is not exposed in ``reasoning_content`` (other than initial ``'Thinking...'`` text),
|
||||
reasoning cannot be disabled, and the ``reasoning_effort`` cannot be specified.
|
||||
|
||||
Tool calling / function calling:
|
||||
.. code-block:: python
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
llm = ChatXAI(model="grok-beta")
|
||||
llm = ChatXAI(model="grok-4")
|
||||
|
||||
class GetWeather(BaseModel):
|
||||
'''Get the current weather in a given location'''
|
||||
@ -168,7 +201,6 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
)
|
||||
ai_msg.tool_calls
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
[
|
||||
@ -186,6 +218,67 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
}
|
||||
]
|
||||
|
||||
.. note::
|
||||
With stream response, the tool / function call will be returned in whole in a
|
||||
single chunk, instead of being streamed across chunks.
|
||||
|
||||
Tool choice can be controlled by setting the ``tool_choice`` parameter in the model
|
||||
constructor's ``extra_body`` argument. For example, to disable tool / function calling:
|
||||
.. code-block:: python
|
||||
|
||||
llm = ChatXAI(model="grok-4", extra_body={"tool_choice": "none"})
|
||||
|
||||
To require that the model always calls a tool / function, set ``tool_choice`` to ``'required'``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
llm = ChatXAI(model="grok-4", extra_body={"tool_choice": "required"})
|
||||
|
||||
To specify a tool / function to call, set ``tool_choice`` to the name of the tool / function:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
llm = ChatXAI(
|
||||
model="grok-4",
|
||||
extra_body={
|
||||
"tool_choice": {"type": "function", "function": {"name": "GetWeather"}}
|
||||
},
|
||||
)
|
||||
|
||||
class GetWeather(BaseModel):
|
||||
\"\"\"Get the current weather in a given location\"\"\"
|
||||
|
||||
location: str = Field(..., description='The city and state, e.g. San Francisco, CA')
|
||||
|
||||
|
||||
class GetPopulation(BaseModel):
|
||||
\"\"\"Get the current population in a given location\"\"\"
|
||||
|
||||
location: str = Field(..., description='The city and state, e.g. San Francisco, CA')
|
||||
|
||||
|
||||
llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
|
||||
ai_msg = llm_with_tools.invoke(
|
||||
"Which city is bigger: LA or NY?",
|
||||
)
|
||||
ai_msg.tool_calls
|
||||
|
||||
The resulting tool call would be:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
[{'name': 'GetWeather',
|
||||
'args': {'location': 'Los Angeles, CA'},
|
||||
'id': 'call_81668711',
|
||||
'type': 'tool_call'}]
|
||||
|
||||
Parallel tool calling / parallel function calling:
|
||||
By default, parallel tool / function calling is enabled, so you can process
|
||||
multiple function calls in one request/response cycle. When two or more tool calls
|
||||
are required, all of the tool call requests will be included in the response body.
|
||||
|
||||
Structured output:
|
||||
.. code-block:: python
|
||||
|
||||
@ -222,7 +315,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
from langchain_xai import ChatXAI
|
||||
|
||||
llm = ChatXAI(
|
||||
model="grok-3-latest",
|
||||
model="grok-4",
|
||||
search_parameters={
|
||||
"mode": "auto",
|
||||
# Example optional parameters below:
|
||||
@ -234,6 +327,10 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
|
||||
llm.invoke("Provide me a digest of world news in the last 24 hours.")
|
||||
|
||||
.. note::
|
||||
`Citations <https://docs.x.ai/docs/guides/live-search#returning-citations>`__
|
||||
are only available in `Grok 3 <https://docs.x.ai/docs/models/grok-3>`__.
|
||||
|
||||
Token usage:
|
||||
.. code-block:: python
|
||||
|
||||
@ -275,7 +372,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
'prompt_tokens': 19,
|
||||
'total_tokens': 23
|
||||
},
|
||||
'model_name': 'grok-beta',
|
||||
'model_name': 'grok-4',
|
||||
'system_fingerprint': None,
|
||||
'finish_reason': 'stop',
|
||||
'logprobs': None
|
||||
@ -283,7 +380,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
model_name: str = Field(alias="model")
|
||||
model_name: str = Field(default="grok-4", alias="model")
|
||||
"""Model name to use."""
|
||||
xai_api_key: Optional[SecretStr] = Field(
|
||||
alias="api_key",
|
||||
|
@ -18,6 +18,10 @@ rate_limiter = InMemoryRateLimiter(
|
||||
)
|
||||
|
||||
|
||||
# Not using Grok 4 since it doesn't support reasoning params (effort) or returns
|
||||
# reasoning content.
|
||||
|
||||
|
||||
class TestXAIStandard(ChatModelIntegrationTests):
|
||||
@property
|
||||
def chat_model_class(self) -> type[BaseChatModel]:
|
||||
@ -25,6 +29,7 @@ class TestXAIStandard(ChatModelIntegrationTests):
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
# TODO: bump to test new Grok once they implement other features
|
||||
return {
|
||||
"model": "grok-3",
|
||||
"rate_limiter": rate_limiter,
|
||||
@ -35,7 +40,7 @@ class TestXAIStandard(ChatModelIntegrationTests):
|
||||
def test_reasoning_content() -> None:
|
||||
"""Test reasoning content."""
|
||||
chat_model = ChatXAI(
|
||||
model="grok-3-mini-beta",
|
||||
model="grok-3-mini",
|
||||
reasoning_effort="low",
|
||||
)
|
||||
response = chat_model.invoke("What is 3^3?")
|
||||
@ -52,7 +57,7 @@ def test_reasoning_content() -> None:
|
||||
|
||||
def test_web_search() -> None:
|
||||
llm = ChatXAI(
|
||||
model="grok-3-latest",
|
||||
model="grok-3",
|
||||
search_parameters={"mode": "auto", "max_search_results": 3},
|
||||
)
|
||||
|
||||
|
@ -15,10 +15,12 @@ from langchain_openai.chat_models.base import (
|
||||
|
||||
from langchain_xai import ChatXAI
|
||||
|
||||
MODEL_NAME = "grok-4"
|
||||
|
||||
|
||||
def test_initialization() -> None:
|
||||
"""Test chat model initialization."""
|
||||
ChatXAI(model="grok-beta")
|
||||
ChatXAI(model=MODEL_NAME)
|
||||
|
||||
|
||||
def test_xai_model_param() -> None:
|
||||
@ -34,7 +36,7 @@ def test_chat_xai_invalid_streaming_params() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
with pytest.raises(ValueError):
|
||||
ChatXAI(
|
||||
model="grok-beta",
|
||||
model=MODEL_NAME,
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
@ -45,17 +47,17 @@ def test_chat_xai_invalid_streaming_params() -> None:
|
||||
def test_chat_xai_extra_kwargs() -> None:
|
||||
"""Test extra kwargs to chat xai."""
|
||||
# Check that foo is saved in extra_kwargs.
|
||||
llm = ChatXAI(model="grok-beta", foo=3, max_tokens=10) # type: ignore[call-arg]
|
||||
llm = ChatXAI(model=MODEL_NAME, foo=3, max_tokens=10) # type: ignore[call-arg]
|
||||
assert llm.max_tokens == 10
|
||||
assert llm.model_kwargs == {"foo": 3}
|
||||
|
||||
# Test that if extra_kwargs are provided, they are added to it.
|
||||
llm = ChatXAI(model="grok-beta", foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg]
|
||||
llm = ChatXAI(model=MODEL_NAME, foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg]
|
||||
assert llm.model_kwargs == {"foo": 3, "bar": 2}
|
||||
|
||||
# Test that if provided twice it errors
|
||||
with pytest.raises(ValueError):
|
||||
ChatXAI(model="grok-beta", foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg]
|
||||
ChatXAI(model=MODEL_NAME, foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def test_function_dict_to_message_function_message() -> None:
|
||||
|
@ -1,7 +1,9 @@
|
||||
from langchain_xai import ChatXAI
|
||||
|
||||
MODEL_NAME = "grok-4"
|
||||
|
||||
|
||||
def test_chat_xai_secrets() -> None:
|
||||
o = ChatXAI(model="grok-beta", xai_api_key="foo") # type: ignore[call-arg]
|
||||
o = ChatXAI(model=MODEL_NAME, xai_api_key="foo") # type: ignore[call-arg]
|
||||
s = str(o)
|
||||
assert "foo" not in s
|
||||
|
Loading…
Reference in New Issue
Block a user