diff --git a/libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py b/libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py index ce5225943ce..8cbb274477e 100644 --- a/libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py +++ b/libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py @@ -14,10 +14,6 @@ from typing import ( cast, ) -from langchain_community.llms.huggingface_hub import HuggingFaceHub -from langchain_community.llms.huggingface_text_gen_inference import ( - HuggingFaceTextGenInference, -) from langchain_core.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -111,6 +107,34 @@ def _convert_TGI_message_to_LC_message( return AIMessage(content=content, additional_kwargs=additional_kwargs) +def _is_huggingface_hub(llm: Any) -> bool: + try: + from langchain_community.llms.huggingface_hub import ( # type: ignore[import-not-found] + HuggingFaceHub, + ) + + return isinstance(llm, HuggingFaceHub) + except ImportError: + # if no langchain community, it is not a HuggingFaceHub + return False + + +def _is_huggingface_textgen_inference(llm: Any) -> bool: + try: + from langchain_community.llms.huggingface_text_gen_inference import ( # type: ignore[import-not-found] + HuggingFaceTextGenInference, + ) + + return isinstance(llm, HuggingFaceTextGenInference) + except ImportError: + # if no langchain community, it is not a HuggingFaceTextGenInference + return False + + +def _is_huggingface_endpoint(llm: Any) -> bool: + return isinstance(llm, HuggingFaceEndpoint) + + class ChatHuggingFace(BaseChatModel): """ Wrapper for using Hugging Face LLM's as ChatModels. @@ -147,9 +171,10 @@ class ChatHuggingFace(BaseChatModel): @root_validator() def validate_llm(cls, values: dict) -> dict: - if not isinstance( - values["llm"], - (HuggingFaceHub, HuggingFaceTextGenInference, HuggingFaceEndpoint), + if ( + not _is_huggingface_hub(values["llm"]) + and not _is_huggingface_textgen_inference(values["llm"]) + and not _is_huggingface_endpoint(values["llm"]) ): raise TypeError( "Expected llm to be one of HuggingFaceTextGenInference, " @@ -177,11 +202,11 @@ class ChatHuggingFace(BaseChatModel): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - if isinstance(self.llm, HuggingFaceTextGenInference): + if _is_huggingface_textgen_inference(self.llm): message_dicts = self._create_message_dicts(messages, stop) answer = self.llm.client.chat(messages=message_dicts, **kwargs) return self._create_chat_result(answer) - elif isinstance(self.llm, HuggingFaceEndpoint): + elif _is_huggingface_endpoint(self.llm): message_dicts = self._create_message_dicts(messages, stop) answer = self.llm.client.chat_completion(messages=message_dicts, **kwargs) return self._create_chat_result(answer) @@ -199,7 +224,7 @@ class ChatHuggingFace(BaseChatModel): run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - if isinstance(self.llm, HuggingFaceTextGenInference): + if _is_huggingface_textgen_inference(self.llm): message_dicts = self._create_message_dicts(messages, stop) answer = await self.llm.async_client.chat(messages=message_dicts, **kwargs) return self._create_chat_result(answer) @@ -261,12 +286,12 @@ class ChatHuggingFace(BaseChatModel): from huggingface_hub import list_inference_endpoints # type: ignore[import] available_endpoints = list_inference_endpoints("*") - if isinstance(self.llm, HuggingFaceHub) or ( + if _is_huggingface_hub(self.llm) or ( hasattr(self.llm, "repo_id") and self.llm.repo_id ): self.model_id = self.llm.repo_id return - elif isinstance(self.llm, HuggingFaceTextGenInference): + elif _is_huggingface_textgen_inference(self.llm): endpoint_url: Optional[str] = self.llm.inference_server_url else: endpoint_url = self.llm.endpoint_url diff --git a/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface.py b/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface.py index 5095a0fb149..65bb5736fed 100644 --- a/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface.py +++ b/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface.py @@ -14,7 +14,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): Example: .. code-block:: python - from langchain_community.embeddings import HuggingFaceEmbeddings + from langchain_huggingface import HuggingFaceEmbeddings model_name = "sentence-transformers/all-mpnet-base-v2" model_kwargs = {'device': 'cpu'} diff --git a/libs/partners/huggingface/langchain_huggingface/llms/huggingface_pipeline.py b/libs/partners/huggingface/langchain_huggingface/llms/huggingface_pipeline.py index 8595795aa47..070f1413281 100644 --- a/libs/partners/huggingface/langchain_huggingface/llms/huggingface_pipeline.py +++ b/libs/partners/huggingface/langchain_huggingface/llms/huggingface_pipeline.py @@ -33,7 +33,7 @@ class HuggingFacePipeline(BaseLLM): Example using from_model_id: .. code-block:: python - from langchain_community.llms import HuggingFacePipeline + from langchain_huggingface import HuggingFacePipeline hf = HuggingFacePipeline.from_model_id( model_id="gpt2", task="text-generation", @@ -42,7 +42,7 @@ class HuggingFacePipeline(BaseLLM): Example passing pipeline in directly: .. code-block:: python - from langchain_community.llms import HuggingFacePipeline + from langchain_huggingface import HuggingFacePipeline from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline model_id = "gpt2" diff --git a/libs/partners/huggingface/poetry.lock b/libs/partners/huggingface/poetry.lock index c70b9c6a39d..c37f0d3e932 100644 --- a/libs/partners/huggingface/poetry.lock +++ b/libs/partners/huggingface/poetry.lock @@ -3343,4 +3343,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "3c254c1b543c7f23605ca0183b42db74fa64900300d802b15ed4c9b979bcea63" +content-hash = "dd853a4abc0cb93a17319b693279069f9cfd06838d10615217ab00e2c879c789" diff --git a/libs/partners/huggingface/pyproject.toml b/libs/partners/huggingface/pyproject.toml index c3631189bef..918f660d40c 100644 --- a/libs/partners/huggingface/pyproject.toml +++ b/libs/partners/huggingface/pyproject.toml @@ -44,14 +44,12 @@ ruff = "^0.1.5" [tool.poetry.group.typing.dependencies] mypy = "^1" langchain-core = { path = "../../core", develop = true } -langchain-community = { path = "../../community", develop = true } [tool.poetry.group.dev] optional = true [tool.poetry.group.dev.dependencies] langchain-core = { path = "../../core", develop = true } -langchain-community = { path = "../../community", develop = true } ipykernel = "^6.29.2" [tool.poetry.group.test_integration] diff --git a/libs/partners/huggingface/scripts/lint_imports.sh b/libs/partners/huggingface/scripts/lint_imports.sh index ff69b095f2f..90d21f7295a 100755 --- a/libs/partners/huggingface/scripts/lint_imports.sh +++ b/libs/partners/huggingface/scripts/lint_imports.sh @@ -8,7 +8,7 @@ errors=0 # make sure not importing from langchain or langchain_experimental git --no-pager grep '^from langchain\.' . && errors=$((errors+1)) git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1)) -# git --no-pager grep '^from langchain_community\.' . && errors=$((errors+1)) +git --no-pager grep '^from langchain_community\.' . && errors=$((errors+1)) # Decide on an exit status based on the errors if [ "$errors" -gt 0 ]; then