diff --git a/libs/community/langchain_community/chat_models/huggingface.py b/libs/community/langchain_community/chat_models/huggingface.py index 143aff07172..f459f58e1dd 100644 --- a/libs/community/langchain_community/chat_models/huggingface.py +++ b/libs/community/langchain_community/chat_models/huggingface.py @@ -1,6 +1,6 @@ """Hugging Face Chat Wrapper.""" -from typing import Any, List, Optional, Union +from typing import Any, List, Optional from langchain_core.callbacks.manager import ( AsyncCallbackManagerForLLMRun, @@ -13,11 +13,8 @@ from langchain_core.messages import ( HumanMessage, SystemMessage, ) -from langchain_core.outputs import ( - ChatGeneration, - ChatResult, - LLMResult, -) +from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult +from langchain_core.pydantic_v1 import root_validator from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint from langchain_community.llms.huggingface_hub import HuggingFaceHub @@ -42,7 +39,9 @@ class ChatHuggingFace(BaseChatModel): Adapted from: https://python.langchain.com/docs/integrations/chat/llama2_chat """ - llm: Union[HuggingFaceTextGenInference, HuggingFaceEndpoint, HuggingFaceHub] + llm: Any + """LLM, must be of type HuggingFaceTextGenInference, HuggingFaceEndpoint, or + HuggingFaceHub.""" system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT) tokenizer: Any = None model_id: Optional[str] = None @@ -60,6 +59,18 @@ class ChatHuggingFace(BaseChatModel): else self.tokenizer ) + @root_validator() + def validate_llm(cls, values: dict) -> dict: + if not isinstance( + values["llm"], + (HuggingFaceTextGenInference, HuggingFaceEndpoint, HuggingFaceHub), + ): + raise TypeError( + "Expected llm to be one of HuggingFaceTextGenInference, " + f"HuggingFaceEndpoint, HuggingFaceHub, received {type(values['llm'])}" + ) + return values + def _generate( self, messages: List[BaseMessage],