community: sambastudio llms api v2 support (#25063)

- **Description:** SambaStudio GenericV2 API support
This commit is contained in:
Jorge Piedrahita Ortiz 2024-09-03 09:18:15 -05:00 committed by GitHub
parent 8d784db107
commit c7154a4045
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -514,7 +514,7 @@ class SSEndpointHandler:
response: requests.Response,
) -> Generator[Dict, None, None]:
"""Process the streaming response"""
if "nlp" in self.api_base_uri:
if "api/predict/nlp" in self.api_base_uri:
try:
import sseclient
except ImportError:
@ -535,14 +535,15 @@ class SSEndpointHandler:
yield chunk
if close_conn:
client.close()
elif "generic" in self.api_base_uri:
elif (
"api/v2/predict/generic" in self.api_base_uri
or "api/predict/generic" in self.api_base_uri
):
try:
for line in response.iter_lines():
chunk = json.loads(line)
if "status_code" not in chunk:
chunk["status_code"] = response.status_code
if chunk["status_code"] == 200 and chunk.get("error"):
chunk["result"] = {"responses": [{"stream_token": ""}]}
yield chunk
except Exception as e:
raise RuntimeError(f"Error processing streaming response: {e}")
@ -583,12 +584,18 @@ class SSEndpointHandler:
"""
if isinstance(input, str):
input = [input]
if "nlp" in self.api_base_uri:
if "api/predict/nlp" in self.api_base_uri:
if params:
data = {"inputs": input, "params": json.loads(params)}
else:
data = {"inputs": input}
elif "generic" in self.api_base_uri:
elif "api/v2/predict/generic" in self.api_base_uri:
items = [{"id": f"item{i}", "value": item} for i, item in enumerate(input)]
if params:
data = {"items": items, "params": json.loads(params)}
else:
data = {"items": items}
elif "api/predict/generic" in self.api_base_uri:
if params:
data = {"instances": input, "params": json.loads(params)}
else:
@ -623,14 +630,22 @@ class SSEndpointHandler:
:returns: Prediction results
:type: dict
"""
if "nlp" in self.api_base_uri:
if "api/predict/nlp" in self.api_base_uri:
if isinstance(input, str):
input = [input]
if params:
data = {"inputs": input, "params": json.loads(params)}
else:
data = {"inputs": input}
elif "generic" in self.api_base_uri:
elif "api/v2/predict/generic" in self.api_base_uri:
if isinstance(input, str):
input = [input]
items = [{"id": f"item{i}", "value": item} for i, item in enumerate(input)]
if params:
data = {"items": items, "params": json.loads(params)}
else:
data = {"items": items}
elif "api/predict/generic" in self.api_base_uri:
if isinstance(input, list):
input = input[0]
if params:
@ -770,10 +785,13 @@ class SambaStudio(LLM):
# _model_kwargs["stop_sequences"] = ",".join(
# f'"{x}"' for x in _stop_sequences
# )
tuning_params_dict = {
k: {"type": type(v).__name__, "value": str(v)}
for k, v in (_model_kwargs.items())
}
if "api/v2/predict/generic" in self.sambastudio_base_uri:
tuning_params_dict = _model_kwargs
else:
tuning_params_dict = {
k: {"type": type(v).__name__, "value": str(v)}
for k, v in (_model_kwargs.items())
}
# _model_kwargs["stop_sequences"] = _kwarg_stop_sequences
tuning_params = json.dumps(tuning_params_dict)
return tuning_params
@ -814,9 +832,11 @@ class SambaStudio(LLM):
f"Sambanova /complete call failed with status code "
f"{response['status_code']}.\n response {response}"
)
if "nlp" in self.sambastudio_base_uri:
if "api/predict/nlp" in self.sambastudio_base_uri:
return response["data"][0]["completion"]
elif "generic" in self.sambastudio_base_uri:
elif "api/v2/predict/generic" in self.sambastudio_base_uri:
return response["items"][0]["value"]["completion"]
elif "api/predict/generic" in self.sambastudio_base_uri:
return response["predictions"][0]["completion"]
else:
raise ValueError(
@ -885,10 +905,15 @@ class SambaStudio(LLM):
f"{chunk['status_code']}."
f"{chunk}."
)
if "nlp" in self.sambastudio_base_uri:
if "api/predict/nlp" in self.sambastudio_base_uri:
text = json.loads(chunk["data"])["stream_token"]
elif "generic" in self.sambastudio_base_uri:
text = chunk["result"]["responses"][0]["stream_token"]
elif "api/v2/predict/generic" in self.sambastudio_base_uri:
text = chunk["result"]["items"][0]["value"]["stream_token"]
elif "api/predict/generic" in self.sambastudio_base_uri:
if len(chunk["result"]["responses"]) > 0:
text = chunk["result"]["responses"][0]["stream_token"]
else:
text = ""
else:
raise ValueError(
f"handling of endpoint uri: {self.sambastudio_base_uri}"