mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-31 18:38:48 +00:00
Community: sambastudio embeddings GenericV2 API support (#25064)
- **Description:** SambaStudio GenericV2 API support Minor changes for requests error handling
This commit is contained in:
committed by
GitHub
parent
bdce9a47d0
commit
9ac953a948
@@ -98,10 +98,13 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
The tuning parameters as a JSON string.
|
||||
"""
|
||||
tuning_params_dict = {
|
||||
k: {"type": type(v).__name__, "value": str(v)}
|
||||
for k, v in (self.model_kwargs.items())
|
||||
}
|
||||
if "api/v2/predict/generic" in self.sambastudio_embeddings_base_uri:
|
||||
tuning_params_dict = self.model_kwargs
|
||||
else:
|
||||
tuning_params_dict = {
|
||||
k: {"type": type(v).__name__, "value": str(v)}
|
||||
for k, v in (self.model_kwargs.items())
|
||||
}
|
||||
tuning_params = json.dumps(tuning_params_dict)
|
||||
return tuning_params
|
||||
|
||||
@@ -148,7 +151,7 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
|
||||
params = json.loads(self._get_tuning_params())
|
||||
embeddings = []
|
||||
|
||||
if "nlp" in self.sambastudio_embeddings_base_uri:
|
||||
if "api/predict/nlp" in self.sambastudio_embeddings_base_uri:
|
||||
for batch in self._iterate_over_batches(texts, batch_size):
|
||||
data = {"inputs": batch, "params": params}
|
||||
response = http_session.post(
|
||||
@@ -156,6 +159,11 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
|
||||
headers={"key": self.sambastudio_embeddings_api_key},
|
||||
json=data,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{response.status_code}.\n Details: {response.text}"
|
||||
)
|
||||
try:
|
||||
embedding = response.json()["data"]
|
||||
embeddings.extend(embedding)
|
||||
@@ -165,7 +173,32 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
|
||||
response.json(),
|
||||
)
|
||||
|
||||
elif "generic" in self.sambastudio_embeddings_base_uri:
|
||||
elif "api/v2/predict/generic" in self.sambastudio_embeddings_base_uri:
|
||||
for batch in self._iterate_over_batches(texts, batch_size):
|
||||
items = [
|
||||
{"id": f"item{i}", "value": item} for i, item in enumerate(batch)
|
||||
]
|
||||
data = {"items": items, "params": params}
|
||||
response = http_session.post(
|
||||
url,
|
||||
headers={"key": self.sambastudio_embeddings_api_key},
|
||||
json=data,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{response.status_code}.\n Details: {response.text}"
|
||||
)
|
||||
try:
|
||||
embedding = [item["value"] for item in response.json()["items"]]
|
||||
embeddings.extend(embedding)
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
"'items' not found in endpoint response",
|
||||
response.json(),
|
||||
)
|
||||
|
||||
elif "api/predict/generic" in self.sambastudio_embeddings_base_uri:
|
||||
for batch in self._iterate_over_batches(texts, batch_size):
|
||||
data = {"instances": batch, "params": params}
|
||||
response = http_session.post(
|
||||
@@ -173,6 +206,11 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
|
||||
headers={"key": self.sambastudio_embeddings_api_key},
|
||||
json=data,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{response.status_code}.\n Details: {response.text}"
|
||||
)
|
||||
try:
|
||||
if params.get("select_expert"):
|
||||
embedding = response.json()["predictions"][0]
|
||||
@@ -207,13 +245,18 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
params = json.loads(self._get_tuning_params())
|
||||
|
||||
if "nlp" in self.sambastudio_embeddings_base_uri:
|
||||
if "api/predict/nlp" in self.sambastudio_embeddings_base_uri:
|
||||
data = {"inputs": [text], "params": params}
|
||||
response = http_session.post(
|
||||
url,
|
||||
headers={"key": self.sambastudio_embeddings_api_key},
|
||||
json=data,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{response.status_code}.\n Details: {response.text}"
|
||||
)
|
||||
try:
|
||||
embedding = response.json()["data"][0]
|
||||
except KeyError:
|
||||
@@ -222,13 +265,38 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
|
||||
response.json(),
|
||||
)
|
||||
|
||||
elif "generic" in self.sambastudio_embeddings_base_uri:
|
||||
elif "api/v2/predict/generic" in self.sambastudio_embeddings_base_uri:
|
||||
data = {"items": [{"id": "item0", "value": text}], "params": params}
|
||||
response = http_session.post(
|
||||
url,
|
||||
headers={"key": self.sambastudio_embeddings_api_key},
|
||||
json=data,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{response.status_code}.\n Details: {response.text}"
|
||||
)
|
||||
try:
|
||||
embedding = response.json()["items"][0]["value"]
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
"'items' not found in endpoint response",
|
||||
response.json(),
|
||||
)
|
||||
|
||||
elif "api/predict/generic" in self.sambastudio_embeddings_base_uri:
|
||||
data = {"instances": [text], "params": params}
|
||||
response = http_session.post(
|
||||
url,
|
||||
headers={"key": self.sambastudio_embeddings_api_key},
|
||||
json=data,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{response.status_code}.\n Details: {response.text}"
|
||||
)
|
||||
try:
|
||||
if params.get("select_expert"):
|
||||
embedding = response.json()["predictions"][0][0]
|
||||
|
Reference in New Issue
Block a user