mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-28 18:48:50 +00:00
fix(embeddings): huggingface hub embeddings and TEI (#14489)
**Description:** This PR fixes `HuggingFaceHubEmbeddings` by making the API token optional (as in the client beneath). Most models don't require one. I also updated the notebook for TEI (text-embeddings-inference) accordingly as requested here #14288. In addition, I fixed a mistake in the POST call parameters. **Tag maintainers:** @baskaryan
This commit is contained in:
parent
5da79e150b
commit
6080c98108
@ -14,7 +14,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 1,
|
||||
"id": "579f0677-aa06-4ad8-a816-3520c8d6923c",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -50,7 +50,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 2,
|
||||
"id": "22b09777-5ba3-4fbe-81cf-a702a55df9c4",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -62,45 +62,19 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "c26fca9f-cfdb-45e5-a0bd-f677ff8b9d92",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdin",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Enter your HF API Key:\n",
|
||||
"\n",
|
||||
" ········\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from getpass import getpass\n",
|
||||
"\n",
|
||||
"huggingfacehub_api_token = getpass(\"Enter your HF API Key:\\n\\n\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 3,
|
||||
"id": "f9a92970-16f4-458c-b186-2a83e9f7d840",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"embeddings = HuggingFaceHubEmbeddings(\n",
|
||||
" model=\"http://localhost:8080\", huggingfacehub_api_token=huggingfacehub_api_token\n",
|
||||
")"
|
||||
"embeddings = HuggingFaceHubEmbeddings(model=\"http://localhost:8080\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 4,
|
||||
"id": "42105438-9fee-460a-9c52-b7c595722758",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -112,7 +86,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 5,
|
||||
"id": "20167762-0988-4205-bbd4-1f20fd9dd247",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -124,7 +98,7 @@
|
||||
"[0.018113142, 0.00302585, -0.049911194]"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -136,7 +110,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 6,
|
||||
"id": "54b87cf6-86ad-46f5-b2cd-17eb43cb4d0b",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -145,6 +119,29 @@
|
||||
"source": [
|
||||
"doc_result = embeddings.embed_documents([text])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "6fba8be9-fabf-4972-8334-aa56ed9893e1",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[0.018113142, 0.00302585, -0.049911194]"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"doc_result[0][:3]"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
@ -1,9 +1,9 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
DEFAULT_MODEL = "sentence-transformers/all-mpnet-base-v2"
|
||||
VALID_TASKS = ("feature-extraction",)
|
||||
@ -48,9 +48,10 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings):
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
huggingfacehub_api_token = get_from_dict_or_env(
|
||||
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
|
||||
huggingfacehub_api_token = values["huggingfacehub_api_token"] or os.getenv(
|
||||
"HUGGINGFACEHUB_API_TOKEN"
|
||||
)
|
||||
|
||||
try:
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
@ -92,7 +93,7 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings):
|
||||
texts = [text.replace("\n", " ") for text in texts]
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
responses = self.client.post(
|
||||
json={"inputs": texts, "parameters": _model_kwargs, "task": self.task}
|
||||
json={"inputs": texts, "parameters": _model_kwargs}, task=self.task
|
||||
)
|
||||
return json.loads(responses.decode())
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user