mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-08 22:42:05 +00:00
partners[lint]: run pyupgrade
to get code in line with 3.9 standards (#30781)
Using `pyupgrade` to get all `partners` code up to 3.9 standards (mostly, fixing old `typing` imports).
This commit is contained in:
@@ -1,16 +1,13 @@
|
||||
"""Hugging Face Chat Wrapper."""
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@@ -46,8 +43,8 @@ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."
|
||||
class TGI_RESPONSE:
|
||||
"""Response from the TextGenInference API."""
|
||||
|
||||
choices: List[Any]
|
||||
usage: Dict
|
||||
choices: list[Any]
|
||||
usage: dict
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -56,12 +53,12 @@ class TGI_MESSAGE:
|
||||
|
||||
role: str
|
||||
content: str
|
||||
tool_calls: List[Dict]
|
||||
tool_calls: list[dict]
|
||||
|
||||
|
||||
def _convert_message_to_chat_message(
|
||||
message: BaseMessage,
|
||||
) -> Dict:
|
||||
) -> dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
return dict(role=message.role, content=message.content)
|
||||
elif isinstance(message, HumanMessage):
|
||||
@@ -104,7 +101,7 @@ def _convert_TGI_message_to_LC_message(
|
||||
content = cast(str, _message.content)
|
||||
if content is None:
|
||||
content = ""
|
||||
additional_kwargs: Dict = {}
|
||||
additional_kwargs: dict = {}
|
||||
if tool_calls := _message.tool_calls:
|
||||
if "arguments" in tool_calls[0]["function"]:
|
||||
functions = tool_calls[0]["function"].pop("arguments")
|
||||
@@ -358,8 +355,8 @@ class ChatHuggingFace(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@@ -380,8 +377,8 @@ class ChatHuggingFace(BaseChatModel):
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@@ -398,7 +395,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
|
||||
def _to_chat_prompt(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: list[BaseMessage],
|
||||
) -> str:
|
||||
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
|
||||
if not messages:
|
||||
@@ -472,7 +469,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
|
||||
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
|
||||
*,
|
||||
tool_choice: Optional[
|
||||
Union[dict, str, Literal["auto", "none", "required"], bool]
|
||||
@@ -529,8 +526,8 @@ class ChatHuggingFace(BaseChatModel):
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> List[Dict[Any, Any]]:
|
||||
self, messages: list[BaseMessage], stop: Optional[list[str]]
|
||||
) -> list[dict[Any, Any]]:
|
||||
message_dicts = [_convert_message_to_chat_message(m) for m in messages]
|
||||
return message_dicts
|
||||
|
||||
|
Reference in New Issue
Block a user