community[patch]: update sambastudio embeddings (#23133)

Description: update sambastudio embeddings integration, now compatible
with generic endpoints and CoE endpoints
This commit is contained in:
Jorge Piedrahita Ortiz 2024-06-19 12:26:56 -05:00 committed by GitHub
parent db6f46c1a6
commit e162893d7f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 179 additions and 27 deletions

View File

@ -43,12 +43,14 @@
"import os\n", "import os\n",
"\n", "\n",
"sambastudio_base_url = \"<Your SambaStudio environment URL>\"\n", "sambastudio_base_url = \"<Your SambaStudio environment URL>\"\n",
"sambastudio_base_uri = \"<Your SambaStudio environment URI>\"\n",
"sambastudio_project_id = \"<Your SambaStudio project id>\"\n", "sambastudio_project_id = \"<Your SambaStudio project id>\"\n",
"sambastudio_endpoint_id = \"<Your SambaStudio endpoint id>\"\n", "sambastudio_endpoint_id = \"<Your SambaStudio endpoint id>\"\n",
"sambastudio_api_key = \"<Your SambaStudio endpoint API key>\"\n", "sambastudio_api_key = \"<Your SambaStudio endpoint API key>\"\n",
"\n", "\n",
"# Set the environment variables\n", "# Set the environment variables\n",
"os.environ[\"SAMBASTUDIO_EMBEDDINGS_BASE_URL\"] = sambastudio_base_url\n", "os.environ[\"SAMBASTUDIO_EMBEDDINGS_BASE_URL\"] = sambastudio_base_url\n",
"os.environ[\"SAMBASTUDIO_EMBEDDINGS_BASE_URI\"] = sambastudio_base_uri\n",
"os.environ[\"SAMBASTUDIO_EMBEDDINGS_PROJECT_ID\"] = sambastudio_project_id\n", "os.environ[\"SAMBASTUDIO_EMBEDDINGS_PROJECT_ID\"] = sambastudio_project_id\n",
"os.environ[\"SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID\"] = sambastudio_endpoint_id\n", "os.environ[\"SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID\"] = sambastudio_endpoint_id\n",
"os.environ[\"SAMBASTUDIO_EMBEDDINGS_API_KEY\"] = sambastudio_api_key" "os.environ[\"SAMBASTUDIO_EMBEDDINGS_API_KEY\"] = sambastudio_api_key"
@ -79,6 +81,50 @@
"results = embeddings.embed_documents(texts)\n", "results = embeddings.embed_documents(texts)\n",
"print(results)" "print(results)"
] ]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can manually pass the endpoint parameters and manually set the batch size you have in your SambaStudio embeddings endpoint"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"embeddings = SambaStudioEmbeddings(\n",
" sambastudio_embeddings_base_url=sambastudio_base_url,\n",
" sambastudio_embeddings_base_uri=sambastudio_base_uri,\n",
" sambastudio_embeddings_project_id=sambastudio_project_id,\n",
" sambastudio_embeddings_endpoint_id=sambastudio_endpoint_id,\n",
" sambastudio_embeddings_api_key=sambastudio_api_key,\n",
" batch_size=32,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Or You can use an embedding model expert included in your deployed CoE"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"embeddings = SambaStudioEmbeddings(\n",
" batch_size=1,\n",
" model_kwargs={\n",
" \"select_expert\": \"e5-mistral-7b-instruct\",\n",
" },\n",
")"
]
} }
], ],
"metadata": { "metadata": {

View File

@ -1,4 +1,5 @@
from typing import Dict, Generator, List import json
from typing import Dict, Generator, List, Optional
import requests import requests
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
@ -10,8 +11,9 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
"""SambaNova embedding models. """SambaNova embedding models.
To use, you should have the environment variables To use, you should have the environment variables
``SAMBASTUDIO_EMBEDDINGS_BASE_URL``, ``SAMBASTUDIO_EMBEDDINGS_PROJECT_ID``, ``SAMBASTUDIO_EMBEDDINGS_BASE_URL``, ``SAMBASTUDIO_EMBEDDINGS_BASE_URI``
``SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID``, ``SAMBASTUDIO_EMBEDDINGS_API_KEY``, ``SAMBASTUDIO_EMBEDDINGS_PROJECT_ID``, ``SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID``,
``SAMBASTUDIO_EMBEDDINGS_API_KEY``
set with your personal sambastudio variable or pass it as a named parameter set with your personal sambastudio variable or pass it as a named parameter
to the constructor. to the constructor.
@ -19,20 +21,34 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
.. code-block:: python .. code-block:: python
from langchain_community.embeddings import SambaStudioEmbeddings from langchain_community.embeddings import SambaStudioEmbeddings
embeddings = SambaStudioEmbeddings(sambastudio_embeddings_base_url=base_url, embeddings = SambaStudioEmbeddings(sambastudio_embeddings_base_url=base_url,
sambastudio_embeddings_base_uri=base_uri,
sambastudio_embeddings_project_id=project_id, sambastudio_embeddings_project_id=project_id,
sambastudio_embeddings_endpoint_id=endpoint_id, sambastudio_embeddings_endpoint_id=endpoint_id,
sambastudio_embeddings_api_key=api_key) sambastudio_embeddings_api_key=api_key,
(or) batch_size=32)
embeddings = SambaStudioEmbeddings() (or)
"""
API_BASE_PATH = "/api/predict/nlp/" embeddings = SambaStudioEmbeddings(batch_size=32)
"""Base path to use for the API usage"""
(or)
# CoE example
embeddings = SambaStudioEmbeddings(
batch_size=1,
model_kwargs={
'select_expert':'e5-mistral-7b-instruct'
}
)
"""
sambastudio_embeddings_base_url: str = "" sambastudio_embeddings_base_url: str = ""
"""Base url to use""" """Base url to use"""
sambastudio_embeddings_base_uri: str = ""
"""endpoint base uri"""
sambastudio_embeddings_project_id: str = "" sambastudio_embeddings_project_id: str = ""
"""Project id on sambastudio for model""" """Project id on sambastudio for model"""
@ -42,12 +58,24 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
sambastudio_embeddings_api_key: str = "" sambastudio_embeddings_api_key: str = ""
"""sambastudio api key""" """sambastudio api key"""
model_kwargs: dict = {}
"""Key word arguments to pass to the model."""
batch_size: int = 32
"""Batch size for the embedding models"""
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
values["sambastudio_embeddings_base_url"] = get_from_dict_or_env( values["sambastudio_embeddings_base_url"] = get_from_dict_or_env(
values, "sambastudio_embeddings_base_url", "SAMBASTUDIO_EMBEDDINGS_BASE_URL" values, "sambastudio_embeddings_base_url", "SAMBASTUDIO_EMBEDDINGS_BASE_URL"
) )
values["sambastudio_embeddings_base_uri"] = get_from_dict_or_env(
values,
"sambastudio_embeddings_base_uri",
"SAMBASTUDIO_EMBEDDINGS_BASE_URI",
default="api/predict/generic",
)
values["sambastudio_embeddings_project_id"] = get_from_dict_or_env( values["sambastudio_embeddings_project_id"] = get_from_dict_or_env(
values, values,
"sambastudio_embeddings_project_id", "sambastudio_embeddings_project_id",
@ -63,6 +91,20 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
) )
return values return values
def _get_tuning_params(self) -> str:
"""
Get the tuning parameters to use when calling the model
Returns:
The tuning parameters as a JSON string.
"""
tuning_params_dict = {
k: {"type": type(v).__name__, "value": str(v)}
for k, v in (self.model_kwargs.items())
}
tuning_params = json.dumps(tuning_params_dict)
return tuning_params
def _get_full_url(self, path: str) -> str: def _get_full_url(self, path: str) -> str:
""" """
Return the full API URL for a given path. Return the full API URL for a given path.
@ -71,7 +113,7 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
:returns: the full API URL for the sub-path :returns: the full API URL for the sub-path
:rtype: str :rtype: str
""" """
return f"{self.sambastudio_embeddings_base_url}{self.API_BASE_PATH}{path}" return f"{self.sambastudio_embeddings_base_url}/{self.sambastudio_embeddings_base_uri}/{path}" # noqa: E501
def _iterate_over_batches(self, texts: List[str], batch_size: int) -> Generator: def _iterate_over_batches(self, texts: List[str], batch_size: int) -> Generator:
"""Generator for creating batches in the embed documents method """Generator for creating batches in the embed documents method
@ -86,7 +128,7 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
yield texts[i : i + batch_size] yield texts[i : i + batch_size]
def embed_documents( def embed_documents(
self, texts: List[str], batch_size: int = 32 self, texts: List[str], batch_size: Optional[int] = None
) -> List[List[float]]: ) -> List[List[float]]:
"""Returns a list of embeddings for the given sentences. """Returns a list of embeddings for the given sentences.
Args: Args:
@ -97,22 +139,56 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
`List[np.ndarray]` or `List[tensor]`: List of embeddings `List[np.ndarray]` or `List[tensor]`: List of embeddings
for the given sentences for the given sentences
""" """
if batch_size is None:
batch_size = self.batch_size
http_session = requests.Session() http_session = requests.Session()
url = self._get_full_url( url = self._get_full_url(
f"{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}" f"{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}"
) )
params = json.loads(self._get_tuning_params())
embeddings = [] embeddings = []
for batch in self._iterate_over_batches(texts, batch_size): if "nlp" in self.sambastudio_embeddings_base_uri:
data = {"inputs": batch} for batch in self._iterate_over_batches(texts, batch_size):
response = http_session.post( data = {"inputs": batch, "params": params}
url, response = http_session.post(
headers={"key": self.sambastudio_embeddings_api_key}, url,
json=data, headers={"key": self.sambastudio_embeddings_api_key},
json=data,
)
try:
embedding = response.json()["data"]
embeddings.extend(embedding)
except KeyError:
raise KeyError(
"'data' not found in endpoint response",
response.json(),
)
elif "generic" in self.sambastudio_embeddings_base_uri:
for batch in self._iterate_over_batches(texts, batch_size):
data = {"instances": batch, "params": params}
response = http_session.post(
url,
headers={"key": self.sambastudio_embeddings_api_key},
json=data,
)
try:
if params.get("select_expert"):
embedding = response.json()["predictions"][0]
else:
embedding = response.json()["predictions"]
embeddings.extend(embedding)
except KeyError:
raise KeyError(
"'predictions' not found in endpoint response",
response.json(),
)
else:
raise ValueError(
f"handling of endpoint uri: {self.sambastudio_embeddings_base_uri} not implemented" # noqa: E501
) )
embedding = response.json()["data"]
embeddings.extend(embedding)
return embeddings return embeddings
@ -129,14 +205,44 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
url = self._get_full_url( url = self._get_full_url(
f"{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}" f"{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}"
) )
params = json.loads(self._get_tuning_params())
data = {"inputs": [text]} if "nlp" in self.sambastudio_embeddings_base_uri:
data = {"inputs": [text], "params": params}
response = http_session.post(
url,
headers={"key": self.sambastudio_embeddings_api_key},
json=data,
)
try:
embedding = response.json()["data"][0]
except KeyError:
raise KeyError(
"'data' not found in endpoint response",
response.json(),
)
response = http_session.post( elif "generic" in self.sambastudio_embeddings_base_uri:
url, data = {"instances": [text], "params": params}
headers={"key": self.sambastudio_embeddings_api_key}, response = http_session.post(
json=data, url,
) headers={"key": self.sambastudio_embeddings_api_key},
embedding = response.json()["data"][0] json=data,
)
try:
if params.get("select_expert"):
embedding = response.json()["predictions"][0][0]
else:
embedding = response.json()["predictions"][0]
except KeyError:
raise KeyError(
"'predictions' not found in endpoint response",
response.json(),
)
else:
raise ValueError(
f"handling of endpoint uri: {self.sambastudio_embeddings_base_uri} not implemented" # noqa: E501
)
return embedding return embedding