community[patch]: chat hf typing fix (#18693)

This commit is contained in:
Bagatur 2024-03-07 17:06:38 -08:00 committed by GitHub
parent 1e1cac50d8
commit 3b1eb1f828
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,6 @@
"""Hugging Face Chat Wrapper."""
from typing import Any, List, Optional, Union
from typing import Any, List, Optional
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
@ -13,11 +13,8 @@ from langchain_core.messages import (
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import (
ChatGeneration,
ChatResult,
LLMResult,
)
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
from langchain_core.pydantic_v1 import root_validator
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
from langchain_community.llms.huggingface_hub import HuggingFaceHub
@ -42,7 +39,9 @@ class ChatHuggingFace(BaseChatModel):
Adapted from: https://python.langchain.com/docs/integrations/chat/llama2_chat
"""
llm: Union[HuggingFaceTextGenInference, HuggingFaceEndpoint, HuggingFaceHub]
llm: Any
"""LLM, must be of type HuggingFaceTextGenInference, HuggingFaceEndpoint, or
HuggingFaceHub."""
system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT)
tokenizer: Any = None
model_id: Optional[str] = None
@ -60,6 +59,18 @@ class ChatHuggingFace(BaseChatModel):
else self.tokenizer
)
@root_validator()
def validate_llm(cls, values: dict) -> dict:
if not isinstance(
values["llm"],
(HuggingFaceTextGenInference, HuggingFaceEndpoint, HuggingFaceHub),
):
raise TypeError(
"Expected llm to be one of HuggingFaceTextGenInference, "
f"HuggingFaceEndpoint, HuggingFaceHub, received {type(values['llm'])}"
)
return values
def _generate(
self,
messages: List[BaseMessage],