mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-05 19:15:44 +00:00
community[patch]: update sambastudio embeddings (#23133)
Description: update sambastudio embeddings integration, now compatible with generic endpoints and CoE endpoints
This commit is contained in:
parent
db6f46c1a6
commit
e162893d7f
@ -43,12 +43,14 @@
|
|||||||
"import os\n",
|
"import os\n",
|
||||||
"\n",
|
"\n",
|
||||||
"sambastudio_base_url = \"<Your SambaStudio environment URL>\"\n",
|
"sambastudio_base_url = \"<Your SambaStudio environment URL>\"\n",
|
||||||
|
"sambastudio_base_uri = \"<Your SambaStudio environment URI>\"\n",
|
||||||
"sambastudio_project_id = \"<Your SambaStudio project id>\"\n",
|
"sambastudio_project_id = \"<Your SambaStudio project id>\"\n",
|
||||||
"sambastudio_endpoint_id = \"<Your SambaStudio endpoint id>\"\n",
|
"sambastudio_endpoint_id = \"<Your SambaStudio endpoint id>\"\n",
|
||||||
"sambastudio_api_key = \"<Your SambaStudio endpoint API key>\"\n",
|
"sambastudio_api_key = \"<Your SambaStudio endpoint API key>\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Set the environment variables\n",
|
"# Set the environment variables\n",
|
||||||
"os.environ[\"SAMBASTUDIO_EMBEDDINGS_BASE_URL\"] = sambastudio_base_url\n",
|
"os.environ[\"SAMBASTUDIO_EMBEDDINGS_BASE_URL\"] = sambastudio_base_url\n",
|
||||||
|
"os.environ[\"SAMBASTUDIO_EMBEDDINGS_BASE_URI\"] = sambastudio_base_uri\n",
|
||||||
"os.environ[\"SAMBASTUDIO_EMBEDDINGS_PROJECT_ID\"] = sambastudio_project_id\n",
|
"os.environ[\"SAMBASTUDIO_EMBEDDINGS_PROJECT_ID\"] = sambastudio_project_id\n",
|
||||||
"os.environ[\"SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID\"] = sambastudio_endpoint_id\n",
|
"os.environ[\"SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID\"] = sambastudio_endpoint_id\n",
|
||||||
"os.environ[\"SAMBASTUDIO_EMBEDDINGS_API_KEY\"] = sambastudio_api_key"
|
"os.environ[\"SAMBASTUDIO_EMBEDDINGS_API_KEY\"] = sambastudio_api_key"
|
||||||
@ -79,6 +81,50 @@
|
|||||||
"results = embeddings.embed_documents(texts)\n",
|
"results = embeddings.embed_documents(texts)\n",
|
||||||
"print(results)"
|
"print(results)"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"You can manually pass the endpoint parameters and manually set the batch size you have in your SambaStudio embeddings endpoint"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"embeddings = SambaStudioEmbeddings(\n",
|
||||||
|
" sambastudio_embeddings_base_url=sambastudio_base_url,\n",
|
||||||
|
" sambastudio_embeddings_base_uri=sambastudio_base_uri,\n",
|
||||||
|
" sambastudio_embeddings_project_id=sambastudio_project_id,\n",
|
||||||
|
" sambastudio_embeddings_endpoint_id=sambastudio_endpoint_id,\n",
|
||||||
|
" sambastudio_embeddings_api_key=sambastudio_api_key,\n",
|
||||||
|
" batch_size=32,\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Or You can use an embedding model expert included in your deployed CoE"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"embeddings = SambaStudioEmbeddings(\n",
|
||||||
|
" batch_size=1,\n",
|
||||||
|
" model_kwargs={\n",
|
||||||
|
" \"select_expert\": \"e5-mistral-7b-instruct\",\n",
|
||||||
|
" },\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from typing import Dict, Generator, List
|
import json
|
||||||
|
from typing import Dict, Generator, List, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
@ -10,8 +11,9 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
|
|||||||
"""SambaNova embedding models.
|
"""SambaNova embedding models.
|
||||||
|
|
||||||
To use, you should have the environment variables
|
To use, you should have the environment variables
|
||||||
``SAMBASTUDIO_EMBEDDINGS_BASE_URL``, ``SAMBASTUDIO_EMBEDDINGS_PROJECT_ID``,
|
``SAMBASTUDIO_EMBEDDINGS_BASE_URL``, ``SAMBASTUDIO_EMBEDDINGS_BASE_URI``
|
||||||
``SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID``, ``SAMBASTUDIO_EMBEDDINGS_API_KEY``,
|
``SAMBASTUDIO_EMBEDDINGS_PROJECT_ID``, ``SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID``,
|
||||||
|
``SAMBASTUDIO_EMBEDDINGS_API_KEY``
|
||||||
set with your personal sambastudio variable or pass it as a named parameter
|
set with your personal sambastudio variable or pass it as a named parameter
|
||||||
to the constructor.
|
to the constructor.
|
||||||
|
|
||||||
@ -19,20 +21,34 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
|
|||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain_community.embeddings import SambaStudioEmbeddings
|
from langchain_community.embeddings import SambaStudioEmbeddings
|
||||||
|
|
||||||
embeddings = SambaStudioEmbeddings(sambastudio_embeddings_base_url=base_url,
|
embeddings = SambaStudioEmbeddings(sambastudio_embeddings_base_url=base_url,
|
||||||
|
sambastudio_embeddings_base_uri=base_uri,
|
||||||
sambastudio_embeddings_project_id=project_id,
|
sambastudio_embeddings_project_id=project_id,
|
||||||
sambastudio_embeddings_endpoint_id=endpoint_id,
|
sambastudio_embeddings_endpoint_id=endpoint_id,
|
||||||
sambastudio_embeddings_api_key=api_key)
|
sambastudio_embeddings_api_key=api_key,
|
||||||
(or)
|
batch_size=32)
|
||||||
embeddings = SambaStudioEmbeddings()
|
(or)
|
||||||
"""
|
|
||||||
|
|
||||||
API_BASE_PATH = "/api/predict/nlp/"
|
embeddings = SambaStudioEmbeddings(batch_size=32)
|
||||||
"""Base path to use for the API usage"""
|
|
||||||
|
(or)
|
||||||
|
|
||||||
|
# CoE example
|
||||||
|
embeddings = SambaStudioEmbeddings(
|
||||||
|
batch_size=1,
|
||||||
|
model_kwargs={
|
||||||
|
'select_expert':'e5-mistral-7b-instruct'
|
||||||
|
}
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
sambastudio_embeddings_base_url: str = ""
|
sambastudio_embeddings_base_url: str = ""
|
||||||
"""Base url to use"""
|
"""Base url to use"""
|
||||||
|
|
||||||
|
sambastudio_embeddings_base_uri: str = ""
|
||||||
|
"""endpoint base uri"""
|
||||||
|
|
||||||
sambastudio_embeddings_project_id: str = ""
|
sambastudio_embeddings_project_id: str = ""
|
||||||
"""Project id on sambastudio for model"""
|
"""Project id on sambastudio for model"""
|
||||||
|
|
||||||
@ -42,12 +58,24 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
|
|||||||
sambastudio_embeddings_api_key: str = ""
|
sambastudio_embeddings_api_key: str = ""
|
||||||
"""sambastudio api key"""
|
"""sambastudio api key"""
|
||||||
|
|
||||||
|
model_kwargs: dict = {}
|
||||||
|
"""Key word arguments to pass to the model."""
|
||||||
|
|
||||||
|
batch_size: int = 32
|
||||||
|
"""Batch size for the embedding models"""
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
values["sambastudio_embeddings_base_url"] = get_from_dict_or_env(
|
values["sambastudio_embeddings_base_url"] = get_from_dict_or_env(
|
||||||
values, "sambastudio_embeddings_base_url", "SAMBASTUDIO_EMBEDDINGS_BASE_URL"
|
values, "sambastudio_embeddings_base_url", "SAMBASTUDIO_EMBEDDINGS_BASE_URL"
|
||||||
)
|
)
|
||||||
|
values["sambastudio_embeddings_base_uri"] = get_from_dict_or_env(
|
||||||
|
values,
|
||||||
|
"sambastudio_embeddings_base_uri",
|
||||||
|
"SAMBASTUDIO_EMBEDDINGS_BASE_URI",
|
||||||
|
default="api/predict/generic",
|
||||||
|
)
|
||||||
values["sambastudio_embeddings_project_id"] = get_from_dict_or_env(
|
values["sambastudio_embeddings_project_id"] = get_from_dict_or_env(
|
||||||
values,
|
values,
|
||||||
"sambastudio_embeddings_project_id",
|
"sambastudio_embeddings_project_id",
|
||||||
@ -63,6 +91,20 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
|
|||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
def _get_tuning_params(self) -> str:
|
||||||
|
"""
|
||||||
|
Get the tuning parameters to use when calling the model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The tuning parameters as a JSON string.
|
||||||
|
"""
|
||||||
|
tuning_params_dict = {
|
||||||
|
k: {"type": type(v).__name__, "value": str(v)}
|
||||||
|
for k, v in (self.model_kwargs.items())
|
||||||
|
}
|
||||||
|
tuning_params = json.dumps(tuning_params_dict)
|
||||||
|
return tuning_params
|
||||||
|
|
||||||
def _get_full_url(self, path: str) -> str:
|
def _get_full_url(self, path: str) -> str:
|
||||||
"""
|
"""
|
||||||
Return the full API URL for a given path.
|
Return the full API URL for a given path.
|
||||||
@ -71,7 +113,7 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
|
|||||||
:returns: the full API URL for the sub-path
|
:returns: the full API URL for the sub-path
|
||||||
:rtype: str
|
:rtype: str
|
||||||
"""
|
"""
|
||||||
return f"{self.sambastudio_embeddings_base_url}{self.API_BASE_PATH}{path}"
|
return f"{self.sambastudio_embeddings_base_url}/{self.sambastudio_embeddings_base_uri}/{path}" # noqa: E501
|
||||||
|
|
||||||
def _iterate_over_batches(self, texts: List[str], batch_size: int) -> Generator:
|
def _iterate_over_batches(self, texts: List[str], batch_size: int) -> Generator:
|
||||||
"""Generator for creating batches in the embed documents method
|
"""Generator for creating batches in the embed documents method
|
||||||
@ -86,7 +128,7 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
|
|||||||
yield texts[i : i + batch_size]
|
yield texts[i : i + batch_size]
|
||||||
|
|
||||||
def embed_documents(
|
def embed_documents(
|
||||||
self, texts: List[str], batch_size: int = 32
|
self, texts: List[str], batch_size: Optional[int] = None
|
||||||
) -> List[List[float]]:
|
) -> List[List[float]]:
|
||||||
"""Returns a list of embeddings for the given sentences.
|
"""Returns a list of embeddings for the given sentences.
|
||||||
Args:
|
Args:
|
||||||
@ -97,22 +139,56 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
|
|||||||
`List[np.ndarray]` or `List[tensor]`: List of embeddings
|
`List[np.ndarray]` or `List[tensor]`: List of embeddings
|
||||||
for the given sentences
|
for the given sentences
|
||||||
"""
|
"""
|
||||||
|
if batch_size is None:
|
||||||
|
batch_size = self.batch_size
|
||||||
http_session = requests.Session()
|
http_session = requests.Session()
|
||||||
url = self._get_full_url(
|
url = self._get_full_url(
|
||||||
f"{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}"
|
f"{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}"
|
||||||
)
|
)
|
||||||
|
params = json.loads(self._get_tuning_params())
|
||||||
embeddings = []
|
embeddings = []
|
||||||
|
|
||||||
for batch in self._iterate_over_batches(texts, batch_size):
|
if "nlp" in self.sambastudio_embeddings_base_uri:
|
||||||
data = {"inputs": batch}
|
for batch in self._iterate_over_batches(texts, batch_size):
|
||||||
response = http_session.post(
|
data = {"inputs": batch, "params": params}
|
||||||
url,
|
response = http_session.post(
|
||||||
headers={"key": self.sambastudio_embeddings_api_key},
|
url,
|
||||||
json=data,
|
headers={"key": self.sambastudio_embeddings_api_key},
|
||||||
|
json=data,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
embedding = response.json()["data"]
|
||||||
|
embeddings.extend(embedding)
|
||||||
|
except KeyError:
|
||||||
|
raise KeyError(
|
||||||
|
"'data' not found in endpoint response",
|
||||||
|
response.json(),
|
||||||
|
)
|
||||||
|
|
||||||
|
elif "generic" in self.sambastudio_embeddings_base_uri:
|
||||||
|
for batch in self._iterate_over_batches(texts, batch_size):
|
||||||
|
data = {"instances": batch, "params": params}
|
||||||
|
response = http_session.post(
|
||||||
|
url,
|
||||||
|
headers={"key": self.sambastudio_embeddings_api_key},
|
||||||
|
json=data,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
if params.get("select_expert"):
|
||||||
|
embedding = response.json()["predictions"][0]
|
||||||
|
else:
|
||||||
|
embedding = response.json()["predictions"]
|
||||||
|
embeddings.extend(embedding)
|
||||||
|
except KeyError:
|
||||||
|
raise KeyError(
|
||||||
|
"'predictions' not found in endpoint response",
|
||||||
|
response.json(),
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"handling of endpoint uri: {self.sambastudio_embeddings_base_uri} not implemented" # noqa: E501
|
||||||
)
|
)
|
||||||
embedding = response.json()["data"]
|
|
||||||
embeddings.extend(embedding)
|
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
@ -129,14 +205,44 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
|
|||||||
url = self._get_full_url(
|
url = self._get_full_url(
|
||||||
f"{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}"
|
f"{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}"
|
||||||
)
|
)
|
||||||
|
params = json.loads(self._get_tuning_params())
|
||||||
|
|
||||||
data = {"inputs": [text]}
|
if "nlp" in self.sambastudio_embeddings_base_uri:
|
||||||
|
data = {"inputs": [text], "params": params}
|
||||||
|
response = http_session.post(
|
||||||
|
url,
|
||||||
|
headers={"key": self.sambastudio_embeddings_api_key},
|
||||||
|
json=data,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
embedding = response.json()["data"][0]
|
||||||
|
except KeyError:
|
||||||
|
raise KeyError(
|
||||||
|
"'data' not found in endpoint response",
|
||||||
|
response.json(),
|
||||||
|
)
|
||||||
|
|
||||||
response = http_session.post(
|
elif "generic" in self.sambastudio_embeddings_base_uri:
|
||||||
url,
|
data = {"instances": [text], "params": params}
|
||||||
headers={"key": self.sambastudio_embeddings_api_key},
|
response = http_session.post(
|
||||||
json=data,
|
url,
|
||||||
)
|
headers={"key": self.sambastudio_embeddings_api_key},
|
||||||
embedding = response.json()["data"][0]
|
json=data,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
if params.get("select_expert"):
|
||||||
|
embedding = response.json()["predictions"][0][0]
|
||||||
|
else:
|
||||||
|
embedding = response.json()["predictions"][0]
|
||||||
|
except KeyError:
|
||||||
|
raise KeyError(
|
||||||
|
"'predictions' not found in endpoint response",
|
||||||
|
response.json(),
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"handling of endpoint uri: {self.sambastudio_embeddings_base_uri} not implemented" # noqa: E501
|
||||||
|
)
|
||||||
|
|
||||||
return embedding
|
return embedding
|
||||||
|
Loading…
Reference in New Issue
Block a user