mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
partners[lint]: run pyupgrade to get code in line with 3.9 standards (#30781)
Using `pyupgrade` to get all `partners` code up to 3.9 standards (mostly, fixing old `typing` imports).
This commit is contained in:
@@ -4,20 +4,13 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
@@ -109,7 +102,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
# Fix for azure
|
||||
# Also Fireworks returns None for tool invocations
|
||||
content = _dict.get("content", "") or ""
|
||||
additional_kwargs: Dict = {}
|
||||
additional_kwargs: dict = {}
|
||||
if function_call := _dict.get("function_call"):
|
||||
additional_kwargs["function_call"] = dict(function_call)
|
||||
tool_calls = []
|
||||
@@ -157,7 +150,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
Returns:
|
||||
The dictionary.
|
||||
"""
|
||||
message_dict: Dict[str, Any]
|
||||
message_dict: dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
@@ -205,14 +198,14 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
|
||||
|
||||
def _convert_chunk_to_message_chunk(
|
||||
chunk: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||
chunk: Mapping[str, Any], default_class: type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
choice = chunk["choices"][0]
|
||||
_dict = choice["delta"]
|
||||
role = cast(str, _dict.get("role"))
|
||||
content = cast(str, _dict.get("content") or "")
|
||||
additional_kwargs: Dict = {}
|
||||
tool_call_chunks: List[ToolCallChunk] = []
|
||||
additional_kwargs: dict = {}
|
||||
tool_call_chunks: list[ToolCallChunk] = []
|
||||
if _dict.get("function_call"):
|
||||
function_call = dict(_dict["function_call"])
|
||||
if "name" in function_call and function_call["name"] is None:
|
||||
@@ -290,17 +283,17 @@ class ChatFireworks(BaseChatModel):
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
def lc_secrets(self) -> dict[str, str]:
|
||||
return {"fireworks_api_key": "FIREWORKS_API_KEY"}
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "chat_models", "fireworks"]
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
attributes: Dict[str, Any] = {}
|
||||
def lc_attributes(self) -> dict[str, Any]:
|
||||
attributes: dict[str, Any] = {}
|
||||
if self.fireworks_api_base:
|
||||
attributes["fireworks_api_base"] = self.fireworks_api_base
|
||||
|
||||
@@ -319,9 +312,9 @@ class ChatFireworks(BaseChatModel):
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.0
|
||||
"""What sampling temperature to use."""
|
||||
stop: Optional[Union[str, List[str]]] = Field(default=None, alias="stop_sequences")
|
||||
stop: Optional[Union[str, list[str]]] = Field(default=None, alias="stop_sequences")
|
||||
"""Default stop sequences."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
fireworks_api_key: SecretStr = Field(
|
||||
alias="api_key",
|
||||
@@ -344,7 +337,7 @@ class ChatFireworks(BaseChatModel):
|
||||
)
|
||||
"""Base URL path for API requests, leave blank if not using a proxy or service
|
||||
emulator."""
|
||||
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
|
||||
request_timeout: Union[float, tuple[float, float], Any, None] = Field(
|
||||
default=None, alias="timeout"
|
||||
)
|
||||
"""Timeout for requests to Fireworks completion API. Can be float, httpx.Timeout or
|
||||
@@ -364,7 +357,7 @@ class ChatFireworks(BaseChatModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
def build_extra(cls, values: dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
values = _build_model_kwargs(values, all_required_field_names)
|
||||
@@ -398,7 +391,7 @@ class ChatFireworks(BaseChatModel):
|
||||
return self
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
def _default_params(self) -> dict[str, Any]:
|
||||
"""Get the default parameters for calling Fireworks API."""
|
||||
params = {
|
||||
"model": self.model_name,
|
||||
@@ -413,7 +406,7 @@ class ChatFireworks(BaseChatModel):
|
||||
return params
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
self, stop: Optional[list[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
@@ -429,7 +422,7 @@ class ChatFireworks(BaseChatModel):
|
||||
ls_params["ls_stop"] = ls_stop
|
||||
return ls_params
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
|
||||
overall_token_usage: dict = {}
|
||||
system_fingerprint = None
|
||||
for output in llm_outputs:
|
||||
@@ -452,15 +445,15 @@ class ChatFireworks(BaseChatModel):
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||
for chunk in self.client.create(messages=message_dicts, **params):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
@@ -487,8 +480,8 @@ class ChatFireworks(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
@@ -509,8 +502,8 @@ class ChatFireworks(BaseChatModel):
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
self, messages: list[BaseMessage], stop: Optional[list[str]]
|
||||
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
||||
params = self._default_params
|
||||
if stop is not None:
|
||||
params["stop"] = stop
|
||||
@@ -547,15 +540,15 @@ class ChatFireworks(BaseChatModel):
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||
async for chunk in self.async_client.acreate(messages=message_dicts, **params):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
@@ -584,8 +577,8 @@ class ChatFireworks(BaseChatModel):
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
@@ -607,13 +600,13 @@ class ChatFireworks(BaseChatModel):
|
||||
return self._create_chat_result(response)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
def _identifying_params(self) -> dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {"model_name": self.model_name, **self._default_params}
|
||||
|
||||
def _get_invocation_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
self, stop: Optional[list[str]] = None, **kwargs: Any
|
||||
) -> dict[str, Any]:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
return {
|
||||
"model": self.model_name,
|
||||
@@ -634,7 +627,7 @@ class ChatFireworks(BaseChatModel):
|
||||
)
|
||||
def bind_functions(
|
||||
self,
|
||||
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
|
||||
function_call: Optional[
|
||||
Union[_FunctionCall, str, Literal["auto", "none"]]
|
||||
] = None,
|
||||
@@ -690,7 +683,7 @@ class ChatFireworks(BaseChatModel):
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
|
||||
*,
|
||||
tool_choice: Optional[
|
||||
Union[dict, str, Literal["auto", "any", "none"], bool]
|
||||
@@ -738,14 +731,14 @@ class ChatFireworks(BaseChatModel):
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Optional[Union[Dict, Type[BaseModel]]] = None,
|
||||
schema: Optional[Union[dict, type[BaseModel]]] = None,
|
||||
*,
|
||||
method: Literal[
|
||||
"function_calling", "json_mode", "json_schema"
|
||||
] = "function_calling",
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import List
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.utils import secret_from_env
|
||||
from openai import OpenAI
|
||||
@@ -96,13 +94,13 @@ class FireworksEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
return self
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed search docs."""
|
||||
return [
|
||||
i.embedding
|
||||
for i in self.client.embeddings.create(input=texts, model=self.model).data
|
||||
]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Embed query text."""
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Wrapper around Fireworks AI's Completion API."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
from aiohttp import ClientSession
|
||||
@@ -63,7 +63,7 @@ class Fireworks(LLM):
|
||||
for question answering or summarization. A value greater than 1 introduces more
|
||||
randomness in the output.
|
||||
"""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
top_k: Optional[int] = None
|
||||
"""Used to limit the number of choices for the next predicted word or token. It
|
||||
@@ -90,7 +90,7 @@ class Fireworks(LLM):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
def build_extra(cls, values: dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
values = _build_model_kwargs(values, all_required_field_names)
|
||||
@@ -109,7 +109,7 @@ class Fireworks(LLM):
|
||||
return f"langchain-fireworks/{__version__}"
|
||||
|
||||
@property
|
||||
def default_params(self) -> Dict[str, Any]:
|
||||
def default_params(self) -> dict[str, Any]:
|
||||
return {
|
||||
"model": self.model,
|
||||
"temperature": self.temperature,
|
||||
@@ -122,7 +122,7 @@ class Fireworks(LLM):
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
@@ -139,7 +139,7 @@ class Fireworks(LLM):
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
stop_to_use = stop[0] if stop and len(stop) == 1 else stop
|
||||
payload: Dict[str, Any] = {
|
||||
payload: dict[str, Any] = {
|
||||
**self.default_params,
|
||||
"prompt": prompt,
|
||||
"stop": stop_to_use,
|
||||
@@ -168,7 +168,7 @@ class Fireworks(LLM):
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
@@ -185,7 +185,7 @@ class Fireworks(LLM):
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
stop_to_use = stop[0] if stop and len(stop) == 1 else stop
|
||||
payload: Dict[str, Any] = {
|
||||
payload: dict[str, Any] = {
|
||||
**self.default_params,
|
||||
"prompt": prompt,
|
||||
"stop": stop_to_use,
|
||||
|
||||
Reference in New Issue
Block a user