mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
feat(xai): ruff fixes and rules (#32501)
This commit is contained in:
@@ -1,26 +1,24 @@
|
||||
"""Wrapper around xAI's Chat Completions API."""
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
Literal,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union
|
||||
|
||||
import openai
|
||||
from langchain_core.language_models.chat_models import (
|
||||
LangSmithParams,
|
||||
LanguageModelInput,
|
||||
)
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.utils import secret_from_env
|
||||
from langchain_openai.chat_models.base import BaseChatOpenAI
|
||||
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.language_models.chat_models import (
|
||||
LangSmithParams,
|
||||
LanguageModelInput,
|
||||
)
|
||||
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||
from langchain_core.runnables import Runnable
|
||||
|
||||
_BM = TypeVar("_BM", bound=BaseModel)
|
||||
_DictOrPydanticClass = Union[dict[str, Any], type[_BM], type]
|
||||
_DictOrPydantic = Union[dict, _BM]
|
||||
@@ -450,9 +448,11 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
def validate_environment(self) -> Self:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if self.n is not None and self.n < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
msg = "n must be at least 1."
|
||||
raise ValueError(msg)
|
||||
if self.n is not None and self.n > 1 and self.streaming:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
msg = "n must be 1 when streaming."
|
||||
raise ValueError(msg)
|
||||
|
||||
client_params: dict = {
|
||||
"api_key": (
|
||||
@@ -467,10 +467,11 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
client_params["max_retries"] = self.max_retries
|
||||
|
||||
if client_params["api_key"] is None:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"xAI API key is not set. Please set it in the `xai_api_key` field or "
|
||||
"in the `XAI_API_KEY` environment variable."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
if not (self.client or None):
|
||||
sync_specific: dict = {"http_client": self.http_client}
|
||||
@@ -511,9 +512,9 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
if not isinstance(response, openai.BaseModel):
|
||||
return rtn
|
||||
|
||||
if hasattr(response.choices[0].message, "reasoning_content"): # type: ignore
|
||||
if hasattr(response.choices[0].message, "reasoning_content"): # type: ignore[attr-defined]
|
||||
rtn.generations[0].message.additional_kwargs["reasoning_content"] = (
|
||||
response.choices[0].message.reasoning_content # type: ignore
|
||||
response.choices[0].message.reasoning_content # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
if hasattr(response, "citations"):
|
||||
@@ -536,15 +537,19 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
)
|
||||
if (choices := chunk.get("choices")) and generation_chunk:
|
||||
top = choices[0]
|
||||
if isinstance(generation_chunk.message, AIMessageChunk):
|
||||
if reasoning_content := top.get("delta", {}).get("reasoning_content"):
|
||||
generation_chunk.message.additional_kwargs["reasoning_content"] = (
|
||||
reasoning_content
|
||||
)
|
||||
if isinstance(generation_chunk.message, AIMessageChunk) and (
|
||||
reasoning_content := top.get("delta", {}).get("reasoning_content")
|
||||
):
|
||||
generation_chunk.message.additional_kwargs["reasoning_content"] = (
|
||||
reasoning_content
|
||||
)
|
||||
|
||||
if (citations := chunk.get("citations")) and generation_chunk:
|
||||
if isinstance(generation_chunk.message, AIMessageChunk):
|
||||
generation_chunk.message.additional_kwargs["citations"] = citations
|
||||
if (
|
||||
(citations := chunk.get("citations"))
|
||||
and generation_chunk
|
||||
and isinstance(generation_chunk.message, AIMessageChunk)
|
||||
):
|
||||
generation_chunk.message.additional_kwargs["citations"] = citations
|
||||
|
||||
return generation_chunk
|
||||
|
||||
|
||||
Reference in New Issue
Block a user