Community: sambastudio embeddings GenericV2 API support (#25064)

- **Description:** 
        SambaStudio GenericV2 API support 
        Minor changes for requests error handling
This commit is contained in:
Jorge Piedrahita Ortiz
2024-08-29 08:52:49 -05:00
committed by GitHub
parent bdce9a47d0
commit 9ac953a948

View File

@@ -98,10 +98,13 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
Returns:
The tuning parameters as a JSON string.
"""
tuning_params_dict = {
k: {"type": type(v).__name__, "value": str(v)}
for k, v in (self.model_kwargs.items())
}
if "api/v2/predict/generic" in self.sambastudio_embeddings_base_uri:
tuning_params_dict = self.model_kwargs
else:
tuning_params_dict = {
k: {"type": type(v).__name__, "value": str(v)}
for k, v in (self.model_kwargs.items())
}
tuning_params = json.dumps(tuning_params_dict)
return tuning_params
@@ -148,7 +151,7 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
params = json.loads(self._get_tuning_params())
embeddings = []
if "nlp" in self.sambastudio_embeddings_base_uri:
if "api/predict/nlp" in self.sambastudio_embeddings_base_uri:
for batch in self._iterate_over_batches(texts, batch_size):
data = {"inputs": batch, "params": params}
response = http_session.post(
@@ -156,6 +159,11 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
headers={"key": self.sambastudio_embeddings_api_key},
json=data,
)
if response.status_code != 200:
raise RuntimeError(
f"Sambanova /complete call failed with status code "
f"{response.status_code}.\n Details: {response.text}"
)
try:
embedding = response.json()["data"]
embeddings.extend(embedding)
@@ -165,7 +173,32 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
response.json(),
)
elif "generic" in self.sambastudio_embeddings_base_uri:
elif "api/v2/predict/generic" in self.sambastudio_embeddings_base_uri:
for batch in self._iterate_over_batches(texts, batch_size):
items = [
{"id": f"item{i}", "value": item} for i, item in enumerate(batch)
]
data = {"items": items, "params": params}
response = http_session.post(
url,
headers={"key": self.sambastudio_embeddings_api_key},
json=data,
)
if response.status_code != 200:
raise RuntimeError(
f"Sambanova /complete call failed with status code "
f"{response.status_code}.\n Details: {response.text}"
)
try:
embedding = [item["value"] for item in response.json()["items"]]
embeddings.extend(embedding)
except KeyError:
raise KeyError(
"'items' not found in endpoint response",
response.json(),
)
elif "api/predict/generic" in self.sambastudio_embeddings_base_uri:
for batch in self._iterate_over_batches(texts, batch_size):
data = {"instances": batch, "params": params}
response = http_session.post(
@@ -173,6 +206,11 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
headers={"key": self.sambastudio_embeddings_api_key},
json=data,
)
if response.status_code != 200:
raise RuntimeError(
f"Sambanova /complete call failed with status code "
f"{response.status_code}.\n Details: {response.text}"
)
try:
if params.get("select_expert"):
embedding = response.json()["predictions"][0]
@@ -207,13 +245,18 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
)
params = json.loads(self._get_tuning_params())
if "nlp" in self.sambastudio_embeddings_base_uri:
if "api/predict/nlp" in self.sambastudio_embeddings_base_uri:
data = {"inputs": [text], "params": params}
response = http_session.post(
url,
headers={"key": self.sambastudio_embeddings_api_key},
json=data,
)
if response.status_code != 200:
raise RuntimeError(
f"Sambanova /complete call failed with status code "
f"{response.status_code}.\n Details: {response.text}"
)
try:
embedding = response.json()["data"][0]
except KeyError:
@@ -222,13 +265,38 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
response.json(),
)
elif "generic" in self.sambastudio_embeddings_base_uri:
elif "api/v2/predict/generic" in self.sambastudio_embeddings_base_uri:
data = {"items": [{"id": "item0", "value": text}], "params": params}
response = http_session.post(
url,
headers={"key": self.sambastudio_embeddings_api_key},
json=data,
)
if response.status_code != 200:
raise RuntimeError(
f"Sambanova /complete call failed with status code "
f"{response.status_code}.\n Details: {response.text}"
)
try:
embedding = response.json()["items"][0]["value"]
except KeyError:
raise KeyError(
"'items' not found in endpoint response",
response.json(),
)
elif "api/predict/generic" in self.sambastudio_embeddings_base_uri:
data = {"instances": [text], "params": params}
response = http_session.post(
url,
headers={"key": self.sambastudio_embeddings_api_key},
json=data,
)
if response.status_code != 200:
raise RuntimeError(
f"Sambanova /complete call failed with status code "
f"{response.status_code}.\n Details: {response.text}"
)
try:
if params.get("select_expert"):
embedding = response.json()["predictions"][0][0]