huggingface: fix community dep checking (#21628)

This commit is contained in:
Erick Friis 2024-05-13 14:52:18 -07:00 committed by GitHub
parent 91a2ea5cd6
commit 9b51ca08bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 42 additions and 19 deletions

View File

@ -14,10 +14,6 @@ from typing import (
cast,
)
from langchain_community.llms.huggingface_hub import HuggingFaceHub
from langchain_community.llms.huggingface_text_gen_inference import (
HuggingFaceTextGenInference,
)
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
@ -111,6 +107,34 @@ def _convert_TGI_message_to_LC_message(
return AIMessage(content=content, additional_kwargs=additional_kwargs)
def _is_huggingface_hub(llm: Any) -> bool:
try:
from langchain_community.llms.huggingface_hub import ( # type: ignore[import-not-found]
HuggingFaceHub,
)
return isinstance(llm, HuggingFaceHub)
except ImportError:
# if no langchain community, it is not a HuggingFaceHub
return False
def _is_huggingface_textgen_inference(llm: Any) -> bool:
try:
from langchain_community.llms.huggingface_text_gen_inference import ( # type: ignore[import-not-found]
HuggingFaceTextGenInference,
)
return isinstance(llm, HuggingFaceTextGenInference)
except ImportError:
# if no langchain community, it is not a HuggingFaceTextGenInference
return False
def _is_huggingface_endpoint(llm: Any) -> bool:
return isinstance(llm, HuggingFaceEndpoint)
class ChatHuggingFace(BaseChatModel):
"""
Wrapper for using Hugging Face LLM's as ChatModels.
@ -147,9 +171,10 @@ class ChatHuggingFace(BaseChatModel):
@root_validator()
def validate_llm(cls, values: dict) -> dict:
if not isinstance(
values["llm"],
(HuggingFaceHub, HuggingFaceTextGenInference, HuggingFaceEndpoint),
if (
not _is_huggingface_hub(values["llm"])
and not _is_huggingface_textgen_inference(values["llm"])
and not _is_huggingface_endpoint(values["llm"])
):
raise TypeError(
"Expected llm to be one of HuggingFaceTextGenInference, "
@ -177,11 +202,11 @@ class ChatHuggingFace(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if isinstance(self.llm, HuggingFaceTextGenInference):
if _is_huggingface_textgen_inference(self.llm):
message_dicts = self._create_message_dicts(messages, stop)
answer = self.llm.client.chat(messages=message_dicts, **kwargs)
return self._create_chat_result(answer)
elif isinstance(self.llm, HuggingFaceEndpoint):
elif _is_huggingface_endpoint(self.llm):
message_dicts = self._create_message_dicts(messages, stop)
answer = self.llm.client.chat_completion(messages=message_dicts, **kwargs)
return self._create_chat_result(answer)
@ -199,7 +224,7 @@ class ChatHuggingFace(BaseChatModel):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if isinstance(self.llm, HuggingFaceTextGenInference):
if _is_huggingface_textgen_inference(self.llm):
message_dicts = self._create_message_dicts(messages, stop)
answer = await self.llm.async_client.chat(messages=message_dicts, **kwargs)
return self._create_chat_result(answer)
@ -261,12 +286,12 @@ class ChatHuggingFace(BaseChatModel):
from huggingface_hub import list_inference_endpoints # type: ignore[import]
available_endpoints = list_inference_endpoints("*")
if isinstance(self.llm, HuggingFaceHub) or (
if _is_huggingface_hub(self.llm) or (
hasattr(self.llm, "repo_id") and self.llm.repo_id
):
self.model_id = self.llm.repo_id
return
elif isinstance(self.llm, HuggingFaceTextGenInference):
elif _is_huggingface_textgen_inference(self.llm):
endpoint_url: Optional[str] = self.llm.inference_server_url
else:
endpoint_url = self.llm.endpoint_url

View File

@ -14,7 +14,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
Example:
.. code-block:: python
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_huggingface import HuggingFaceEmbeddings
model_name = "sentence-transformers/all-mpnet-base-v2"
model_kwargs = {'device': 'cpu'}

View File

@ -33,7 +33,7 @@ class HuggingFacePipeline(BaseLLM):
Example using from_model_id:
.. code-block:: python
from langchain_community.llms import HuggingFacePipeline
from langchain_huggingface import HuggingFacePipeline
hf = HuggingFacePipeline.from_model_id(
model_id="gpt2",
task="text-generation",
@ -42,7 +42,7 @@ class HuggingFacePipeline(BaseLLM):
Example passing pipeline in directly:
.. code-block:: python
from langchain_community.llms import HuggingFacePipeline
from langchain_huggingface import HuggingFacePipeline
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
model_id = "gpt2"

View File

@ -3343,4 +3343,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "3c254c1b543c7f23605ca0183b42db74fa64900300d802b15ed4c9b979bcea63"
content-hash = "dd853a4abc0cb93a17319b693279069f9cfd06838d10615217ab00e2c879c789"

View File

@ -44,14 +44,12 @@ ruff = "^0.1.5"
[tool.poetry.group.typing.dependencies]
mypy = "^1"
langchain-core = { path = "../../core", develop = true }
langchain-community = { path = "../../community", develop = true }
[tool.poetry.group.dev]
optional = true
[tool.poetry.group.dev.dependencies]
langchain-core = { path = "../../core", develop = true }
langchain-community = { path = "../../community", develop = true }
ipykernel = "^6.29.2"
[tool.poetry.group.test_integration]

View File

@ -8,7 +8,7 @@ errors=0
# make sure not importing from langchain or langchain_experimental
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
# git --no-pager grep '^from langchain_community\.' . && errors=$((errors+1))
git --no-pager grep '^from langchain_community\.' . && errors=$((errors+1))
# Decide on an exit status based on the errors
if [ "$errors" -gt 0 ]; then