mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-09 17:18:31 +00:00
huggingface: fix community dep checking (#21628)
This commit is contained in:
parent
91a2ea5cd6
commit
9b51ca08bc
@ -14,10 +14,6 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_community.llms.huggingface_hub import HuggingFaceHub
|
||||
from langchain_community.llms.huggingface_text_gen_inference import (
|
||||
HuggingFaceTextGenInference,
|
||||
)
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
@ -111,6 +107,34 @@ def _convert_TGI_message_to_LC_message(
|
||||
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||
|
||||
|
||||
def _is_huggingface_hub(llm: Any) -> bool:
|
||||
try:
|
||||
from langchain_community.llms.huggingface_hub import ( # type: ignore[import-not-found]
|
||||
HuggingFaceHub,
|
||||
)
|
||||
|
||||
return isinstance(llm, HuggingFaceHub)
|
||||
except ImportError:
|
||||
# if no langchain community, it is not a HuggingFaceHub
|
||||
return False
|
||||
|
||||
|
||||
def _is_huggingface_textgen_inference(llm: Any) -> bool:
|
||||
try:
|
||||
from langchain_community.llms.huggingface_text_gen_inference import ( # type: ignore[import-not-found]
|
||||
HuggingFaceTextGenInference,
|
||||
)
|
||||
|
||||
return isinstance(llm, HuggingFaceTextGenInference)
|
||||
except ImportError:
|
||||
# if no langchain community, it is not a HuggingFaceTextGenInference
|
||||
return False
|
||||
|
||||
|
||||
def _is_huggingface_endpoint(llm: Any) -> bool:
|
||||
return isinstance(llm, HuggingFaceEndpoint)
|
||||
|
||||
|
||||
class ChatHuggingFace(BaseChatModel):
|
||||
"""
|
||||
Wrapper for using Hugging Face LLM's as ChatModels.
|
||||
@ -147,9 +171,10 @@ class ChatHuggingFace(BaseChatModel):
|
||||
|
||||
@root_validator()
|
||||
def validate_llm(cls, values: dict) -> dict:
|
||||
if not isinstance(
|
||||
values["llm"],
|
||||
(HuggingFaceHub, HuggingFaceTextGenInference, HuggingFaceEndpoint),
|
||||
if (
|
||||
not _is_huggingface_hub(values["llm"])
|
||||
and not _is_huggingface_textgen_inference(values["llm"])
|
||||
and not _is_huggingface_endpoint(values["llm"])
|
||||
):
|
||||
raise TypeError(
|
||||
"Expected llm to be one of HuggingFaceTextGenInference, "
|
||||
@ -177,11 +202,11 @@ class ChatHuggingFace(BaseChatModel):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if isinstance(self.llm, HuggingFaceTextGenInference):
|
||||
if _is_huggingface_textgen_inference(self.llm):
|
||||
message_dicts = self._create_message_dicts(messages, stop)
|
||||
answer = self.llm.client.chat(messages=message_dicts, **kwargs)
|
||||
return self._create_chat_result(answer)
|
||||
elif isinstance(self.llm, HuggingFaceEndpoint):
|
||||
elif _is_huggingface_endpoint(self.llm):
|
||||
message_dicts = self._create_message_dicts(messages, stop)
|
||||
answer = self.llm.client.chat_completion(messages=message_dicts, **kwargs)
|
||||
return self._create_chat_result(answer)
|
||||
@ -199,7 +224,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if isinstance(self.llm, HuggingFaceTextGenInference):
|
||||
if _is_huggingface_textgen_inference(self.llm):
|
||||
message_dicts = self._create_message_dicts(messages, stop)
|
||||
answer = await self.llm.async_client.chat(messages=message_dicts, **kwargs)
|
||||
return self._create_chat_result(answer)
|
||||
@ -261,12 +286,12 @@ class ChatHuggingFace(BaseChatModel):
|
||||
from huggingface_hub import list_inference_endpoints # type: ignore[import]
|
||||
|
||||
available_endpoints = list_inference_endpoints("*")
|
||||
if isinstance(self.llm, HuggingFaceHub) or (
|
||||
if _is_huggingface_hub(self.llm) or (
|
||||
hasattr(self.llm, "repo_id") and self.llm.repo_id
|
||||
):
|
||||
self.model_id = self.llm.repo_id
|
||||
return
|
||||
elif isinstance(self.llm, HuggingFaceTextGenInference):
|
||||
elif _is_huggingface_textgen_inference(self.llm):
|
||||
endpoint_url: Optional[str] = self.llm.inference_server_url
|
||||
else:
|
||||
endpoint_url = self.llm.endpoint_url
|
||||
|
@ -14,7 +14,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
|
||||
model_name = "sentence-transformers/all-mpnet-base-v2"
|
||||
model_kwargs = {'device': 'cpu'}
|
||||
|
@ -33,7 +33,7 @@ class HuggingFacePipeline(BaseLLM):
|
||||
Example using from_model_id:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.llms import HuggingFacePipeline
|
||||
from langchain_huggingface import HuggingFacePipeline
|
||||
hf = HuggingFacePipeline.from_model_id(
|
||||
model_id="gpt2",
|
||||
task="text-generation",
|
||||
@ -42,7 +42,7 @@ class HuggingFacePipeline(BaseLLM):
|
||||
Example passing pipeline in directly:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.llms import HuggingFacePipeline
|
||||
from langchain_huggingface import HuggingFacePipeline
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
||||
|
||||
model_id = "gpt2"
|
||||
|
2
libs/partners/huggingface/poetry.lock
generated
2
libs/partners/huggingface/poetry.lock
generated
@ -3343,4 +3343,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "3c254c1b543c7f23605ca0183b42db74fa64900300d802b15ed4c9b979bcea63"
|
||||
content-hash = "dd853a4abc0cb93a17319b693279069f9cfd06838d10615217ab00e2c879c789"
|
||||
|
@ -44,14 +44,12 @@ ruff = "^0.1.5"
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^1"
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
langchain-community = { path = "../../community", develop = true }
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
langchain-community = { path = "../../community", develop = true }
|
||||
ipykernel = "^6.29.2"
|
||||
|
||||
[tool.poetry.group.test_integration]
|
||||
|
@ -8,7 +8,7 @@ errors=0
|
||||
# make sure not importing from langchain or langchain_experimental
|
||||
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
||||
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
||||
# git --no-pager grep '^from langchain_community\.' . && errors=$((errors+1))
|
||||
git --no-pager grep '^from langchain_community\.' . && errors=$((errors+1))
|
||||
|
||||
# Decide on an exit status based on the errors
|
||||
if [ "$errors" -gt 0 ]; then
|
||||
|
Loading…
Reference in New Issue
Block a user