mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 16:43:35 +00:00
community[patch]: chat hf typing fix (#18693)
This commit is contained in:
parent
1e1cac50d8
commit
3b1eb1f828
@ -1,6 +1,6 @@
|
|||||||
"""Hugging Face Chat Wrapper."""
|
"""Hugging Face Chat Wrapper."""
|
||||||
|
|
||||||
from typing import Any, List, Optional, Union
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from langchain_core.callbacks.manager import (
|
from langchain_core.callbacks.manager import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
@ -13,11 +13,8 @@ from langchain_core.messages import (
|
|||||||
HumanMessage,
|
HumanMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import (
|
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
|
||||||
ChatGeneration,
|
from langchain_core.pydantic_v1 import root_validator
|
||||||
ChatResult,
|
|
||||||
LLMResult,
|
|
||||||
)
|
|
||||||
|
|
||||||
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
|
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
|
||||||
from langchain_community.llms.huggingface_hub import HuggingFaceHub
|
from langchain_community.llms.huggingface_hub import HuggingFaceHub
|
||||||
@ -42,7 +39,9 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
Adapted from: https://python.langchain.com/docs/integrations/chat/llama2_chat
|
Adapted from: https://python.langchain.com/docs/integrations/chat/llama2_chat
|
||||||
"""
|
"""
|
||||||
|
|
||||||
llm: Union[HuggingFaceTextGenInference, HuggingFaceEndpoint, HuggingFaceHub]
|
llm: Any
|
||||||
|
"""LLM, must be of type HuggingFaceTextGenInference, HuggingFaceEndpoint, or
|
||||||
|
HuggingFaceHub."""
|
||||||
system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT)
|
system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT)
|
||||||
tokenizer: Any = None
|
tokenizer: Any = None
|
||||||
model_id: Optional[str] = None
|
model_id: Optional[str] = None
|
||||||
@ -60,6 +59,18 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
else self.tokenizer
|
else self.tokenizer
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_llm(cls, values: dict) -> dict:
|
||||||
|
if not isinstance(
|
||||||
|
values["llm"],
|
||||||
|
(HuggingFaceTextGenInference, HuggingFaceEndpoint, HuggingFaceHub),
|
||||||
|
):
|
||||||
|
raise TypeError(
|
||||||
|
"Expected llm to be one of HuggingFaceTextGenInference, "
|
||||||
|
f"HuggingFaceEndpoint, HuggingFaceHub, received {type(values['llm'])}"
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
|
Loading…
Reference in New Issue
Block a user