mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 02:50:47 +00:00
Community: sambastudio embeddings GenericV2 API support (#25064)
- **Description:** SambaStudio GenericV2 API support Minor changes for requests error handling
This commit is contained in:
committed by
GitHub
parent
bdce9a47d0
commit
9ac953a948
@@ -98,10 +98,13 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
The tuning parameters as a JSON string.
|
The tuning parameters as a JSON string.
|
||||||
"""
|
"""
|
||||||
tuning_params_dict = {
|
if "api/v2/predict/generic" in self.sambastudio_embeddings_base_uri:
|
||||||
k: {"type": type(v).__name__, "value": str(v)}
|
tuning_params_dict = self.model_kwargs
|
||||||
for k, v in (self.model_kwargs.items())
|
else:
|
||||||
}
|
tuning_params_dict = {
|
||||||
|
k: {"type": type(v).__name__, "value": str(v)}
|
||||||
|
for k, v in (self.model_kwargs.items())
|
||||||
|
}
|
||||||
tuning_params = json.dumps(tuning_params_dict)
|
tuning_params = json.dumps(tuning_params_dict)
|
||||||
return tuning_params
|
return tuning_params
|
||||||
|
|
||||||
@@ -148,7 +151,7 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
|
|||||||
params = json.loads(self._get_tuning_params())
|
params = json.loads(self._get_tuning_params())
|
||||||
embeddings = []
|
embeddings = []
|
||||||
|
|
||||||
if "nlp" in self.sambastudio_embeddings_base_uri:
|
if "api/predict/nlp" in self.sambastudio_embeddings_base_uri:
|
||||||
for batch in self._iterate_over_batches(texts, batch_size):
|
for batch in self._iterate_over_batches(texts, batch_size):
|
||||||
data = {"inputs": batch, "params": params}
|
data = {"inputs": batch, "params": params}
|
||||||
response = http_session.post(
|
response = http_session.post(
|
||||||
@@ -156,6 +159,11 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
|
|||||||
headers={"key": self.sambastudio_embeddings_api_key},
|
headers={"key": self.sambastudio_embeddings_api_key},
|
||||||
json=data,
|
json=data,
|
||||||
)
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Sambanova /complete call failed with status code "
|
||||||
|
f"{response.status_code}.\n Details: {response.text}"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
embedding = response.json()["data"]
|
embedding = response.json()["data"]
|
||||||
embeddings.extend(embedding)
|
embeddings.extend(embedding)
|
||||||
@@ -165,7 +173,32 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
|
|||||||
response.json(),
|
response.json(),
|
||||||
)
|
)
|
||||||
|
|
||||||
elif "generic" in self.sambastudio_embeddings_base_uri:
|
elif "api/v2/predict/generic" in self.sambastudio_embeddings_base_uri:
|
||||||
|
for batch in self._iterate_over_batches(texts, batch_size):
|
||||||
|
items = [
|
||||||
|
{"id": f"item{i}", "value": item} for i, item in enumerate(batch)
|
||||||
|
]
|
||||||
|
data = {"items": items, "params": params}
|
||||||
|
response = http_session.post(
|
||||||
|
url,
|
||||||
|
headers={"key": self.sambastudio_embeddings_api_key},
|
||||||
|
json=data,
|
||||||
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Sambanova /complete call failed with status code "
|
||||||
|
f"{response.status_code}.\n Details: {response.text}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
embedding = [item["value"] for item in response.json()["items"]]
|
||||||
|
embeddings.extend(embedding)
|
||||||
|
except KeyError:
|
||||||
|
raise KeyError(
|
||||||
|
"'items' not found in endpoint response",
|
||||||
|
response.json(),
|
||||||
|
)
|
||||||
|
|
||||||
|
elif "api/predict/generic" in self.sambastudio_embeddings_base_uri:
|
||||||
for batch in self._iterate_over_batches(texts, batch_size):
|
for batch in self._iterate_over_batches(texts, batch_size):
|
||||||
data = {"instances": batch, "params": params}
|
data = {"instances": batch, "params": params}
|
||||||
response = http_session.post(
|
response = http_session.post(
|
||||||
@@ -173,6 +206,11 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
|
|||||||
headers={"key": self.sambastudio_embeddings_api_key},
|
headers={"key": self.sambastudio_embeddings_api_key},
|
||||||
json=data,
|
json=data,
|
||||||
)
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Sambanova /complete call failed with status code "
|
||||||
|
f"{response.status_code}.\n Details: {response.text}"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
if params.get("select_expert"):
|
if params.get("select_expert"):
|
||||||
embedding = response.json()["predictions"][0]
|
embedding = response.json()["predictions"][0]
|
||||||
@@ -207,13 +245,18 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
|
|||||||
)
|
)
|
||||||
params = json.loads(self._get_tuning_params())
|
params = json.loads(self._get_tuning_params())
|
||||||
|
|
||||||
if "nlp" in self.sambastudio_embeddings_base_uri:
|
if "api/predict/nlp" in self.sambastudio_embeddings_base_uri:
|
||||||
data = {"inputs": [text], "params": params}
|
data = {"inputs": [text], "params": params}
|
||||||
response = http_session.post(
|
response = http_session.post(
|
||||||
url,
|
url,
|
||||||
headers={"key": self.sambastudio_embeddings_api_key},
|
headers={"key": self.sambastudio_embeddings_api_key},
|
||||||
json=data,
|
json=data,
|
||||||
)
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Sambanova /complete call failed with status code "
|
||||||
|
f"{response.status_code}.\n Details: {response.text}"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
embedding = response.json()["data"][0]
|
embedding = response.json()["data"][0]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@@ -222,13 +265,38 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
|
|||||||
response.json(),
|
response.json(),
|
||||||
)
|
)
|
||||||
|
|
||||||
elif "generic" in self.sambastudio_embeddings_base_uri:
|
elif "api/v2/predict/generic" in self.sambastudio_embeddings_base_uri:
|
||||||
|
data = {"items": [{"id": "item0", "value": text}], "params": params}
|
||||||
|
response = http_session.post(
|
||||||
|
url,
|
||||||
|
headers={"key": self.sambastudio_embeddings_api_key},
|
||||||
|
json=data,
|
||||||
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Sambanova /complete call failed with status code "
|
||||||
|
f"{response.status_code}.\n Details: {response.text}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
embedding = response.json()["items"][0]["value"]
|
||||||
|
except KeyError:
|
||||||
|
raise KeyError(
|
||||||
|
"'items' not found in endpoint response",
|
||||||
|
response.json(),
|
||||||
|
)
|
||||||
|
|
||||||
|
elif "api/predict/generic" in self.sambastudio_embeddings_base_uri:
|
||||||
data = {"instances": [text], "params": params}
|
data = {"instances": [text], "params": params}
|
||||||
response = http_session.post(
|
response = http_session.post(
|
||||||
url,
|
url,
|
||||||
headers={"key": self.sambastudio_embeddings_api_key},
|
headers={"key": self.sambastudio_embeddings_api_key},
|
||||||
json=data,
|
json=data,
|
||||||
)
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Sambanova /complete call failed with status code "
|
||||||
|
f"{response.status_code}.\n Details: {response.text}"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
if params.get("select_expert"):
|
if params.get("select_expert"):
|
||||||
embedding = response.json()["predictions"][0][0]
|
embedding = response.json()["predictions"][0][0]
|
||||||
|
Reference in New Issue
Block a user