Community: Fuse HuggingFace Endpoint-related classes into one (#17254)

## Description
Fuse HuggingFace Endpoint-related classes into one:
-
[HuggingFaceHub](5ceaf784f3/libs/community/langchain_community/llms/huggingface_hub.py)
-
[HuggingFaceTextGenInference](5ceaf784f3/libs/community/langchain_community/llms/huggingface_text_gen_inference.py)
- and
[HuggingFaceEndpoint](5ceaf784f3/libs/community/langchain_community/llms/huggingface_endpoint.py)

Are fused into
- HuggingFaceEndpoint

## Issue
The deduplication of classes was creating a lack of clarity, and
additional effort to develop classes leads to issues like [this
hack](5ceaf784f3/libs/community/langchain_community/llms/huggingface_endpoint.py (L159)).

## Dependancies

None, this removes dependancies.

## Twitter handle

If you want to post about this: @AymericRoucher

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Aymeric Roucher
2024-02-19 19:33:15 +01:00
committed by GitHub
parent 8009be862e
commit 0d294760e7
10 changed files with 632 additions and 745 deletions

View File

@@ -1,4 +1,5 @@
"""Hugging Face Chat Wrapper."""
from typing import Any, List, Optional, Union
from langchain_core.callbacks.manager import (
@@ -52,6 +53,7 @@ class ChatHuggingFace(BaseChatModel):
from transformers import AutoTokenizer
self._resolve_model_id()
self.tokenizer = (
AutoTokenizer.from_pretrained(self.model_id)
if self.tokenizer is None
@@ -90,10 +92,10 @@ class ChatHuggingFace(BaseChatModel):
) -> str:
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
if not messages:
raise ValueError("at least one HumanMessage must be provided")
raise ValueError("At least one HumanMessage must be provided!")
if not isinstance(messages[-1], HumanMessage):
raise ValueError("last message must be a HumanMessage")
raise ValueError("Last message must be a HumanMessage!")
messages_dicts = [self._to_chatml_format(m) for m in messages]
@@ -135,20 +137,15 @@ class ChatHuggingFace(BaseChatModel):
from huggingface_hub import list_inference_endpoints
available_endpoints = list_inference_endpoints("*")
if isinstance(self.llm, HuggingFaceTextGenInference):
endpoint_url = self.llm.inference_server_url
elif isinstance(self.llm, HuggingFaceEndpoint):
endpoint_url = self.llm.endpoint_url
elif isinstance(self.llm, HuggingFaceHub):
# no need to look up model_id for HuggingFaceHub LLM
if isinstance(self.llm, HuggingFaceHub) or (
hasattr(self.llm, "repo_id") and self.llm.repo_id
):
self.model_id = self.llm.repo_id
return
elif isinstance(self.llm, HuggingFaceTextGenInference):
endpoint_url: Optional[str] = self.llm.inference_server_url
else:
raise ValueError(f"Unknown LLM type: {type(self.llm)}")
endpoint_url = self.llm.endpoint_url
for endpoint in available_endpoints:
if endpoint.url == endpoint_url:
@@ -156,8 +153,8 @@ class ChatHuggingFace(BaseChatModel):
if not self.model_id:
raise ValueError(
"Failed to resolve model_id"
f"Could not find model id for inference server provided: {endpoint_url}"
"Failed to resolve model_id:"
f"Could not find model id for inference server: {endpoint_url}"
"Make sure that your Hugging Face token has access to the endpoint."
)