mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 07:35:18 +00:00
community[patch]: chat hf typing fix (#18693)
This commit is contained in:
parent
1e1cac50d8
commit
3b1eb1f828
@ -1,6 +1,6 @@
|
||||
"""Hugging Face Chat Wrapper."""
|
||||
|
||||
from typing import Any, List, Optional, Union
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@ -13,11 +13,8 @@ from langchain_core.messages import (
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
LLMResult,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
|
||||
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
|
||||
from langchain_community.llms.huggingface_hub import HuggingFaceHub
|
||||
@ -42,7 +39,9 @@ class ChatHuggingFace(BaseChatModel):
|
||||
Adapted from: https://python.langchain.com/docs/integrations/chat/llama2_chat
|
||||
"""
|
||||
|
||||
llm: Union[HuggingFaceTextGenInference, HuggingFaceEndpoint, HuggingFaceHub]
|
||||
llm: Any
|
||||
"""LLM, must be of type HuggingFaceTextGenInference, HuggingFaceEndpoint, or
|
||||
HuggingFaceHub."""
|
||||
system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT)
|
||||
tokenizer: Any = None
|
||||
model_id: Optional[str] = None
|
||||
@ -60,6 +59,18 @@ class ChatHuggingFace(BaseChatModel):
|
||||
else self.tokenizer
|
||||
)
|
||||
|
||||
@root_validator()
|
||||
def validate_llm(cls, values: dict) -> dict:
|
||||
if not isinstance(
|
||||
values["llm"],
|
||||
(HuggingFaceTextGenInference, HuggingFaceEndpoint, HuggingFaceHub),
|
||||
):
|
||||
raise TypeError(
|
||||
"Expected llm to be one of HuggingFaceTextGenInference, "
|
||||
f"HuggingFaceEndpoint, HuggingFaceHub, received {type(values['llm'])}"
|
||||
)
|
||||
return values
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
|
Loading…
Reference in New Issue
Block a user