mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 06:53:16 +00:00
partners[lint]: run pyupgrade
to get code in line with 3.9 standards (#30781)
Using `pyupgrade` to get all `partners` code up to 3.9 standards (mostly, fixing old `typing` imports).
This commit is contained in:
@@ -1,16 +1,13 @@
|
||||
"""Hugging Face Chat Wrapper."""
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@@ -46,8 +43,8 @@ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."
|
||||
class TGI_RESPONSE:
|
||||
"""Response from the TextGenInference API."""
|
||||
|
||||
choices: List[Any]
|
||||
usage: Dict
|
||||
choices: list[Any]
|
||||
usage: dict
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -56,12 +53,12 @@ class TGI_MESSAGE:
|
||||
|
||||
role: str
|
||||
content: str
|
||||
tool_calls: List[Dict]
|
||||
tool_calls: list[dict]
|
||||
|
||||
|
||||
def _convert_message_to_chat_message(
|
||||
message: BaseMessage,
|
||||
) -> Dict:
|
||||
) -> dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
return dict(role=message.role, content=message.content)
|
||||
elif isinstance(message, HumanMessage):
|
||||
@@ -104,7 +101,7 @@ def _convert_TGI_message_to_LC_message(
|
||||
content = cast(str, _message.content)
|
||||
if content is None:
|
||||
content = ""
|
||||
additional_kwargs: Dict = {}
|
||||
additional_kwargs: dict = {}
|
||||
if tool_calls := _message.tool_calls:
|
||||
if "arguments" in tool_calls[0]["function"]:
|
||||
functions = tool_calls[0]["function"].pop("arguments")
|
||||
@@ -358,8 +355,8 @@ class ChatHuggingFace(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@@ -380,8 +377,8 @@ class ChatHuggingFace(BaseChatModel):
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@@ -398,7 +395,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
|
||||
def _to_chat_prompt(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: list[BaseMessage],
|
||||
) -> str:
|
||||
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
|
||||
if not messages:
|
||||
@@ -472,7 +469,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
|
||||
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
|
||||
*,
|
||||
tool_choice: Optional[
|
||||
Union[dict, str, Literal["auto", "none", "required"], bool]
|
||||
@@ -529,8 +526,8 @@ class ChatHuggingFace(BaseChatModel):
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> List[Dict[Any, Any]]:
|
||||
self, messages: list[BaseMessage], stop: Optional[list[str]]
|
||||
) -> list[dict[Any, Any]]:
|
||||
message_dicts = [_convert_message_to_chat_message(m) for m in messages]
|
||||
return message_dicts
|
||||
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
@@ -40,16 +40,16 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
cache_folder: Optional[str] = None
|
||||
"""Path to store models.
|
||||
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass to the Sentence Transformer model, such as `device`,
|
||||
`prompts`, `default_prompt_name`, `revision`, `trust_remote_code`, or `token`.
|
||||
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer"""
|
||||
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
encode_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass when calling the `encode` method for the documents of
|
||||
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
|
||||
`precision`, `normalize_embeddings`, and more.
|
||||
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"""
|
||||
query_encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
query_encode_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass when calling the `encode` method for the query of
|
||||
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
|
||||
`precision`, `normalize_embeddings`, and more.
|
||||
@@ -102,8 +102,8 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
|
||||
def _embed(
|
||||
self, texts: list[str], encode_kwargs: Dict[str, Any]
|
||||
) -> List[List[float]]:
|
||||
self, texts: list[str], encode_kwargs: dict[str, Any]
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
Embed a text using the HuggingFace transformer model.
|
||||
|
||||
@@ -138,7 +138,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
return embeddings.tolist()
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
@@ -149,7 +149,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
return self._embed(texts, self.encode_kwargs)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Compute query embeddings using a HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
|
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.utils import from_env
|
||||
@@ -101,7 +101,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
return self
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Call out to HuggingFaceHub's embedding endpoint for embedding search docs.
|
||||
|
||||
Args:
|
||||
@@ -117,7 +117,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
||||
responses = self.client.feature_extraction(text=texts, **_model_kwargs)
|
||||
return responses
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Async Call to HuggingFaceHub's embedding endpoint for embedding search docs.
|
||||
|
||||
Args:
|
||||
@@ -134,7 +134,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
return responses
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Call out to HuggingFaceHub's embedding endpoint for embedding query text.
|
||||
|
||||
Args:
|
||||
@@ -146,7 +146,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
||||
response = self.embed_documents([text])[0]
|
||||
return response
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
"""Async Call to HuggingFaceHub's embedding endpoint for embedding query text.
|
||||
|
||||
Args:
|
||||
|
@@ -2,7 +2,8 @@ import inspect
|
||||
import json # type: ignore[import-not-found]
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional
|
||||
from collections.abc import AsyncIterator, Iterator, Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@@ -96,7 +97,7 @@ class HuggingFaceEndpoint(LLM):
|
||||
"""Whether to prepend the prompt to the generated text"""
|
||||
truncate: Optional[int] = None
|
||||
"""Truncate inputs tokens to the given size"""
|
||||
stop_sequences: List[str] = Field(default_factory=list)
|
||||
stop_sequences: list[str] = Field(default_factory=list)
|
||||
"""Stop generating tokens if a member of `stop_sequences` is generated"""
|
||||
seed: Optional[int] = None
|
||||
"""Random sampling seed"""
|
||||
@@ -111,9 +112,9 @@ class HuggingFaceEndpoint(LLM):
|
||||
watermark: bool = False
|
||||
"""Watermarking with [A Watermark for Large Language Models]
|
||||
(https://arxiv.org/abs/2301.10226)"""
|
||||
server_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
server_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any text-generation-inference server parameters not explicitly specified"""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `call` not explicitly specified"""
|
||||
model: str
|
||||
client: Any = None #: :meta private:
|
||||
@@ -128,7 +129,7 @@ class HuggingFaceEndpoint(LLM):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
def build_extra(cls, values: dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
@@ -252,7 +253,7 @@ class HuggingFaceEndpoint(LLM):
|
||||
return self
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
def _default_params(self) -> dict[str, Any]:
|
||||
"""Get the default parameters for calling text generation inference API."""
|
||||
return {
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
@@ -285,8 +286,8 @@ class HuggingFaceEndpoint(LLM):
|
||||
return "huggingface_endpoint"
|
||||
|
||||
def _invocation_params(
|
||||
self, runtime_stop: Optional[List[str]], **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
self, runtime_stop: Optional[list[str]], **kwargs: Any
|
||||
) -> dict[str, Any]:
|
||||
params = {**self._default_params, **kwargs}
|
||||
params["stop_sequences"] = params["stop_sequences"] + (runtime_stop or [])
|
||||
return params
|
||||
@@ -294,7 +295,7 @@ class HuggingFaceEndpoint(LLM):
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
@@ -326,7 +327,7 @@ class HuggingFaceEndpoint(LLM):
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
@@ -357,7 +358,7 @@ class HuggingFaceEndpoint(LLM):
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
@@ -394,7 +395,7 @@ class HuggingFaceEndpoint(LLM):
|
||||
async def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
|
@@ -2,7 +2,8 @@ from __future__ import annotations # type: ignore[import-not-found]
|
||||
|
||||
import importlib.util
|
||||
import logging
|
||||
from typing import Any, Dict, Iterator, List, Mapping, Optional
|
||||
from collections.abc import Iterator, Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import BaseLLM
|
||||
@@ -82,7 +83,7 @@ class HuggingFacePipeline(BaseLLM):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_init_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def pre_init_validator(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Ensure model_id is set either by pipeline or user input."""
|
||||
if "model_id" not in values:
|
||||
if "pipeline" in values and values["pipeline"]:
|
||||
@@ -297,13 +298,13 @@ class HuggingFacePipeline(BaseLLM):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
# List to hold all results
|
||||
text_generations: List[str] = []
|
||||
text_generations: list[str] = []
|
||||
pipeline_kwargs = kwargs.get("pipeline_kwargs", {})
|
||||
skip_prompt = kwargs.get("skip_prompt", False)
|
||||
|
||||
@@ -347,7 +348,7 @@ class HuggingFacePipeline(BaseLLM):
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
|
Reference in New Issue
Block a user