mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-05 22:53:30 +00:00
fix(embeddings): huggingface hub embeddings and TEI (#14489)
**Description:** This PR fixes `HuggingFaceHubEmbeddings` by making the API token optional (as in the client beneath). Most models don't require one. I also updated the notebook for TEI (text-embeddings-inference) accordingly as requested here #14288. In addition, I fixed a mistake in the POST call parameters. **Tag maintainers:** @baskaryan
This commit is contained in:
parent
5da79e150b
commit
6080c98108
@ -14,7 +14,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 1,
|
||||||
"id": "579f0677-aa06-4ad8-a816-3520c8d6923c",
|
"id": "579f0677-aa06-4ad8-a816-3520c8d6923c",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -50,7 +50,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 2,
|
||||||
"id": "22b09777-5ba3-4fbe-81cf-a702a55df9c4",
|
"id": "22b09777-5ba3-4fbe-81cf-a702a55df9c4",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -62,45 +62,19 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 3,
|
||||||
"id": "c26fca9f-cfdb-45e5-a0bd-f677ff8b9d92",
|
|
||||||
"metadata": {
|
|
||||||
"tags": []
|
|
||||||
},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdin",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Enter your HF API Key:\n",
|
|
||||||
"\n",
|
|
||||||
" ········\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"from getpass import getpass\n",
|
|
||||||
"\n",
|
|
||||||
"huggingfacehub_api_token = getpass(\"Enter your HF API Key:\\n\\n\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 6,
|
|
||||||
"id": "f9a92970-16f4-458c-b186-2a83e9f7d840",
|
"id": "f9a92970-16f4-458c-b186-2a83e9f7d840",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"embeddings = HuggingFaceHubEmbeddings(\n",
|
"embeddings = HuggingFaceHubEmbeddings(model=\"http://localhost:8080\")"
|
||||||
" model=\"http://localhost:8080\", huggingfacehub_api_token=huggingfacehub_api_token\n",
|
|
||||||
")"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 4,
|
||||||
"id": "42105438-9fee-460a-9c52-b7c595722758",
|
"id": "42105438-9fee-460a-9c52-b7c595722758",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -112,7 +86,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 5,
|
||||||
"id": "20167762-0988-4205-bbd4-1f20fd9dd247",
|
"id": "20167762-0988-4205-bbd4-1f20fd9dd247",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -124,7 +98,7 @@
|
|||||||
"[0.018113142, 0.00302585, -0.049911194]"
|
"[0.018113142, 0.00302585, -0.049911194]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 8,
|
"execution_count": 5,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -136,7 +110,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 9,
|
"execution_count": 6,
|
||||||
"id": "54b87cf6-86ad-46f5-b2cd-17eb43cb4d0b",
|
"id": "54b87cf6-86ad-46f5-b2cd-17eb43cb4d0b",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -145,6 +119,29 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"doc_result = embeddings.embed_documents([text])"
|
"doc_result = embeddings.embed_documents([text])"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "6fba8be9-fabf-4972-8334-aa56ed9893e1",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"[0.018113142, 0.00302585, -0.049911194]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"doc_result[0][:3]"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
||||||
from langchain_core.utils import get_from_dict_or_env
|
|
||||||
|
|
||||||
DEFAULT_MODEL = "sentence-transformers/all-mpnet-base-v2"
|
DEFAULT_MODEL = "sentence-transformers/all-mpnet-base-v2"
|
||||||
VALID_TASKS = ("feature-extraction",)
|
VALID_TASKS = ("feature-extraction",)
|
||||||
@ -48,9 +48,10 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings):
|
|||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
huggingfacehub_api_token = get_from_dict_or_env(
|
huggingfacehub_api_token = values["huggingfacehub_api_token"] or os.getenv(
|
||||||
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
|
"HUGGINGFACEHUB_API_TOKEN"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from huggingface_hub import InferenceClient
|
from huggingface_hub import InferenceClient
|
||||||
|
|
||||||
@ -92,7 +93,7 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings):
|
|||||||
texts = [text.replace("\n", " ") for text in texts]
|
texts = [text.replace("\n", " ") for text in texts]
|
||||||
_model_kwargs = self.model_kwargs or {}
|
_model_kwargs = self.model_kwargs or {}
|
||||||
responses = self.client.post(
|
responses = self.client.post(
|
||||||
json={"inputs": texts, "parameters": _model_kwargs, "task": self.task}
|
json={"inputs": texts, "parameters": _model_kwargs}, task=self.task
|
||||||
)
|
)
|
||||||
return json.loads(responses.decode())
|
return json.loads(responses.decode())
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user