From 3b1eb1f828fe22b7e09a283a621fbac3f8977318 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Thu, 7 Mar 2024 17:06:38 -0800 Subject: [PATCH] community[patch]: chat hf typing fix (#18693) --- .../chat_models/huggingface.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/libs/community/langchain_community/chat_models/huggingface.py b/libs/community/langchain_community/chat_models/huggingface.py index 143aff07172..f459f58e1dd 100644 --- a/libs/community/langchain_community/chat_models/huggingface.py +++ b/libs/community/langchain_community/chat_models/huggingface.py @@ -1,6 +1,6 @@ """Hugging Face Chat Wrapper.""" -from typing import Any, List, Optional, Union +from typing import Any, List, Optional from langchain_core.callbacks.manager import ( AsyncCallbackManagerForLLMRun, @@ -13,11 +13,8 @@ from langchain_core.messages import ( HumanMessage, SystemMessage, ) -from langchain_core.outputs import ( - ChatGeneration, - ChatResult, - LLMResult, -) +from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult +from langchain_core.pydantic_v1 import root_validator from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint from langchain_community.llms.huggingface_hub import HuggingFaceHub @@ -42,7 +39,9 @@ class ChatHuggingFace(BaseChatModel): Adapted from: https://python.langchain.com/docs/integrations/chat/llama2_chat """ - llm: Union[HuggingFaceTextGenInference, HuggingFaceEndpoint, HuggingFaceHub] + llm: Any + """LLM, must be of type HuggingFaceTextGenInference, HuggingFaceEndpoint, or + HuggingFaceHub.""" system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT) tokenizer: Any = None model_id: Optional[str] = None @@ -60,6 +59,18 @@ class ChatHuggingFace(BaseChatModel): else self.tokenizer ) + @root_validator() + def validate_llm(cls, values: dict) -> dict: + if not isinstance( + values["llm"], + (HuggingFaceTextGenInference, HuggingFaceEndpoint, HuggingFaceHub), + ): + raise TypeError( + "Expected llm to be one of HuggingFaceTextGenInference, " + f"HuggingFaceEndpoint, HuggingFaceHub, received {type(values['llm'])}" + ) + return values + def _generate( self, messages: List[BaseMessage],