community: sambaverse api update (#21816)

- **Description:** fix sambaverse integration to make it compatible with
sambaverse API update / minor changes in docs
This commit is contained in:
Jorge Piedrahita Ortiz 2024-05-17 12:18:08 -05:00 committed by GitHub
parent 7976fb1663
commit 700b1c7212
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 91 additions and 72 deletions

View File

@ -22,7 +22,8 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"**Sambaverse** allows you to interact with multiple open-source models. You can view the list of available models and interact with them in the [playground](https://sambaverse.sambanova.ai/playground).\n **Please note that Sambaverse's free offering is performance-limited.** Companies that are ready to evaluate the production tokens-per-second performance, volume throughput, and 10x lower total cost of ownership (TCO) of SambaNova should [contact us](https://sambaverse.sambanova.ai/contact-us) for a non-limited evaluation instance." "**Sambaverse** allows you to interact with multiple open-source models. You can view the list of available models and interact with them in the [playground](https://sambaverse.sambanova.ai/playground).\n",
" **Please note that Sambaverse's free offering is performance-limited.** Companies that are ready to evaluate the production tokens-per-second performance, volume throughput, and 10x lower total cost of ownership (TCO) of SambaNova should [contact us](https://sambaverse.sambanova.ai/contact-us) for a non-limited evaluation instance."
] ]
}, },
{ {
@ -88,9 +89,10 @@
" \"temperature\": 0.01,\n", " \"temperature\": 0.01,\n",
" \"process_prompt\": True,\n", " \"process_prompt\": True,\n",
" \"select_expert\": \"llama-2-7b-chat-hf\",\n", " \"select_expert\": \"llama-2-7b-chat-hf\",\n",
" # \"repetition_penalty\": {\"type\": \"float\", \"value\": \"1\"},\n", " # \"stop_sequences\": '\\\"sequence1\\\",\\\"sequence2\\\"',\n",
" # \"top_k\": {\"type\": \"int\", \"value\": \"50\"},\n", " # \"repetition_penalty\": 1.0,\n",
" # \"top_p\": {\"type\": \"float\", \"value\": \"1\"}\n", " # \"top_k\": 50,\n",
" # \"top_p\": 1.0\n",
" },\n", " },\n",
")\n", ")\n",
"\n", "\n",
@ -177,10 +179,10 @@
" \"do_sample\": True,\n", " \"do_sample\": True,\n",
" \"max_tokens_to_generate\": 1000,\n", " \"max_tokens_to_generate\": 1000,\n",
" \"temperature\": 0.01,\n", " \"temperature\": 0.01,\n",
" # \"repetition_penalty\": {\"type\": \"float\", \"value\": \"1\"},\n", " # \"repetition_penalty\": 1.0,\n",
" # \"top_k\": {\"type\": \"int\", \"value\": \"50\"},\n", " # \"top_k\": 50,\n",
" # \"top_logprobs\": {\"type\": \"int\", \"value\": \"0\"},\n", " # \"top_logprobs\": 0,\n",
" # \"top_p\": {\"type\": \"float\", \"value\": \"1\"}\n", " # \"top_p\": 1.0\n",
" },\n", " },\n",
")\n", ")\n",
"\n", "\n",

View File

@ -47,8 +47,19 @@ class SVEndpointHandler:
""" """
result: Dict[str, Any] = {} result: Dict[str, Any] = {}
try: try:
text_result = response.text.strip().split("\n")[-1] lines_result = response.text.strip().split("\n")
result = {"data": json.loads("".join(text_result.split("data: ")[1:]))} text_result = lines_result[-1]
if response.status_code == 200 and json.loads(text_result).get("error"):
completion = ""
for line in lines_result[:-1]:
completion += json.loads(line)["result"]["responses"][0][
"stream_token"
]
text_result = lines_result[-2]
result = json.loads(text_result)
result["result"]["responses"][0]["completion"] = completion
else:
result = json.loads(text_result)
except Exception as e: except Exception as e:
result["detail"] = str(e) result["detail"] = str(e)
if "status_code" not in result: if "status_code" not in result:
@ -58,25 +69,19 @@ class SVEndpointHandler:
@staticmethod @staticmethod
def _process_streaming_response( def _process_streaming_response(
response: requests.Response, response: requests.Response,
) -> Generator[GenerationChunk, None, None]: ) -> Generator[Dict, None, None]:
"""Process the streaming response""" """Process the streaming response"""
try: try:
import sseclient for line in response.iter_lines():
except ImportError: chunk = json.loads(line)
raise ImportError( if "status_code" not in chunk:
"could not import sseclient library" chunk["status_code"] = response.status_code
"Please install it with `pip install sseclient-py`." if chunk["status_code"] == 200 and chunk.get("error"):
) chunk["result"] = {"responses": [{"stream_token": ""}]}
client = sseclient.SSEClient(response) return chunk
close_conn = False
for event in client.events():
if event.event == "error_event":
close_conn = True
text = json.dumps({"event": event.event, "data": event.data})
chunk = GenerationChunk(text=text)
yield chunk yield chunk
if close_conn: except Exception as e:
client.close() raise RuntimeError(f"Error processing streaming response: {e}")
def _get_full_url(self) -> str: def _get_full_url(self) -> str:
""" """
@ -105,25 +110,21 @@ class SVEndpointHandler:
:returns: Prediction results :returns: Prediction results
:rtype: dict :rtype: dict
""" """
if isinstance(input, str):
input = [input]
parsed_input = []
for element in input:
parsed_element = { parsed_element = {
"conversation_id": "sambaverse-conversation-id", "conversation_id": "sambaverse-conversation-id",
"messages": [ "messages": [
{ {
"message_id": 0, "message_id": 0,
"role": "user", "role": "user",
"content": element, "content": input,
} }
], ],
} }
parsed_input.append(json.dumps(parsed_element)) parsed_input = json.dumps(parsed_element)
if params: if params:
data = {"inputs": parsed_input, "params": json.loads(params)} data = {"instance": parsed_input, "params": json.loads(params)}
else: else:
data = {"inputs": parsed_input} data = {"instance": parsed_input}
response = self.http_session.post( response = self.http_session.post(
self._get_full_url(), self._get_full_url(),
headers={ headers={
@ -141,7 +142,7 @@ class SVEndpointHandler:
sambaverse_model_name: Optional[str], sambaverse_model_name: Optional[str],
input: Union[List[str], str], input: Union[List[str], str],
params: Optional[str] = "", params: Optional[str] = "",
) -> Iterator[GenerationChunk]: ) -> Iterator[Dict]:
""" """
NLP predict using inline input string. NLP predict using inline input string.
@ -153,25 +154,21 @@ class SVEndpointHandler:
:returns: Prediction results :returns: Prediction results
:rtype: dict :rtype: dict
""" """
if isinstance(input, str):
input = [input]
parsed_input = []
for element in input:
parsed_element = { parsed_element = {
"conversation_id": "sambaverse-conversation-id", "conversation_id": "sambaverse-conversation-id",
"messages": [ "messages": [
{ {
"message_id": 0, "message_id": 0,
"role": "user", "role": "user",
"content": element, "content": input,
} }
], ],
} }
parsed_input.append(json.dumps(parsed_element)) parsed_input = json.dumps(parsed_element)
if params: if params:
data = {"inputs": parsed_input, "params": json.loads(params)} data = {"instance": parsed_input, "params": json.loads(params)}
else: else:
data = {"inputs": parsed_input} data = {"instance": parsed_input}
# Streaming output # Streaming output
response = self.http_session.post( response = self.http_session.post(
self._get_full_url(), self._get_full_url(),
@ -213,7 +210,7 @@ class Sambaverse(LLM):
"max_tokens_to_generate": 100, "max_tokens_to_generate": 100,
"temperature": 0.7, "temperature": 0.7,
"top_p": 1.0, "top_p": 1.0,
"repetition_penalty": 1, "repetition_penalty": 1.0,
"top_k": 50, "top_k": 50,
}, },
) )
@ -279,13 +276,17 @@ class Sambaverse(LLM):
The tuning parameters as a JSON string. The tuning parameters as a JSON string.
""" """
_model_kwargs = self.model_kwargs or {} _model_kwargs = self.model_kwargs or {}
_stop_sequences = _model_kwargs.get("stop_sequences", []) _kwarg_stop_sequences = _model_kwargs.get("stop_sequences", [])
_stop_sequences = stop or _stop_sequences _stop_sequences = stop or _kwarg_stop_sequences
_model_kwargs["stop_sequences"] = ",".join(f'"{x}"' for x in _stop_sequences) if not _kwarg_stop_sequences:
_model_kwargs["stop_sequences"] = ",".join(
f'"{x}"' for x in _stop_sequences
)
tuning_params_dict = { tuning_params_dict = {
k: {"type": type(v).__name__, "value": str(v)} k: {"type": type(v).__name__, "value": str(v)}
for k, v in (_model_kwargs.items()) for k, v in (_model_kwargs.items())
} }
_model_kwargs["stop_sequences"] = _kwarg_stop_sequences
tuning_params = json.dumps(tuning_params_dict) tuning_params = json.dumps(tuning_params_dict)
return tuning_params return tuning_params
@ -313,14 +314,17 @@ class Sambaverse(LLM):
self.sambaverse_api_key, self.sambaverse_model_name, prompt, tuning_params self.sambaverse_api_key, self.sambaverse_model_name, prompt, tuning_params
) )
if response["status_code"] != 200: if response["status_code"] != 200:
optional_details = response["details"] optional_code = response["error"].get("code")
optional_message = response["message"] optional_details = response["error"].get("details")
optional_message = response["error"].get("message")
raise ValueError( raise ValueError(
f"Sambanova /complete call failed with status code " f"Sambanova /complete call failed with status code "
f"{response['status_code']}. Details: {optional_details}" f"{response['status_code']}."
f"{response['status_code']}. Message: {optional_message}" f"Message: {optional_message}"
f"Details: {optional_details}"
f"Code: {optional_code}"
) )
return response["data"]["completion"] return response["result"]["responses"][0]["completion"]
def _handle_completion_requests( def _handle_completion_requests(
self, prompt: Union[List[str], str], stop: Optional[List[str]] self, prompt: Union[List[str], str], stop: Optional[List[str]]
@ -359,7 +363,20 @@ class Sambaverse(LLM):
for chunk in sdk.nlp_predict_stream( for chunk in sdk.nlp_predict_stream(
self.sambaverse_api_key, self.sambaverse_model_name, prompt, tuning_params self.sambaverse_api_key, self.sambaverse_model_name, prompt, tuning_params
): ):
yield chunk if chunk["status_code"] != 200:
optional_code = chunk["error"].get("code")
optional_details = chunk["error"].get("details")
optional_message = chunk["error"].get("message")
raise ValueError(
f"Sambanova /complete call failed with status code "
f"{chunk['status_code']}."
f"Message: {optional_message}"
f"Details: {optional_details}"
f"Code: {optional_code}"
)
text = chunk["result"]["responses"][0]["stream_token"]
generated_chunk = GenerationChunk(text=text)
yield generated_chunk
def _stream( def _stream(
self, self,