mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
|
||||
|
||||
import openai
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
@@ -19,9 +19,8 @@ if TYPE_CHECKING:
|
||||
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||
from langchain_core.runnables import Runnable
|
||||
|
||||
_BM = TypeVar("_BM", bound=BaseModel)
|
||||
_DictOrPydanticClass = Union[dict[str, Any], type[_BM], type]
|
||||
_DictOrPydantic = Union[dict, _BM]
|
||||
_DictOrPydanticClass: TypeAlias = dict[str, Any] | type[BaseModel] | type
|
||||
_DictOrPydantic: TypeAlias = dict | BaseModel
|
||||
|
||||
|
||||
class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
@@ -31,7 +30,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
for more nuanced details on the API's behavior and supported parameters.
|
||||
|
||||
Setup:
|
||||
Install ``langchain-xai`` and set environment variable ``XAI_API_KEY``.
|
||||
Install `langchain-xai` and set environment variable `XAI_API_KEY`.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
@@ -43,9 +42,9 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
model: str
|
||||
Name of model to use.
|
||||
temperature: float
|
||||
Sampling temperature between `0` and ``2``. Higher values mean more random completions,
|
||||
while lower values (like ``0.2``) mean more focused and deterministic completions.
|
||||
(Default: ``1``.)
|
||||
Sampling temperature between `0` and `2`. Higher values mean more random completions,
|
||||
while lower values (like `0.2`) mean more focused and deterministic completions.
|
||||
(Default: `1`.)
|
||||
max_tokens: Optional[int]
|
||||
Max number of tokens to generate. Refer to your `model's documentation <https://docs.x.ai/docs/models#model-pricing>`__
|
||||
for the maximum number of tokens it can generate.
|
||||
@@ -58,7 +57,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
max_retries: int
|
||||
Max number of retries.
|
||||
api_key: Optional[str]
|
||||
xAI API key. If not passed in will be read from env var ``XAI_API_KEY``.
|
||||
xAI API key. If not passed in will be read from env var `XAI_API_KEY`.
|
||||
|
||||
Instantiate:
|
||||
.. code-block:: python
|
||||
@@ -170,9 +169,9 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
If provided, reasoning content is returned under the ``additional_kwargs`` field of the
|
||||
AIMessage or AIMessageChunk.
|
||||
|
||||
If supported, reasoning effort can be specified in the model constructor's ``extra_body``
|
||||
If supported, reasoning effort can be specified in the model constructor's `extra_body`
|
||||
argument, which will control the amount of reasoning the model does. The value can be one of
|
||||
``'low'`` or ``'high'``.
|
||||
`'low'` or `'high'`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -236,7 +235,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
single chunk, instead of being streamed across chunks.
|
||||
|
||||
Tool choice can be controlled by setting the `tool_choice` parameter in the model
|
||||
constructor's ``extra_body`` argument. For example, to disable tool / function calling:
|
||||
constructor's `extra_body` argument. For example, to disable tool / function calling:
|
||||
.. code-block:: python
|
||||
|
||||
llm = ChatXAI(model="grok-4", extra_body={"tool_choice": "none"})
|
||||
@@ -399,21 +398,21 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
|
||||
model_name: str = Field(default="grok-4", alias="model")
|
||||
"""Model name to use."""
|
||||
xai_api_key: Optional[SecretStr] = Field(
|
||||
xai_api_key: SecretStr | None = Field(
|
||||
alias="api_key",
|
||||
default_factory=secret_from_env("XAI_API_KEY", default=None),
|
||||
)
|
||||
"""xAI API key.
|
||||
|
||||
Automatically read from env variable ``XAI_API_KEY`` if not provided.
|
||||
Automatically read from env variable `XAI_API_KEY` if not provided.
|
||||
"""
|
||||
xai_api_base: str = Field(default="https://api.x.ai/v1/")
|
||||
"""Base URL path for API requests."""
|
||||
search_parameters: Optional[dict[str, Any]] = None
|
||||
search_parameters: dict[str, Any] | None = None
|
||||
"""Parameters for search requests. Example: ``{"mode": "auto"}``."""
|
||||
|
||||
openai_api_key: Optional[SecretStr] = None
|
||||
openai_api_base: Optional[str] = None
|
||||
openai_api_key: SecretStr | None = None
|
||||
openai_api_base: str | None = None
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
@@ -457,7 +456,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
|
||||
def _get_ls_params(
|
||||
self,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: list[str] | None = None,
|
||||
**kwargs: Any, # noqa: ANN401
|
||||
) -> LangSmithParams:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
@@ -525,8 +524,8 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
|
||||
def _create_chat_result(
|
||||
self,
|
||||
response: Union[dict, openai.BaseModel],
|
||||
generation_info: Optional[dict] = None,
|
||||
response: dict | openai.BaseModel,
|
||||
generation_info: dict | None = None,
|
||||
) -> ChatResult:
|
||||
rtn = super()._create_chat_result(response, generation_info)
|
||||
|
||||
@@ -549,8 +548,8 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
self,
|
||||
chunk: dict,
|
||||
default_chunk_class: type,
|
||||
base_generation_info: Optional[dict],
|
||||
) -> Optional[ChatGenerationChunk]:
|
||||
base_generation_info: dict | None,
|
||||
) -> ChatGenerationChunk | None:
|
||||
generation_chunk = super()._convert_chunk_to_generation_chunk(
|
||||
chunk,
|
||||
default_chunk_class,
|
||||
@@ -576,13 +575,13 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Optional[_DictOrPydanticClass] = None,
|
||||
schema: _DictOrPydanticClass | None = None,
|
||||
*,
|
||||
method: Literal[
|
||||
"function_calling", "json_mode", "json_schema"
|
||||
] = "function_calling",
|
||||
include_raw: bool = False,
|
||||
strict: Optional[bool] = None,
|
||||
strict: bool | None = None,
|
||||
**kwargs: Any, # noqa: ANN401
|
||||
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
|
||||
Reference in New Issue
Block a user