mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-13 11:07:41 +00:00
huggingface: fix community dep checking (#21628)
This commit is contained in:
parent
91a2ea5cd6
commit
9b51ca08bc
@ -14,10 +14,6 @@ from typing import (
|
|||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain_community.llms.huggingface_hub import HuggingFaceHub
|
|
||||||
from langchain_community.llms.huggingface_text_gen_inference import (
|
|
||||||
HuggingFaceTextGenInference,
|
|
||||||
)
|
|
||||||
from langchain_core.callbacks.manager import (
|
from langchain_core.callbacks.manager import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
@ -111,6 +107,34 @@ def _convert_TGI_message_to_LC_message(
|
|||||||
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_huggingface_hub(llm: Any) -> bool:
|
||||||
|
try:
|
||||||
|
from langchain_community.llms.huggingface_hub import ( # type: ignore[import-not-found]
|
||||||
|
HuggingFaceHub,
|
||||||
|
)
|
||||||
|
|
||||||
|
return isinstance(llm, HuggingFaceHub)
|
||||||
|
except ImportError:
|
||||||
|
# if no langchain community, it is not a HuggingFaceHub
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_huggingface_textgen_inference(llm: Any) -> bool:
|
||||||
|
try:
|
||||||
|
from langchain_community.llms.huggingface_text_gen_inference import ( # type: ignore[import-not-found]
|
||||||
|
HuggingFaceTextGenInference,
|
||||||
|
)
|
||||||
|
|
||||||
|
return isinstance(llm, HuggingFaceTextGenInference)
|
||||||
|
except ImportError:
|
||||||
|
# if no langchain community, it is not a HuggingFaceTextGenInference
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_huggingface_endpoint(llm: Any) -> bool:
|
||||||
|
return isinstance(llm, HuggingFaceEndpoint)
|
||||||
|
|
||||||
|
|
||||||
class ChatHuggingFace(BaseChatModel):
|
class ChatHuggingFace(BaseChatModel):
|
||||||
"""
|
"""
|
||||||
Wrapper for using Hugging Face LLM's as ChatModels.
|
Wrapper for using Hugging Face LLM's as ChatModels.
|
||||||
@ -147,9 +171,10 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_llm(cls, values: dict) -> dict:
|
def validate_llm(cls, values: dict) -> dict:
|
||||||
if not isinstance(
|
if (
|
||||||
values["llm"],
|
not _is_huggingface_hub(values["llm"])
|
||||||
(HuggingFaceHub, HuggingFaceTextGenInference, HuggingFaceEndpoint),
|
and not _is_huggingface_textgen_inference(values["llm"])
|
||||||
|
and not _is_huggingface_endpoint(values["llm"])
|
||||||
):
|
):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Expected llm to be one of HuggingFaceTextGenInference, "
|
"Expected llm to be one of HuggingFaceTextGenInference, "
|
||||||
@ -177,11 +202,11 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
if isinstance(self.llm, HuggingFaceTextGenInference):
|
if _is_huggingface_textgen_inference(self.llm):
|
||||||
message_dicts = self._create_message_dicts(messages, stop)
|
message_dicts = self._create_message_dicts(messages, stop)
|
||||||
answer = self.llm.client.chat(messages=message_dicts, **kwargs)
|
answer = self.llm.client.chat(messages=message_dicts, **kwargs)
|
||||||
return self._create_chat_result(answer)
|
return self._create_chat_result(answer)
|
||||||
elif isinstance(self.llm, HuggingFaceEndpoint):
|
elif _is_huggingface_endpoint(self.llm):
|
||||||
message_dicts = self._create_message_dicts(messages, stop)
|
message_dicts = self._create_message_dicts(messages, stop)
|
||||||
answer = self.llm.client.chat_completion(messages=message_dicts, **kwargs)
|
answer = self.llm.client.chat_completion(messages=message_dicts, **kwargs)
|
||||||
return self._create_chat_result(answer)
|
return self._create_chat_result(answer)
|
||||||
@ -199,7 +224,7 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
if isinstance(self.llm, HuggingFaceTextGenInference):
|
if _is_huggingface_textgen_inference(self.llm):
|
||||||
message_dicts = self._create_message_dicts(messages, stop)
|
message_dicts = self._create_message_dicts(messages, stop)
|
||||||
answer = await self.llm.async_client.chat(messages=message_dicts, **kwargs)
|
answer = await self.llm.async_client.chat(messages=message_dicts, **kwargs)
|
||||||
return self._create_chat_result(answer)
|
return self._create_chat_result(answer)
|
||||||
@ -261,12 +286,12 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
from huggingface_hub import list_inference_endpoints # type: ignore[import]
|
from huggingface_hub import list_inference_endpoints # type: ignore[import]
|
||||||
|
|
||||||
available_endpoints = list_inference_endpoints("*")
|
available_endpoints = list_inference_endpoints("*")
|
||||||
if isinstance(self.llm, HuggingFaceHub) or (
|
if _is_huggingface_hub(self.llm) or (
|
||||||
hasattr(self.llm, "repo_id") and self.llm.repo_id
|
hasattr(self.llm, "repo_id") and self.llm.repo_id
|
||||||
):
|
):
|
||||||
self.model_id = self.llm.repo_id
|
self.model_id = self.llm.repo_id
|
||||||
return
|
return
|
||||||
elif isinstance(self.llm, HuggingFaceTextGenInference):
|
elif _is_huggingface_textgen_inference(self.llm):
|
||||||
endpoint_url: Optional[str] = self.llm.inference_server_url
|
endpoint_url: Optional[str] = self.llm.inference_server_url
|
||||||
else:
|
else:
|
||||||
endpoint_url = self.llm.endpoint_url
|
endpoint_url = self.llm.endpoint_url
|
||||||
|
@ -14,7 +14,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
|||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
from langchain_huggingface import HuggingFaceEmbeddings
|
||||||
|
|
||||||
model_name = "sentence-transformers/all-mpnet-base-v2"
|
model_name = "sentence-transformers/all-mpnet-base-v2"
|
||||||
model_kwargs = {'device': 'cpu'}
|
model_kwargs = {'device': 'cpu'}
|
||||||
|
@ -33,7 +33,7 @@ class HuggingFacePipeline(BaseLLM):
|
|||||||
Example using from_model_id:
|
Example using from_model_id:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain_community.llms import HuggingFacePipeline
|
from langchain_huggingface import HuggingFacePipeline
|
||||||
hf = HuggingFacePipeline.from_model_id(
|
hf = HuggingFacePipeline.from_model_id(
|
||||||
model_id="gpt2",
|
model_id="gpt2",
|
||||||
task="text-generation",
|
task="text-generation",
|
||||||
@ -42,7 +42,7 @@ class HuggingFacePipeline(BaseLLM):
|
|||||||
Example passing pipeline in directly:
|
Example passing pipeline in directly:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain_community.llms import HuggingFacePipeline
|
from langchain_huggingface import HuggingFacePipeline
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
||||||
|
|
||||||
model_id = "gpt2"
|
model_id = "gpt2"
|
||||||
|
2
libs/partners/huggingface/poetry.lock
generated
2
libs/partners/huggingface/poetry.lock
generated
@ -3343,4 +3343,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
content-hash = "3c254c1b543c7f23605ca0183b42db74fa64900300d802b15ed4c9b979bcea63"
|
content-hash = "dd853a4abc0cb93a17319b693279069f9cfd06838d10615217ab00e2c879c789"
|
||||||
|
@ -44,14 +44,12 @@ ruff = "^0.1.5"
|
|||||||
[tool.poetry.group.typing.dependencies]
|
[tool.poetry.group.typing.dependencies]
|
||||||
mypy = "^1"
|
mypy = "^1"
|
||||||
langchain-core = { path = "../../core", develop = true }
|
langchain-core = { path = "../../core", develop = true }
|
||||||
langchain-community = { path = "../../community", develop = true }
|
|
||||||
|
|
||||||
[tool.poetry.group.dev]
|
[tool.poetry.group.dev]
|
||||||
optional = true
|
optional = true
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
langchain-core = { path = "../../core", develop = true }
|
langchain-core = { path = "../../core", develop = true }
|
||||||
langchain-community = { path = "../../community", develop = true }
|
|
||||||
ipykernel = "^6.29.2"
|
ipykernel = "^6.29.2"
|
||||||
|
|
||||||
[tool.poetry.group.test_integration]
|
[tool.poetry.group.test_integration]
|
||||||
|
@ -8,7 +8,7 @@ errors=0
|
|||||||
# make sure not importing from langchain or langchain_experimental
|
# make sure not importing from langchain or langchain_experimental
|
||||||
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
||||||
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
||||||
# git --no-pager grep '^from langchain_community\.' . && errors=$((errors+1))
|
git --no-pager grep '^from langchain_community\.' . && errors=$((errors+1))
|
||||||
|
|
||||||
# Decide on an exit status based on the errors
|
# Decide on an exit status based on the errors
|
||||||
if [ "$errors" -gt 0 ]; then
|
if [ "$errors" -gt 0 ]; then
|
||||||
|
Loading…
Reference in New Issue
Block a user