diff --git a/docs/docs/integrations/chat/xai.ipynb b/docs/docs/integrations/chat/xai.ipynb index 22cc06d256f..7e9d37eb426 100644 --- a/docs/docs/integrations/chat/xai.ipynb +++ b/docs/docs/integrations/chat/xai.ipynb @@ -345,7 +345,7 @@ "source": [ "## API reference\n", "\n", - "For detailed documentation of all `ChatXAI` features and configurations, head to the API reference: https://python.langchain.com/api_reference/xai/chat_models/langchain_xai.chat_models.ChatXAI.html" + "For detailed documentation of all `ChatXAI` features and configurations, head to the [API reference](https://python.langchain.com/api_reference/xai/chat_models/langchain_xai.chat_models.ChatXAI.html)." ] } ], diff --git a/docs/docs/integrations/providers/xai.ipynb b/docs/docs/integrations/providers/xai.ipynb index 4840db40a53..8cf4de3104b 100644 --- a/docs/docs/integrations/providers/xai.ipynb +++ b/docs/docs/integrations/providers/xai.ipynb @@ -63,7 +63,7 @@ "\n", "chat = ChatXAI(\n", " # xai_api_key=\"YOUR_API_KEY\",\n", - " model=\"grok-beta\",\n", + " model=\"grok-4\",\n", ")\n", "\n", "# stream the response back from the model\n", diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index b4541a648ff..92c7b25c45f 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -395,10 +395,10 @@ class ChatGroq(BaseChatModel): # Configure a custom httpx client. See the # [httpx documentation](https://www.python-httpx.org/api/#client) for more details. http_client: Union[Any, None] = None - """Optional httpx.Client.""" + """Optional ``httpx.Client``.""" http_async_client: Union[Any, None] = None - """Optional httpx.AsyncClient. Only used for async invocations. Must specify - http_client as well if you'd like a custom client for sync invocations.""" + """Optional ``httpx.AsyncClient``. Only used for async invocations. Must specify + ``http_client`` as well if you'd like a custom client for sync invocations.""" model_config = ConfigDict( populate_by_name=True, diff --git a/libs/partners/openai/langchain_openai/chat_models/azure.py b/libs/partners/openai/langchain_openai/chat_models/azure.py index 2c85c892955..9f62b875675 100644 --- a/libs/partners/openai/langchain_openai/chat_models/azure.py +++ b/libs/partners/openai/langchain_openai/chat_models/azure.py @@ -76,13 +76,13 @@ class AzureChatOpenAI(BaseChatOpenAI): Max number of retries. organization: Optional[str] OpenAI organization ID. If not passed in will be read from env - var OPENAI_ORG_ID. + var ``OPENAI_ORG_ID``. model: Optional[str] The name of the underlying OpenAI model. Used for tracing and token - counting. Does not affect completion. E.g. "gpt-4", "gpt-35-turbo", etc. + counting. Does not affect completion. E.g. ``'gpt-4'``, ``'gpt-35-turbo'``, etc. model_version: Optional[str] The version of the underlying OpenAI model. Used for tracing and token - counting. Does not affect completion. E.g., "0125", "0125-preview", etc. + counting. Does not affect completion. E.g., ``'0125'``, ``'0125-preview'``, etc. See full list of supported init args and their descriptions in the params section. diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index b2488031931..bd436c36eea 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -542,12 +542,13 @@ class BaseChatOpenAI(BaseChatModel): # Configure a custom httpx client. See the # [httpx documentation](https://www.python-httpx.org/api/#client) for more details. http_client: Union[Any, None] = Field(default=None, exclude=True) - """Optional httpx.Client. Only used for sync invocations. Must specify - http_async_client as well if you'd like a custom client for async invocations. + """Optional ``httpx.Client``. Only used for sync invocations. Must specify + ``http_async_client`` as well if you'd like a custom client for async + invocations. """ http_async_client: Union[Any, None] = Field(default=None, exclude=True) """Optional httpx.AsyncClient. Only used for async invocations. Must specify - http_client as well if you'd like a custom client for sync invocations.""" + ``http_client`` as well if you'd like a custom client for sync invocations.""" stop: Optional[Union[list[str], str]] = Field(default=None, alias="stop_sequences") """Default stop sequences.""" extra_body: Optional[Mapping[str, Any]] = None @@ -588,8 +589,8 @@ class BaseChatOpenAI(BaseChatModel): """ service_tier: Optional[str] = None - """Latency tier for request. Options are 'auto', 'default', or 'flex'. Relevant - for users of OpenAI's scale tier service. + """Latency tier for request. Options are ``'auto'``, ``'default'``, or ``'flex'``. + Relevant for users of OpenAI's scale tier service. """ store: Optional[bool] = None @@ -600,8 +601,8 @@ class BaseChatOpenAI(BaseChatModel): """ truncation: Optional[str] = None - """Truncation strategy (Responses API). Can be ``"auto"`` or ``"disabled"`` - (default). If ``"auto"``, model may drop input items from the middle of the + """Truncation strategy (Responses API). Can be ``'auto'`` or ``'disabled'`` + (default). If ``'auto'``, model may drop input items from the middle of the message sequence to fit the context window. .. versionadded:: 0.3.24 @@ -1451,7 +1452,7 @@ class BaseChatOpenAI(BaseChatModel): Sequence[Union[dict[str, Any], type, Callable, BaseTool]] ] = None, ) -> int: - """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. + """Calculate num tokens for ``gpt-3.5-turbo`` and ``gpt-4`` with ``tiktoken`` package. **Requirements**: You must have the ``pillow`` installed if you want to count image tokens if you are specifying the image as a base64 string, and you must @@ -1459,14 +1460,13 @@ class BaseChatOpenAI(BaseChatModel): as a URL. If these aren't installed image inputs will be ignored in token counting. - OpenAI reference: https://github.com/openai/openai-cookbook/blob/ - main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb + `OpenAI reference `__ Args: messages: The message inputs to tokenize. tools: If provided, sequence of dict, BaseModel, function, or BaseTools to be converted to tool schemas. - """ + """ # noqa: E501 # TODO: Count bound tools as part of input. if tools is not None: warnings.warn( @@ -2036,13 +2036,13 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override] max_retries: Optional[int] Max number of retries. api_key: Optional[str] - OpenAI API key. If not passed in will be read from env var OPENAI_API_KEY. + OpenAI API key. If not passed in will be read from env var ``OPENAI_API_KEY``. base_url: Optional[str] Base URL for API requests. Only specify if using a proxy or service emulator. organization: Optional[str] OpenAI organization ID. If not passed in will be read from env - var OPENAI_ORG_ID. + var ``OPENAI_ORG_ID``. See full list of supported init args and their descriptions in the params section. diff --git a/libs/partners/openai/langchain_openai/embeddings/base.py b/libs/partners/openai/langchain_openai/embeddings/base.py index c23a466cff9..212c6385b04 100644 --- a/libs/partners/openai/langchain_openai/embeddings/base.py +++ b/libs/partners/openai/langchain_openai/embeddings/base.py @@ -93,14 +93,14 @@ class OpenAIEmbeddings(BaseModel, Embeddings): Name of OpenAI model to use. dimensions: Optional[int] = None The number of dimensions the resulting output embeddings should have. - Only supported in `text-embedding-3` and later models. + Only supported in ``'text-embedding-3'`` and later models. Key init args — client params: api_key: Optional[SecretStr] = None OpenAI API key. organization: Optional[str] = None OpenAI organization ID. If not passed in will be read - from env var OPENAI_ORG_ID. + from env var ``OPENAI_ORG_ID``. max_retries: int = 2 Maximum number of retries to make when generating. request_timeout: Optional[Union[float, Tuple[float, float], Any]] = None @@ -194,14 +194,14 @@ class OpenAIEmbeddings(BaseModel, Embeddings): openai_api_key: Optional[SecretStr] = Field( alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None) ) - """Automatically inferred from env var `OPENAI_API_KEY` if not provided.""" + """Automatically inferred from env var ``OPENAI_API_KEY`` if not provided.""" openai_organization: Optional[str] = Field( alias="organization", default_factory=from_env( ["OPENAI_ORG_ID", "OPENAI_ORGANIZATION"], default=None ), ) - """Automatically inferred from env var `OPENAI_ORG_ID` if not provided.""" + """Automatically inferred from env var ``OPENAI_ORG_ID`` if not provided.""" allowed_special: Union[Literal["all"], set[str], None] = None disallowed_special: Union[Literal["all"], set[str], Sequence[str], None] = None chunk_size: int = 1000 @@ -211,12 +211,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings): request_timeout: Optional[Union[float, tuple[float, float], Any]] = Field( default=None, alias="timeout" ) - """Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or + """Timeout for requests to OpenAI completion API. Can be float, ``httpx.Timeout`` or None.""" headers: Any = None tiktoken_enabled: bool = True """Set this to False for non-OpenAI implementations of the embeddings API, e.g. - the `--extensions openai` extension for `text-generation-webui`""" + the ``--extensions openai`` extension for ``text-generation-webui``""" tiktoken_model_name: Optional[str] = None """The model name to pass to tiktoken when using this class. Tiktoken is used to count the number of tokens in documents to constrain @@ -243,12 +243,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings): retry_max_seconds: int = 20 """Max number of seconds to wait between retries""" http_client: Union[Any, None] = None - """Optional httpx.Client. Only used for sync invocations. Must specify - http_async_client as well if you'd like a custom client for async invocations. + """Optional ``httpx.Client``. Only used for sync invocations. Must specify + ``http_async_client`` as well if you'd like a custom client for async + invocations. """ http_async_client: Union[Any, None] = None - """Optional httpx.AsyncClient. Only used for async invocations. Must specify - http_client as well if you'd like a custom client for sync invocations.""" + """Optional ``httpx.AsyncClient``. Only used for async invocations. Must specify + ``http_client`` as well if you'd like a custom client for sync invocations.""" check_embedding_ctx_length: bool = True """Whether to check the token length of inputs and automatically split inputs longer than embedding_ctx_length.""" @@ -289,8 +290,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): """Validate that api key and python package exists in environment.""" if self.openai_api_type in ("azure", "azure_ad", "azuread"): raise ValueError( - "If you are using Azure, " - "please use the `AzureOpenAIEmbeddings` class." + "If you are using Azure, please use the `AzureOpenAIEmbeddings` class." ) client_params: dict = { "api_key": ( diff --git a/libs/partners/openai/langchain_openai/llms/base.py b/libs/partners/openai/langchain_openai/llms/base.py index e22b7b01712..b4ee88ee327 100644 --- a/libs/partners/openai/langchain_openai/llms/base.py +++ b/libs/partners/openai/langchain_openai/llms/base.py @@ -76,7 +76,7 @@ class BaseOpenAI(BaseLLM): openai_api_key: Optional[SecretStr] = Field( alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None) ) - """Automatically inferred from env var `OPENAI_API_KEY` if not provided.""" + """Automatically inferred from env var ``OPENAI_API_KEY`` if not provided.""" openai_api_base: Optional[str] = Field( alias="base_url", default_factory=from_env("OPENAI_API_BASE", default=None) ) @@ -88,7 +88,7 @@ class BaseOpenAI(BaseLLM): ["OPENAI_ORG_ID", "OPENAI_ORGANIZATION"], default=None ), ) - """Automatically inferred from env var `OPENAI_ORG_ID` if not provided.""" + """Automatically inferred from env var ``OPENAI_ORG_ID`` if not provided.""" # to support explicit proxy for OpenAI openai_proxy: Optional[str] = Field( default_factory=from_env("OPENAI_PROXY", default=None) @@ -130,12 +130,13 @@ class BaseOpenAI(BaseLLM): # Configure a custom httpx client. See the # [httpx documentation](https://www.python-httpx.org/api/#client) for more details. http_client: Union[Any, None] = None - """Optional httpx.Client. Only used for sync invocations. Must specify - http_async_client as well if you'd like a custom client for async invocations. + """Optional ``httpx.Client``. Only used for sync invocations. Must specify + ``http_async_client`` as well if you'd like a custom client for async + invocations. """ http_async_client: Union[Any, None] = None - """Optional httpx.AsyncClient. Only used for async invocations. Must specify - http_client as well if you'd like a custom client for sync invocations.""" + """Optional ``httpx.AsyncClient``. Only used for async invocations. Must specify + ``http_client`` as well if you'd like a custom client for sync invocations.""" extra_body: Optional[Mapping[str, Any]] = None """Optional additional JSON properties to include in the request parameters when making requests to OpenAI compatible APIs, such as vLLM.""" @@ -606,13 +607,13 @@ class OpenAI(BaseOpenAI): max_retries: int Max number of retries. api_key: Optional[str] - OpenAI API key. If not passed in will be read from env var OPENAI_API_KEY. + OpenAI API key. If not passed in will be read from env var ``OPENAI_API_KEY``. base_url: Optional[str] Base URL for API requests. Only specify if using a proxy or service emulator. organization: Optional[str] OpenAI organization ID. If not passed in will be read from env - var OPENAI_ORG_ID. + var ``OPENAI_ORG_ID``. See full list of supported init args and their descriptions in the params section. diff --git a/libs/partners/xai/langchain_xai/chat_models.py b/libs/partners/xai/langchain_xai/chat_models.py index 03235f1906b..4c47dba9151 100644 --- a/libs/partners/xai/langchain_xai/chat_models.py +++ b/libs/partners/xai/langchain_xai/chat_models.py @@ -29,6 +29,9 @@ _DictOrPydantic = Union[dict, _BM] class ChatXAI(BaseChatOpenAI): # type: ignore[override] r"""ChatXAI chat model. + Refer to `xAI's documentation `__ + for more nuanced details on the API's behavior and supported parameters. + Setup: Install ``langchain-xai`` and set environment variable ``XAI_API_KEY``. @@ -42,9 +45,12 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] model: str Name of model to use. temperature: float - Sampling temperature. + Sampling temperature between ``0`` and ``2``. Higher values mean more random completions, + while lower values (like ``0.2``) mean more focused and deterministic completions. + (Default: ``1``.) max_tokens: Optional[int] - Max number of tokens to generate. + Max number of tokens to generate. Refer to your `model's documentation `__ + for the maximum number of tokens it can generate. logprobs: Optional[bool] Whether to return logprobs. @@ -62,7 +68,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] from langchain_xai import ChatXAI llm = ChatXAI( - model="grok-beta", + model="grok-4", temperature=0, max_tokens=None, timeout=None, @@ -89,7 +95,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] content="J'adore la programmation.", response_metadata={ 'token_usage': {'completion_tokens': 9, 'prompt_tokens': 32, 'total_tokens': 41}, - 'model_name': 'grok-beta', + 'model_name': 'grok-4', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None @@ -113,7 +119,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] content=' programm' id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9' content='ation' id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9' content='.' id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9' - content='' response_metadata={'finish_reason': 'stop', 'model_name': 'grok-beta'} id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9' + content='' response_metadata={'finish_reason': 'stop', 'model_name': 'grok-4'} id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9' Async: @@ -133,7 +139,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] content="J'adore la programmation.", response_metadata={ 'token_usage': {'completion_tokens': 9, 'prompt_tokens': 32, 'total_tokens': 41}, - 'model_name': 'grok-beta', + 'model_name': 'grok-4', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None @@ -141,12 +147,39 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] id='run-09371a11-7f72-4c53-8e7c-9de5c238b34c-0', usage_metadata={'input_tokens': 32, 'output_tokens': 9, 'total_tokens': 41}) - Tool calling: + Reasoning: + `Certain xAI models `__ support reasoning, + which allows the model to provide reasoning content along with the response. + + If provided, reasoning content is returned under the ``additional_kwargs`` field of the + AIMessage or AIMessageChunk. + + If supported, reasoning effort can be specified in the model constructor's ``extra_body`` + argument, which will control the amount of reasoning the model does. The value can be one of + ``'low'`` or ``'high'``. + + .. code-block:: python + + model = ChatXAI( + model="grok-3-mini", + extra_body={"reasoning_effort": "high"}, + ) + + .. note:: + As of 2025-07-10, ``reasoning_content`` is only returned in Grok 3 models, such as + `Grok 3 Mini `__. + + .. note:: + Note that in `Grok 4 `__, as of 2025-07-10, + reasoning is not exposed in ``reasoning_content`` (other than initial ``'Thinking...'`` text), + reasoning cannot be disabled, and the ``reasoning_effort`` cannot be specified. + + Tool calling / function calling: .. code-block:: python from pydantic import BaseModel, Field - llm = ChatXAI(model="grok-beta") + llm = ChatXAI(model="grok-4") class GetWeather(BaseModel): '''Get the current weather in a given location''' @@ -168,7 +201,6 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] ) ai_msg.tool_calls - .. code-block:: python [ @@ -186,6 +218,67 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] } ] + .. note:: + With stream response, the tool / function call will be returned in whole in a + single chunk, instead of being streamed across chunks. + + Tool choice can be controlled by setting the ``tool_choice`` parameter in the model + constructor's ``extra_body`` argument. For example, to disable tool / function calling: + .. code-block:: python + + llm = ChatXAI(model="grok-4", extra_body={"tool_choice": "none"}) + + To require that the model always calls a tool / function, set ``tool_choice`` to ``'required'``: + + .. code-block:: python + + llm = ChatXAI(model="grok-4", extra_body={"tool_choice": "required"}) + + To specify a tool / function to call, set ``tool_choice`` to the name of the tool / function: + + .. code-block:: python + + from pydantic import BaseModel, Field + + llm = ChatXAI( + model="grok-4", + extra_body={ + "tool_choice": {"type": "function", "function": {"name": "GetWeather"}} + }, + ) + + class GetWeather(BaseModel): + \"\"\"Get the current weather in a given location\"\"\" + + location: str = Field(..., description='The city and state, e.g. San Francisco, CA') + + + class GetPopulation(BaseModel): + \"\"\"Get the current population in a given location\"\"\" + + location: str = Field(..., description='The city and state, e.g. San Francisco, CA') + + + llm_with_tools = llm.bind_tools([GetWeather, GetPopulation]) + ai_msg = llm_with_tools.invoke( + "Which city is bigger: LA or NY?", + ) + ai_msg.tool_calls + + The resulting tool call would be: + + .. code-block:: python + + [{'name': 'GetWeather', + 'args': {'location': 'Los Angeles, CA'}, + 'id': 'call_81668711', + 'type': 'tool_call'}] + + Parallel tool calling / parallel function calling: + By default, parallel tool / function calling is enabled, so you can process + multiple function calls in one request/response cycle. When two or more tool calls + are required, all of the tool call requests will be included in the response body. + Structured output: .. code-block:: python @@ -222,7 +315,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] from langchain_xai import ChatXAI llm = ChatXAI( - model="grok-3-latest", + model="grok-4", search_parameters={ "mode": "auto", # Example optional parameters below: @@ -234,6 +327,10 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] llm.invoke("Provide me a digest of world news in the last 24 hours.") + .. note:: + `Citations `__ + are only available in `Grok 3 `__. + Token usage: .. code-block:: python @@ -275,7 +372,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] 'prompt_tokens': 19, 'total_tokens': 23 }, - 'model_name': 'grok-beta', + 'model_name': 'grok-4', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None @@ -283,7 +380,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] """ # noqa: E501 - model_name: str = Field(alias="model") + model_name: str = Field(default="grok-4", alias="model") """Model name to use.""" xai_api_key: Optional[SecretStr] = Field( alias="api_key", diff --git a/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py b/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py index 67e5bd1707a..7004936a578 100644 --- a/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py +++ b/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py @@ -18,6 +18,10 @@ rate_limiter = InMemoryRateLimiter( ) +# Not using Grok 4 since it doesn't support reasoning params (effort) or returns +# reasoning content. + + class TestXAIStandard(ChatModelIntegrationTests): @property def chat_model_class(self) -> type[BaseChatModel]: @@ -25,6 +29,7 @@ class TestXAIStandard(ChatModelIntegrationTests): @property def chat_model_params(self) -> dict: + # TODO: bump to test new Grok once they implement other features return { "model": "grok-3", "rate_limiter": rate_limiter, @@ -35,7 +40,7 @@ class TestXAIStandard(ChatModelIntegrationTests): def test_reasoning_content() -> None: """Test reasoning content.""" chat_model = ChatXAI( - model="grok-3-mini-beta", + model="grok-3-mini", reasoning_effort="low", ) response = chat_model.invoke("What is 3^3?") @@ -52,7 +57,7 @@ def test_reasoning_content() -> None: def test_web_search() -> None: llm = ChatXAI( - model="grok-3-latest", + model="grok-3", search_parameters={"mode": "auto", "max_search_results": 3}, ) diff --git a/libs/partners/xai/tests/unit_tests/test_chat_models.py b/libs/partners/xai/tests/unit_tests/test_chat_models.py index e595d449893..78c7a495153 100644 --- a/libs/partners/xai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/xai/tests/unit_tests/test_chat_models.py @@ -15,10 +15,12 @@ from langchain_openai.chat_models.base import ( from langchain_xai import ChatXAI +MODEL_NAME = "grok-4" + def test_initialization() -> None: """Test chat model initialization.""" - ChatXAI(model="grok-beta") + ChatXAI(model=MODEL_NAME) def test_xai_model_param() -> None: @@ -34,7 +36,7 @@ def test_chat_xai_invalid_streaming_params() -> None: """Test that streaming correctly invokes on_llm_new_token callback.""" with pytest.raises(ValueError): ChatXAI( - model="grok-beta", + model=MODEL_NAME, max_tokens=10, streaming=True, temperature=0, @@ -45,17 +47,17 @@ def test_chat_xai_invalid_streaming_params() -> None: def test_chat_xai_extra_kwargs() -> None: """Test extra kwargs to chat xai.""" # Check that foo is saved in extra_kwargs. - llm = ChatXAI(model="grok-beta", foo=3, max_tokens=10) # type: ignore[call-arg] + llm = ChatXAI(model=MODEL_NAME, foo=3, max_tokens=10) # type: ignore[call-arg] assert llm.max_tokens == 10 assert llm.model_kwargs == {"foo": 3} # Test that if extra_kwargs are provided, they are added to it. - llm = ChatXAI(model="grok-beta", foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg] + llm = ChatXAI(model=MODEL_NAME, foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg] assert llm.model_kwargs == {"foo": 3, "bar": 2} # Test that if provided twice it errors with pytest.raises(ValueError): - ChatXAI(model="grok-beta", foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg] + ChatXAI(model=MODEL_NAME, foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg] def test_function_dict_to_message_function_message() -> None: diff --git a/libs/partners/xai/tests/unit_tests/test_secrets.py b/libs/partners/xai/tests/unit_tests/test_secrets.py index bcb83e774ff..7156c693f9e 100644 --- a/libs/partners/xai/tests/unit_tests/test_secrets.py +++ b/libs/partners/xai/tests/unit_tests/test_secrets.py @@ -1,7 +1,9 @@ from langchain_xai import ChatXAI +MODEL_NAME = "grok-4" + def test_chat_xai_secrets() -> None: - o = ChatXAI(model="grok-beta", xai_api_key="foo") # type: ignore[call-arg] + o = ChatXAI(model=MODEL_NAME, xai_api_key="foo") # type: ignore[call-arg] s = str(o) assert "foo" not in s