feat(xai): ruff fixes and rules (#32501)

This commit is contained in:
Mason Daugherty
2025-08-11 13:03:07 -04:00
committed by GitHub
parent f55186b38f
commit 27b6b53f20
6 changed files with 118 additions and 57 deletions

View File

@@ -1,26 +1,24 @@
"""Wrapper around xAI's Chat Completions API."""
from typing import (
Any,
Literal,
Optional,
TypeVar,
Union,
)
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union
import openai
from langchain_core.language_models.chat_models import (
LangSmithParams,
LanguageModelInput,
)
from langchain_core.messages import AIMessageChunk
from langchain_core.outputs import ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable
from langchain_core.utils import secret_from_env
from langchain_openai.chat_models.base import BaseChatOpenAI
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self
if TYPE_CHECKING:
from langchain_core.language_models.chat_models import (
LangSmithParams,
LanguageModelInput,
)
from langchain_core.outputs import ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable
_BM = TypeVar("_BM", bound=BaseModel)
_DictOrPydanticClass = Union[dict[str, Any], type[_BM], type]
_DictOrPydantic = Union[dict, _BM]
@@ -450,9 +448,11 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
if self.n is not None and self.n < 1:
raise ValueError("n must be at least 1.")
msg = "n must be at least 1."
raise ValueError(msg)
if self.n is not None and self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.")
msg = "n must be 1 when streaming."
raise ValueError(msg)
client_params: dict = {
"api_key": (
@@ -467,10 +467,11 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
client_params["max_retries"] = self.max_retries
if client_params["api_key"] is None:
raise ValueError(
msg = (
"xAI API key is not set. Please set it in the `xai_api_key` field or "
"in the `XAI_API_KEY` environment variable."
)
raise ValueError(msg)
if not (self.client or None):
sync_specific: dict = {"http_client": self.http_client}
@@ -511,9 +512,9 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
if not isinstance(response, openai.BaseModel):
return rtn
if hasattr(response.choices[0].message, "reasoning_content"): # type: ignore
if hasattr(response.choices[0].message, "reasoning_content"): # type: ignore[attr-defined]
rtn.generations[0].message.additional_kwargs["reasoning_content"] = (
response.choices[0].message.reasoning_content # type: ignore
response.choices[0].message.reasoning_content # type: ignore[attr-defined]
)
if hasattr(response, "citations"):
@@ -536,15 +537,19 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
)
if (choices := chunk.get("choices")) and generation_chunk:
top = choices[0]
if isinstance(generation_chunk.message, AIMessageChunk):
if reasoning_content := top.get("delta", {}).get("reasoning_content"):
generation_chunk.message.additional_kwargs["reasoning_content"] = (
reasoning_content
)
if isinstance(generation_chunk.message, AIMessageChunk) and (
reasoning_content := top.get("delta", {}).get("reasoning_content")
):
generation_chunk.message.additional_kwargs["reasoning_content"] = (
reasoning_content
)
if (citations := chunk.get("citations")) and generation_chunk:
if isinstance(generation_chunk.message, AIMessageChunk):
generation_chunk.message.additional_kwargs["citations"] = citations
if (
(citations := chunk.get("citations"))
and generation_chunk
and isinstance(generation_chunk.message, AIMessageChunk)
):
generation_chunk.message.additional_kwargs["citations"] = citations
return generation_chunk