community: sambastudio llms api v2 support (#25063)

- **Description:** SambaStudio GenericV2 API support
This commit is contained in:
Jorge Piedrahita Ortiz 2024-09-03 09:18:15 -05:00 committed by GitHub
parent 8d784db107
commit c7154a4045
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -514,7 +514,7 @@ class SSEndpointHandler:
response: requests.Response, response: requests.Response,
) -> Generator[Dict, None, None]: ) -> Generator[Dict, None, None]:
"""Process the streaming response""" """Process the streaming response"""
if "nlp" in self.api_base_uri: if "api/predict/nlp" in self.api_base_uri:
try: try:
import sseclient import sseclient
except ImportError: except ImportError:
@ -535,14 +535,15 @@ class SSEndpointHandler:
yield chunk yield chunk
if close_conn: if close_conn:
client.close() client.close()
elif "generic" in self.api_base_uri: elif (
"api/v2/predict/generic" in self.api_base_uri
or "api/predict/generic" in self.api_base_uri
):
try: try:
for line in response.iter_lines(): for line in response.iter_lines():
chunk = json.loads(line) chunk = json.loads(line)
if "status_code" not in chunk: if "status_code" not in chunk:
chunk["status_code"] = response.status_code chunk["status_code"] = response.status_code
if chunk["status_code"] == 200 and chunk.get("error"):
chunk["result"] = {"responses": [{"stream_token": ""}]}
yield chunk yield chunk
except Exception as e: except Exception as e:
raise RuntimeError(f"Error processing streaming response: {e}") raise RuntimeError(f"Error processing streaming response: {e}")
@ -583,12 +584,18 @@ class SSEndpointHandler:
""" """
if isinstance(input, str): if isinstance(input, str):
input = [input] input = [input]
if "nlp" in self.api_base_uri: if "api/predict/nlp" in self.api_base_uri:
if params: if params:
data = {"inputs": input, "params": json.loads(params)} data = {"inputs": input, "params": json.loads(params)}
else: else:
data = {"inputs": input} data = {"inputs": input}
elif "generic" in self.api_base_uri: elif "api/v2/predict/generic" in self.api_base_uri:
items = [{"id": f"item{i}", "value": item} for i, item in enumerate(input)]
if params:
data = {"items": items, "params": json.loads(params)}
else:
data = {"items": items}
elif "api/predict/generic" in self.api_base_uri:
if params: if params:
data = {"instances": input, "params": json.loads(params)} data = {"instances": input, "params": json.loads(params)}
else: else:
@ -623,14 +630,22 @@ class SSEndpointHandler:
:returns: Prediction results :returns: Prediction results
:type: dict :type: dict
""" """
if "nlp" in self.api_base_uri: if "api/predict/nlp" in self.api_base_uri:
if isinstance(input, str): if isinstance(input, str):
input = [input] input = [input]
if params: if params:
data = {"inputs": input, "params": json.loads(params)} data = {"inputs": input, "params": json.loads(params)}
else: else:
data = {"inputs": input} data = {"inputs": input}
elif "generic" in self.api_base_uri: elif "api/v2/predict/generic" in self.api_base_uri:
if isinstance(input, str):
input = [input]
items = [{"id": f"item{i}", "value": item} for i, item in enumerate(input)]
if params:
data = {"items": items, "params": json.loads(params)}
else:
data = {"items": items}
elif "api/predict/generic" in self.api_base_uri:
if isinstance(input, list): if isinstance(input, list):
input = input[0] input = input[0]
if params: if params:
@ -770,10 +785,13 @@ class SambaStudio(LLM):
# _model_kwargs["stop_sequences"] = ",".join( # _model_kwargs["stop_sequences"] = ",".join(
# f'"{x}"' for x in _stop_sequences # f'"{x}"' for x in _stop_sequences
# ) # )
tuning_params_dict = { if "api/v2/predict/generic" in self.sambastudio_base_uri:
k: {"type": type(v).__name__, "value": str(v)} tuning_params_dict = _model_kwargs
for k, v in (_model_kwargs.items()) else:
} tuning_params_dict = {
k: {"type": type(v).__name__, "value": str(v)}
for k, v in (_model_kwargs.items())
}
# _model_kwargs["stop_sequences"] = _kwarg_stop_sequences # _model_kwargs["stop_sequences"] = _kwarg_stop_sequences
tuning_params = json.dumps(tuning_params_dict) tuning_params = json.dumps(tuning_params_dict)
return tuning_params return tuning_params
@ -814,9 +832,11 @@ class SambaStudio(LLM):
f"Sambanova /complete call failed with status code " f"Sambanova /complete call failed with status code "
f"{response['status_code']}.\n response {response}" f"{response['status_code']}.\n response {response}"
) )
if "nlp" in self.sambastudio_base_uri: if "api/predict/nlp" in self.sambastudio_base_uri:
return response["data"][0]["completion"] return response["data"][0]["completion"]
elif "generic" in self.sambastudio_base_uri: elif "api/v2/predict/generic" in self.sambastudio_base_uri:
return response["items"][0]["value"]["completion"]
elif "api/predict/generic" in self.sambastudio_base_uri:
return response["predictions"][0]["completion"] return response["predictions"][0]["completion"]
else: else:
raise ValueError( raise ValueError(
@ -885,10 +905,15 @@ class SambaStudio(LLM):
f"{chunk['status_code']}." f"{chunk['status_code']}."
f"{chunk}." f"{chunk}."
) )
if "nlp" in self.sambastudio_base_uri: if "api/predict/nlp" in self.sambastudio_base_uri:
text = json.loads(chunk["data"])["stream_token"] text = json.loads(chunk["data"])["stream_token"]
elif "generic" in self.sambastudio_base_uri: elif "api/v2/predict/generic" in self.sambastudio_base_uri:
text = chunk["result"]["responses"][0]["stream_token"] text = chunk["result"]["items"][0]["value"]["stream_token"]
elif "api/predict/generic" in self.sambastudio_base_uri:
if len(chunk["result"]["responses"]) > 0:
text = chunk["result"]["responses"][0]["stream_token"]
else:
text = ""
else: else:
raise ValueError( raise ValueError(
f"handling of endpoint uri: {self.sambastudio_base_uri}" f"handling of endpoint uri: {self.sambastudio_base_uri}"