mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-02 19:47:13 +00:00
huggingface: fix community dep checking (#21628)
This commit is contained in:
@@ -14,10 +14,6 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_community.llms.huggingface_hub import HuggingFaceHub
|
||||
from langchain_community.llms.huggingface_text_gen_inference import (
|
||||
HuggingFaceTextGenInference,
|
||||
)
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
@@ -111,6 +107,34 @@ def _convert_TGI_message_to_LC_message(
|
||||
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||
|
||||
|
||||
def _is_huggingface_hub(llm: Any) -> bool:
|
||||
try:
|
||||
from langchain_community.llms.huggingface_hub import ( # type: ignore[import-not-found]
|
||||
HuggingFaceHub,
|
||||
)
|
||||
|
||||
return isinstance(llm, HuggingFaceHub)
|
||||
except ImportError:
|
||||
# if no langchain community, it is not a HuggingFaceHub
|
||||
return False
|
||||
|
||||
|
||||
def _is_huggingface_textgen_inference(llm: Any) -> bool:
|
||||
try:
|
||||
from langchain_community.llms.huggingface_text_gen_inference import ( # type: ignore[import-not-found]
|
||||
HuggingFaceTextGenInference,
|
||||
)
|
||||
|
||||
return isinstance(llm, HuggingFaceTextGenInference)
|
||||
except ImportError:
|
||||
# if no langchain community, it is not a HuggingFaceTextGenInference
|
||||
return False
|
||||
|
||||
|
||||
def _is_huggingface_endpoint(llm: Any) -> bool:
|
||||
return isinstance(llm, HuggingFaceEndpoint)
|
||||
|
||||
|
||||
class ChatHuggingFace(BaseChatModel):
|
||||
"""
|
||||
Wrapper for using Hugging Face LLM's as ChatModels.
|
||||
@@ -147,9 +171,10 @@ class ChatHuggingFace(BaseChatModel):
|
||||
|
||||
@root_validator()
|
||||
def validate_llm(cls, values: dict) -> dict:
|
||||
if not isinstance(
|
||||
values["llm"],
|
||||
(HuggingFaceHub, HuggingFaceTextGenInference, HuggingFaceEndpoint),
|
||||
if (
|
||||
not _is_huggingface_hub(values["llm"])
|
||||
and not _is_huggingface_textgen_inference(values["llm"])
|
||||
and not _is_huggingface_endpoint(values["llm"])
|
||||
):
|
||||
raise TypeError(
|
||||
"Expected llm to be one of HuggingFaceTextGenInference, "
|
||||
@@ -177,11 +202,11 @@ class ChatHuggingFace(BaseChatModel):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if isinstance(self.llm, HuggingFaceTextGenInference):
|
||||
if _is_huggingface_textgen_inference(self.llm):
|
||||
message_dicts = self._create_message_dicts(messages, stop)
|
||||
answer = self.llm.client.chat(messages=message_dicts, **kwargs)
|
||||
return self._create_chat_result(answer)
|
||||
elif isinstance(self.llm, HuggingFaceEndpoint):
|
||||
elif _is_huggingface_endpoint(self.llm):
|
||||
message_dicts = self._create_message_dicts(messages, stop)
|
||||
answer = self.llm.client.chat_completion(messages=message_dicts, **kwargs)
|
||||
return self._create_chat_result(answer)
|
||||
@@ -199,7 +224,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if isinstance(self.llm, HuggingFaceTextGenInference):
|
||||
if _is_huggingface_textgen_inference(self.llm):
|
||||
message_dicts = self._create_message_dicts(messages, stop)
|
||||
answer = await self.llm.async_client.chat(messages=message_dicts, **kwargs)
|
||||
return self._create_chat_result(answer)
|
||||
@@ -261,12 +286,12 @@ class ChatHuggingFace(BaseChatModel):
|
||||
from huggingface_hub import list_inference_endpoints # type: ignore[import]
|
||||
|
||||
available_endpoints = list_inference_endpoints("*")
|
||||
if isinstance(self.llm, HuggingFaceHub) or (
|
||||
if _is_huggingface_hub(self.llm) or (
|
||||
hasattr(self.llm, "repo_id") and self.llm.repo_id
|
||||
):
|
||||
self.model_id = self.llm.repo_id
|
||||
return
|
||||
elif isinstance(self.llm, HuggingFaceTextGenInference):
|
||||
elif _is_huggingface_textgen_inference(self.llm):
|
||||
endpoint_url: Optional[str] = self.llm.inference_server_url
|
||||
else:
|
||||
endpoint_url = self.llm.endpoint_url
|
||||
|
@@ -14,7 +14,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
|
||||
model_name = "sentence-transformers/all-mpnet-base-v2"
|
||||
model_kwargs = {'device': 'cpu'}
|
||||
|
@@ -33,7 +33,7 @@ class HuggingFacePipeline(BaseLLM):
|
||||
Example using from_model_id:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.llms import HuggingFacePipeline
|
||||
from langchain_huggingface import HuggingFacePipeline
|
||||
hf = HuggingFacePipeline.from_model_id(
|
||||
model_id="gpt2",
|
||||
task="text-generation",
|
||||
@@ -42,7 +42,7 @@ class HuggingFacePipeline(BaseLLM):
|
||||
Example passing pipeline in directly:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.llms import HuggingFacePipeline
|
||||
from langchain_huggingface import HuggingFacePipeline
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
||||
|
||||
model_id = "gpt2"
|
||||
|
Reference in New Issue
Block a user