mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-08 06:23:20 +00:00
partners[lint]: run pyupgrade
to get code in line with 3.9 standards (#30781)
Using `pyupgrade` to get all `partners` code up to 3.9 standards (mostly, fixing old `typing` imports).
This commit is contained in:
@@ -2,21 +2,8 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import (
|
||||
AbstractSet,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from collections.abc import AsyncIterator, Collection, Iterator, Mapping
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import openai
|
||||
import tiktoken
|
||||
@@ -35,7 +22,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _update_token_usage(
|
||||
keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any]
|
||||
keys: set[str], response: dict[str, Any], token_usage: dict[str, Any]
|
||||
) -> None:
|
||||
"""Update token usage."""
|
||||
_keys_to_use = keys.intersection(response["usage"])
|
||||
@@ -47,7 +34,7 @@ def _update_token_usage(
|
||||
|
||||
|
||||
def _stream_response_to_generation_chunk(
|
||||
stream_response: Dict[str, Any],
|
||||
stream_response: dict[str, Any],
|
||||
) -> GenerationChunk:
|
||||
"""Convert a stream response to a generation chunk."""
|
||||
if not stream_response["choices"]:
|
||||
@@ -84,7 +71,7 @@ class BaseOpenAI(BaseLLM):
|
||||
"""How many completions to generate for each prompt."""
|
||||
best_of: int = 1
|
||||
"""Generates best_of completions server-side and returns the "best"."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
openai_api_key: Optional[SecretStr] = Field(
|
||||
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
|
||||
@@ -108,12 +95,12 @@ class BaseOpenAI(BaseLLM):
|
||||
)
|
||||
batch_size: int = 20
|
||||
"""Batch size to use when passing multiple documents to generate."""
|
||||
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
|
||||
request_timeout: Union[float, tuple[float, float], Any, None] = Field(
|
||||
default=None, alias="timeout"
|
||||
)
|
||||
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
|
||||
None."""
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
logit_bias: Optional[dict[str, float]] = None
|
||||
"""Adjust the probability of specific tokens being generated."""
|
||||
max_retries: int = 2
|
||||
"""Maximum number of retries to make when generating."""
|
||||
@@ -124,7 +111,7 @@ class BaseOpenAI(BaseLLM):
|
||||
as well the chosen tokens."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
allowed_special: Union[Literal["all"], AbstractSet[str]] = set()
|
||||
allowed_special: Union[Literal["all"], set[str]] = set()
|
||||
"""Set of special tokens that are allowed。"""
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all"
|
||||
"""Set of special tokens that are not allowed。"""
|
||||
@@ -157,7 +144,7 @@ class BaseOpenAI(BaseLLM):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
def build_extra(cls, values: dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
values = _build_model_kwargs(values, all_required_field_names)
|
||||
@@ -197,9 +184,9 @@ class BaseOpenAI(BaseLLM):
|
||||
return self
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
def _default_params(self) -> dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
normal_params: Dict[str, Any] = {
|
||||
normal_params: dict[str, Any] = {
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
@@ -228,7 +215,7 @@ class BaseOpenAI(BaseLLM):
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
@@ -255,7 +242,7 @@ class BaseOpenAI(BaseLLM):
|
||||
async def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
@@ -283,8 +270,8 @@ class BaseOpenAI(BaseLLM):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
@@ -307,7 +294,7 @@ class BaseOpenAI(BaseLLM):
|
||||
params = {**params, **kwargs}
|
||||
sub_prompts = self.get_sub_prompts(params, prompts, stop)
|
||||
choices = []
|
||||
token_usage: Dict[str, int] = {}
|
||||
token_usage: dict[str, int] = {}
|
||||
# Get the token usage from the response.
|
||||
# Includes prompt, completion, and total tokens used.
|
||||
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
||||
@@ -363,8 +350,8 @@ class BaseOpenAI(BaseLLM):
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
@@ -373,7 +360,7 @@ class BaseOpenAI(BaseLLM):
|
||||
params = {**params, **kwargs}
|
||||
sub_prompts = self.get_sub_prompts(params, prompts, stop)
|
||||
choices = []
|
||||
token_usage: Dict[str, int] = {}
|
||||
token_usage: dict[str, int] = {}
|
||||
# Get the token usage from the response.
|
||||
# Includes prompt, completion, and total tokens used.
|
||||
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
||||
@@ -419,10 +406,10 @@ class BaseOpenAI(BaseLLM):
|
||||
|
||||
def get_sub_prompts(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> List[List[str]]:
|
||||
params: dict[str, Any],
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]] = None,
|
||||
) -> list[list[str]]:
|
||||
"""Get the sub prompts for llm call."""
|
||||
if stop is not None:
|
||||
params["stop"] = stop
|
||||
@@ -441,9 +428,9 @@ class BaseOpenAI(BaseLLM):
|
||||
def create_llm_result(
|
||||
self,
|
||||
choices: Any,
|
||||
prompts: List[str],
|
||||
params: Dict[str, Any],
|
||||
token_usage: Dict[str, int],
|
||||
prompts: list[str],
|
||||
params: dict[str, Any],
|
||||
token_usage: dict[str, int],
|
||||
*,
|
||||
system_fingerprint: Optional[str] = None,
|
||||
) -> LLMResult:
|
||||
@@ -470,7 +457,7 @@ class BaseOpenAI(BaseLLM):
|
||||
return LLMResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
def _invocation_params(self) -> dict[str, Any]:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
return self._default_params
|
||||
|
||||
@@ -484,7 +471,7 @@ class BaseOpenAI(BaseLLM):
|
||||
"""Return type of llm."""
|
||||
return "openai"
|
||||
|
||||
def get_token_ids(self, text: str) -> List[int]:
|
||||
def get_token_ids(self, text: str) -> list[int]:
|
||||
"""Get the token IDs using the tiktoken package."""
|
||||
if self.custom_get_token_ids is not None:
|
||||
return self.custom_get_token_ids(text)
|
||||
@@ -689,7 +676,7 @@ class OpenAI(BaseOpenAI):
|
||||
""" # noqa: E501
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "llms", "openai"]
|
||||
|
||||
@@ -699,16 +686,16 @@ class OpenAI(BaseOpenAI):
|
||||
return True
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
def _invocation_params(self) -> dict[str, Any]:
|
||||
return {**{"model": self.model_name}, **super()._invocation_params}
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
def lc_secrets(self) -> dict[str, str]:
|
||||
return {"openai_api_key": "OPENAI_API_KEY"}
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
attributes: Dict[str, Any] = {}
|
||||
def lc_attributes(self) -> dict[str, Any]:
|
||||
attributes: dict[str, Any] = {}
|
||||
if self.openai_api_base:
|
||||
attributes["openai_api_base"] = self.openai_api_base
|
||||
|
||||
|
Reference in New Issue
Block a user