mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-04 04:28:58 +00:00
partners[lint]: run pyupgrade
to get code in line with 3.9 standards (#30781)
Using `pyupgrade` to get all `partners` code up to 3.9 standards (mostly, fixing old `typing` imports).
This commit is contained in:
@@ -3,20 +3,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Iterator, Mapping
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Literal, Optional, TypeVar, Union
|
||||
|
||||
import openai
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
@@ -50,8 +39,8 @@ from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
_BM = TypeVar("_BM", bound=BaseModel)
|
||||
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type]
|
||||
_DictOrPydantic = Union[Dict, _BM]
|
||||
_DictOrPydanticClass = Union[dict[str, Any], type[_BM], type]
|
||||
_DictOrPydantic = Union[dict, _BM]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -162,14 +151,14 @@ class ChatPerplexity(BaseChatModel):
|
||||
"""Model name."""
|
||||
temperature: float = 0.7
|
||||
"""What sampling temperature to use."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
pplx_api_key: Optional[SecretStr] = Field(
|
||||
default_factory=secret_from_env("PPLX_API_KEY", default=None), alias="api_key"
|
||||
)
|
||||
"""Base URL path for API requests,
|
||||
leave blank if not using a proxy or service emulator."""
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = Field(
|
||||
request_timeout: Optional[Union[float, tuple[float, float]]] = Field(
|
||||
None, alias="timeout"
|
||||
)
|
||||
"""Timeout for requests to PerplexityChat completion API. Default is None."""
|
||||
@@ -183,12 +172,12 @@ class ChatPerplexity(BaseChatModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
def lc_secrets(self) -> dict[str, str]:
|
||||
return {"pplx_api_key": "PPLX_API_KEY"}
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
def build_extra(cls, values: dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
@@ -232,7 +221,7 @@ class ChatPerplexity(BaseChatModel):
|
||||
return self
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
def _default_params(self) -> dict[str, Any]:
|
||||
"""Get the default parameters for calling PerplexityChat API."""
|
||||
return {
|
||||
"max_tokens": self.max_tokens,
|
||||
@@ -241,7 +230,7 @@ class ChatPerplexity(BaseChatModel):
|
||||
**self.model_kwargs,
|
||||
}
|
||||
|
||||
def _convert_message_to_dict(self, message: BaseMessage) -> Dict[str, Any]:
|
||||
def _convert_message_to_dict(self, message: BaseMessage) -> dict[str, Any]:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, SystemMessage):
|
||||
@@ -255,8 +244,8 @@ class ChatPerplexity(BaseChatModel):
|
||||
return message_dict
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
self, messages: list[BaseMessage], stop: Optional[list[str]]
|
||||
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
||||
params = dict(self._invocation_params)
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
@@ -266,11 +255,11 @@ class ChatPerplexity(BaseChatModel):
|
||||
return message_dicts, params
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
self, _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||
self, _dict: Mapping[str, Any], default_class: type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
role = _dict.get("role")
|
||||
content = _dict.get("content") or ""
|
||||
additional_kwargs: Dict = {}
|
||||
additional_kwargs: dict = {}
|
||||
if _dict.get("function_call"):
|
||||
function_call = dict(_dict["function_call"])
|
||||
if "name" in function_call and function_call["name"] is None:
|
||||
@@ -296,8 +285,8 @@ class ChatPerplexity(BaseChatModel):
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
@@ -367,8 +356,8 @@ class ChatPerplexity(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@@ -402,7 +391,7 @@ class ChatPerplexity(BaseChatModel):
|
||||
@property
|
||||
def _invocation_params(self) -> Mapping[str, Any]:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
pplx_creds: Dict[str, Any] = {"model": self.model}
|
||||
pplx_creds: dict[str, Any] = {"model": self.model}
|
||||
return {**pplx_creds, **self._default_params}
|
||||
|
||||
@property
|
||||
|
Reference in New Issue
Block a user