From 700b1c7212d565509b9a0b126fd61a852bb1a7f4 Mon Sep 17 00:00:00 2001 From: Jorge Piedrahita Ortiz Date: Fri, 17 May 2024 12:18:08 -0500 Subject: [PATCH] community: sambaverse api update (#21816) - **Description:** fix sambaverse integration to make it compatible with sambaverse API update / minor changes in docs --- docs/docs/integrations/llms/sambanova.ipynb | 18 ++- .../langchain_community/llms/sambanova.py | 145 ++++++++++-------- 2 files changed, 91 insertions(+), 72 deletions(-) diff --git a/docs/docs/integrations/llms/sambanova.ipynb b/docs/docs/integrations/llms/sambanova.ipynb index d68d63aac2b..c1926ff2183 100644 --- a/docs/docs/integrations/llms/sambanova.ipynb +++ b/docs/docs/integrations/llms/sambanova.ipynb @@ -22,7 +22,8 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "**Sambaverse** allows you to interact with multiple open-source models. You can view the list of available models and interact with them in the [playground](https://sambaverse.sambanova.ai/playground).\n **Please note that Sambaverse's free offering is performance-limited.** Companies that are ready to evaluate the production tokens-per-second performance, volume throughput, and 10x lower total cost of ownership (TCO) of SambaNova should [contact us](https://sambaverse.sambanova.ai/contact-us) for a non-limited evaluation instance." + "**Sambaverse** allows you to interact with multiple open-source models. You can view the list of available models and interact with them in the [playground](https://sambaverse.sambanova.ai/playground).\n", + " **Please note that Sambaverse's free offering is performance-limited.** Companies that are ready to evaluate the production tokens-per-second performance, volume throughput, and 10x lower total cost of ownership (TCO) of SambaNova should [contact us](https://sambaverse.sambanova.ai/contact-us) for a non-limited evaluation instance." ] }, { @@ -88,9 +89,10 @@ " \"temperature\": 0.01,\n", " \"process_prompt\": True,\n", " \"select_expert\": \"llama-2-7b-chat-hf\",\n", - " # \"repetition_penalty\": {\"type\": \"float\", \"value\": \"1\"},\n", - " # \"top_k\": {\"type\": \"int\", \"value\": \"50\"},\n", - " # \"top_p\": {\"type\": \"float\", \"value\": \"1\"}\n", + " # \"stop_sequences\": '\\\"sequence1\\\",\\\"sequence2\\\"',\n", + " # \"repetition_penalty\": 1.0,\n", + " # \"top_k\": 50,\n", + " # \"top_p\": 1.0\n", " },\n", ")\n", "\n", @@ -177,10 +179,10 @@ " \"do_sample\": True,\n", " \"max_tokens_to_generate\": 1000,\n", " \"temperature\": 0.01,\n", - " # \"repetition_penalty\": {\"type\": \"float\", \"value\": \"1\"},\n", - " # \"top_k\": {\"type\": \"int\", \"value\": \"50\"},\n", - " # \"top_logprobs\": {\"type\": \"int\", \"value\": \"0\"},\n", - " # \"top_p\": {\"type\": \"float\", \"value\": \"1\"}\n", + " # \"repetition_penalty\": 1.0,\n", + " # \"top_k\": 50,\n", + " # \"top_logprobs\": 0,\n", + " # \"top_p\": 1.0\n", " },\n", ")\n", "\n", diff --git a/libs/community/langchain_community/llms/sambanova.py b/libs/community/langchain_community/llms/sambanova.py index a24f56f7d73..d67b344de18 100644 --- a/libs/community/langchain_community/llms/sambanova.py +++ b/libs/community/langchain_community/llms/sambanova.py @@ -47,8 +47,19 @@ class SVEndpointHandler: """ result: Dict[str, Any] = {} try: - text_result = response.text.strip().split("\n")[-1] - result = {"data": json.loads("".join(text_result.split("data: ")[1:]))} + lines_result = response.text.strip().split("\n") + text_result = lines_result[-1] + if response.status_code == 200 and json.loads(text_result).get("error"): + completion = "" + for line in lines_result[:-1]: + completion += json.loads(line)["result"]["responses"][0][ + "stream_token" + ] + text_result = lines_result[-2] + result = json.loads(text_result) + result["result"]["responses"][0]["completion"] = completion + else: + result = json.loads(text_result) except Exception as e: result["detail"] = str(e) if "status_code" not in result: @@ -58,25 +69,19 @@ class SVEndpointHandler: @staticmethod def _process_streaming_response( response: requests.Response, - ) -> Generator[GenerationChunk, None, None]: + ) -> Generator[Dict, None, None]: """Process the streaming response""" try: - import sseclient - except ImportError: - raise ImportError( - "could not import sseclient library" - "Please install it with `pip install sseclient-py`." - ) - client = sseclient.SSEClient(response) - close_conn = False - for event in client.events(): - if event.event == "error_event": - close_conn = True - text = json.dumps({"event": event.event, "data": event.data}) - chunk = GenerationChunk(text=text) - yield chunk - if close_conn: - client.close() + for line in response.iter_lines(): + chunk = json.loads(line) + if "status_code" not in chunk: + chunk["status_code"] = response.status_code + if chunk["status_code"] == 200 and chunk.get("error"): + chunk["result"] = {"responses": [{"stream_token": ""}]} + return chunk + yield chunk + except Exception as e: + raise RuntimeError(f"Error processing streaming response: {e}") def _get_full_url(self) -> str: """ @@ -105,25 +110,21 @@ class SVEndpointHandler: :returns: Prediction results :rtype: dict """ - if isinstance(input, str): - input = [input] - parsed_input = [] - for element in input: - parsed_element = { - "conversation_id": "sambaverse-conversation-id", - "messages": [ - { - "message_id": 0, - "role": "user", - "content": element, - } - ], - } - parsed_input.append(json.dumps(parsed_element)) + parsed_element = { + "conversation_id": "sambaverse-conversation-id", + "messages": [ + { + "message_id": 0, + "role": "user", + "content": input, + } + ], + } + parsed_input = json.dumps(parsed_element) if params: - data = {"inputs": parsed_input, "params": json.loads(params)} + data = {"instance": parsed_input, "params": json.loads(params)} else: - data = {"inputs": parsed_input} + data = {"instance": parsed_input} response = self.http_session.post( self._get_full_url(), headers={ @@ -141,7 +142,7 @@ class SVEndpointHandler: sambaverse_model_name: Optional[str], input: Union[List[str], str], params: Optional[str] = "", - ) -> Iterator[GenerationChunk]: + ) -> Iterator[Dict]: """ NLP predict using inline input string. @@ -153,25 +154,21 @@ class SVEndpointHandler: :returns: Prediction results :rtype: dict """ - if isinstance(input, str): - input = [input] - parsed_input = [] - for element in input: - parsed_element = { - "conversation_id": "sambaverse-conversation-id", - "messages": [ - { - "message_id": 0, - "role": "user", - "content": element, - } - ], - } - parsed_input.append(json.dumps(parsed_element)) + parsed_element = { + "conversation_id": "sambaverse-conversation-id", + "messages": [ + { + "message_id": 0, + "role": "user", + "content": input, + } + ], + } + parsed_input = json.dumps(parsed_element) if params: - data = {"inputs": parsed_input, "params": json.loads(params)} + data = {"instance": parsed_input, "params": json.loads(params)} else: - data = {"inputs": parsed_input} + data = {"instance": parsed_input} # Streaming output response = self.http_session.post( self._get_full_url(), @@ -213,7 +210,7 @@ class Sambaverse(LLM): "max_tokens_to_generate": 100, "temperature": 0.7, "top_p": 1.0, - "repetition_penalty": 1, + "repetition_penalty": 1.0, "top_k": 50, }, ) @@ -279,13 +276,17 @@ class Sambaverse(LLM): The tuning parameters as a JSON string. """ _model_kwargs = self.model_kwargs or {} - _stop_sequences = _model_kwargs.get("stop_sequences", []) - _stop_sequences = stop or _stop_sequences - _model_kwargs["stop_sequences"] = ",".join(f'"{x}"' for x in _stop_sequences) + _kwarg_stop_sequences = _model_kwargs.get("stop_sequences", []) + _stop_sequences = stop or _kwarg_stop_sequences + if not _kwarg_stop_sequences: + _model_kwargs["stop_sequences"] = ",".join( + f'"{x}"' for x in _stop_sequences + ) tuning_params_dict = { k: {"type": type(v).__name__, "value": str(v)} for k, v in (_model_kwargs.items()) } + _model_kwargs["stop_sequences"] = _kwarg_stop_sequences tuning_params = json.dumps(tuning_params_dict) return tuning_params @@ -313,14 +314,17 @@ class Sambaverse(LLM): self.sambaverse_api_key, self.sambaverse_model_name, prompt, tuning_params ) if response["status_code"] != 200: - optional_details = response["details"] - optional_message = response["message"] + optional_code = response["error"].get("code") + optional_details = response["error"].get("details") + optional_message = response["error"].get("message") raise ValueError( f"Sambanova /complete call failed with status code " - f"{response['status_code']}. Details: {optional_details}" - f"{response['status_code']}. Message: {optional_message}" + f"{response['status_code']}." + f"Message: {optional_message}" + f"Details: {optional_details}" + f"Code: {optional_code}" ) - return response["data"]["completion"] + return response["result"]["responses"][0]["completion"] def _handle_completion_requests( self, prompt: Union[List[str], str], stop: Optional[List[str]] @@ -359,7 +363,20 @@ class Sambaverse(LLM): for chunk in sdk.nlp_predict_stream( self.sambaverse_api_key, self.sambaverse_model_name, prompt, tuning_params ): - yield chunk + if chunk["status_code"] != 200: + optional_code = chunk["error"].get("code") + optional_details = chunk["error"].get("details") + optional_message = chunk["error"].get("message") + raise ValueError( + f"Sambanova /complete call failed with status code " + f"{chunk['status_code']}." + f"Message: {optional_message}" + f"Details: {optional_details}" + f"Code: {optional_code}" + ) + text = chunk["result"]["responses"][0]["stream_token"] + generated_chunk = GenerationChunk(text=text) + yield generated_chunk def _stream( self,