mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-12 00:11:17 +00:00
huggingface[patch]: Support for HuggingFacePipeline in ChatHuggingFace. (#22194)
- **Description:** Added support for using HuggingFacePipeline in ChatHuggingFace (previously it was only usable with API endpoints, probably by oversight). - **Issue:** #19997 - **Dependencies:** none - **Twitter handle:** none --------- Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
@@ -35,6 +35,7 @@ from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
from langchain_huggingface.llms.huggingface_endpoint import HuggingFaceEndpoint
|
||||
from langchain_huggingface.llms.huggingface_pipeline import HuggingFacePipeline
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."""
|
||||
|
||||
@@ -135,6 +136,10 @@ def _is_huggingface_endpoint(llm: Any) -> bool:
|
||||
return isinstance(llm, HuggingFaceEndpoint)
|
||||
|
||||
|
||||
def _is_huggingface_pipeline(llm: Any) -> bool:
|
||||
return isinstance(llm, HuggingFacePipeline)
|
||||
|
||||
|
||||
class ChatHuggingFace(BaseChatModel):
|
||||
"""
|
||||
Wrapper for using Hugging Face LLM's as ChatModels.
|
||||
@@ -150,8 +155,8 @@ class ChatHuggingFace(BaseChatModel):
|
||||
"""
|
||||
|
||||
llm: Any
|
||||
"""LLM, must be of type HuggingFaceTextGenInference, HuggingFaceEndpoint, or
|
||||
HuggingFaceHub."""
|
||||
"""LLM, must be of type HuggingFaceTextGenInference, HuggingFaceEndpoint,
|
||||
HuggingFaceHub, or HuggingFacePipeline."""
|
||||
system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT)
|
||||
tokenizer: Any = None
|
||||
model_id: Optional[str] = None
|
||||
@@ -175,10 +180,12 @@ class ChatHuggingFace(BaseChatModel):
|
||||
not _is_huggingface_hub(values["llm"])
|
||||
and not _is_huggingface_textgen_inference(values["llm"])
|
||||
and not _is_huggingface_endpoint(values["llm"])
|
||||
and not _is_huggingface_pipeline(values["llm"])
|
||||
):
|
||||
raise TypeError(
|
||||
"Expected llm to be one of HuggingFaceTextGenInference, "
|
||||
f"HuggingFaceEndpoint, HuggingFaceHub, received {type(values['llm'])}"
|
||||
"HuggingFaceEndpoint, HuggingFaceHub, HuggingFacePipeline "
|
||||
f"received {type(values['llm'])}"
|
||||
)
|
||||
return values
|
||||
|
||||
@@ -293,6 +300,9 @@ class ChatHuggingFace(BaseChatModel):
|
||||
return
|
||||
elif _is_huggingface_textgen_inference(self.llm):
|
||||
endpoint_url: Optional[str] = self.llm.inference_server_url
|
||||
elif _is_huggingface_pipeline(self.llm):
|
||||
self.model_id = self.llm.model_id
|
||||
return
|
||||
else:
|
||||
endpoint_url = self.llm.endpoint_url
|
||||
|
||||
|
Reference in New Issue
Block a user