From e6207ad4f3cba828fce2623701ed41228c4a272b Mon Sep 17 00:00:00 2001 From: Jorge Piedrahita Ortiz Date: Mon, 20 May 2024 17:29:59 -0500 Subject: [PATCH] community[patch]: Sambanova integration api update (#21848) - **Description:**: SambaStudio generic endpoint compatibility added Improved error description, and handling streaming examples added --- docs/docs/integrations/llms/sambanova.ipynb | 59 ++++ .../langchain_community/llms/sambanova.py | 259 +++++++++++++----- 2 files changed, 248 insertions(+), 70 deletions(-) diff --git a/docs/docs/integrations/llms/sambanova.ipynb b/docs/docs/integrations/llms/sambanova.ipynb index c1926ff2183..522b9bb959c 100644 --- a/docs/docs/integrations/llms/sambanova.ipynb +++ b/docs/docs/integrations/llms/sambanova.ipynb @@ -99,6 +99,36 @@ "print(llm.invoke(\"Why should I use open source models?\"))" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Streaming response\n", + "\n", + "from langchain_community.llms.sambanova import Sambaverse\n", + "\n", + "llm = Sambaverse(\n", + " sambaverse_model_name=\"Meta/llama-2-7b-chat-hf\",\n", + " streaming=True,\n", + " model_kwargs={\n", + " \"do_sample\": True,\n", + " \"max_tokens_to_generate\": 1000,\n", + " \"temperature\": 0.01,\n", + " \"process_prompt\": True,\n", + " \"select_expert\": \"llama-2-7b-chat-hf\",\n", + " # \"stop_sequences\": '\\\"sequence1\\\",\\\"sequence2\\\"',\n", + " # \"repetition_penalty\": 1.0,\n", + " # \"top_k\": 50,\n", + " # \"top_p\": 1.0\n", + " },\n", + ")\n", + "\n", + "for chunk in llm.stream(\"Why should I use open source models?\"):\n", + " print(chunk, end=\"\", flush=True)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -147,12 +177,14 @@ "import os\n", "\n", "sambastudio_base_url = \"\"\n", + "# sambastudio_base_uri = \"\" # optional, \"api/predict/nlp\" set as default\n", "sambastudio_project_id = \"\"\n", "sambastudio_endpoint_id = \"\"\n", "sambastudio_api_key = \"\"\n", "\n", "# Set the environment variables\n", "os.environ[\"SAMBASTUDIO_BASE_URL\"] = sambastudio_base_url\n", + "# os.environ[\"SAMBASTUDIO_BASE_URI\"] = sambastudio_base_uri\n", "os.environ[\"SAMBASTUDIO_PROJECT_ID\"] = sambastudio_project_id\n", "os.environ[\"SAMBASTUDIO_ENDPOINT_ID\"] = sambastudio_endpoint_id\n", "os.environ[\"SAMBASTUDIO_API_KEY\"] = sambastudio_api_key" @@ -188,6 +220,33 @@ "\n", "print(llm.invoke(\"Why should I use open source models?\"))" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Streaming response\n", + "\n", + "from langchain_community.llms.sambanova import SambaStudio\n", + "\n", + "llm = SambaStudio(\n", + " streaming=True,\n", + " model_kwargs={\n", + " \"do_sample\": True,\n", + " \"max_tokens_to_generate\": 1000,\n", + " \"temperature\": 0.01,\n", + " # \"repetition_penalty\": 1.0,\n", + " # \"top_k\": 50,\n", + " # \"top_logprobs\": 0,\n", + " # \"top_p\": 1.0\n", + " },\n", + ")\n", + "\n", + "for chunk in llm.stream(\"Why should I use open source models?\"):\n", + " print(chunk, end=\"\", flush=True)" + ] } ], "metadata": { diff --git a/libs/community/langchain_community/llms/sambanova.py b/libs/community/langchain_community/llms/sambanova.py index d67b344de18..e7e89e1522b 100644 --- a/libs/community/langchain_community/llms/sambanova.py +++ b/libs/community/langchain_community/llms/sambanova.py @@ -216,7 +216,7 @@ class Sambaverse(LLM): ) """ - sambaverse_url: str = "https://sambaverse.sambanova.ai" + sambaverse_url: str = "" """Sambaverse url to use""" sambaverse_api_key: str = "" @@ -244,7 +244,10 @@ class Sambaverse(LLM): def validate_environment(cls, values: Dict) -> Dict: """Validate that api key exists in environment.""" values["sambaverse_url"] = get_from_dict_or_env( - values, "sambaverse_url", "SAMBAVERSE_URL" + values, + "sambaverse_url", + "SAMBAVERSE_URL", + default="https://sambaverse.sambanova.ai", ) values["sambaverse_api_key"] = get_from_dict_or_env( values, "sambaverse_api_key", "SAMBAVERSE_API_KEY" @@ -314,16 +317,24 @@ class Sambaverse(LLM): self.sambaverse_api_key, self.sambaverse_model_name, prompt, tuning_params ) if response["status_code"] != 200: - optional_code = response["error"].get("code") - optional_details = response["error"].get("details") - optional_message = response["error"].get("message") - raise ValueError( - f"Sambanova /complete call failed with status code " - f"{response['status_code']}." - f"Message: {optional_message}" - f"Details: {optional_details}" - f"Code: {optional_code}" - ) + error = response.get("error") + if error: + optional_code = error.get("code") + optional_details = error.get("details") + optional_message = error.get("message") + raise RuntimeError( + f"Sambanova /complete call failed with status code " + f"{response['status_code']}.\n" + f"Message: {optional_message}\n" + f"Details: {optional_details}\n" + f"Code: {optional_code}\n" + ) + else: + raise RuntimeError( + f"Sambanova /complete call failed with status code " + f"{response['status_code']}." + f"{response}." + ) return response["result"]["responses"][0]["completion"] def _handle_completion_requests( @@ -364,16 +375,24 @@ class Sambaverse(LLM): self.sambaverse_api_key, self.sambaverse_model_name, prompt, tuning_params ): if chunk["status_code"] != 200: - optional_code = chunk["error"].get("code") - optional_details = chunk["error"].get("details") - optional_message = chunk["error"].get("message") - raise ValueError( - f"Sambanova /complete call failed with status code " - f"{chunk['status_code']}." - f"Message: {optional_message}" - f"Details: {optional_details}" - f"Code: {optional_code}" - ) + error = chunk.get("error") + if error: + optional_code = error.get("code") + optional_details = error.get("details") + optional_message = error.get("message") + raise ValueError( + f"Sambanova /complete call failed with status code " + f"{chunk['status_code']}.\n" + f"Message: {optional_message}\n" + f"Details: {optional_details}\n" + f"Code: {optional_code}\n" + ) + else: + raise RuntimeError( + f"Sambanova /complete call failed with status code " + f"{chunk['status_code']}." + f"{chunk}." + ) text = chunk["result"]["responses"][0]["stream_token"] generated_chunk = GenerationChunk(text=text) yield generated_chunk @@ -477,19 +496,18 @@ class SSEndpointHandler: :param str host_url: Base URL of the DaaS API service """ - API_BASE_PATH = "/api" - - def __init__(self, host_url: str): + def __init__(self, host_url: str, api_base_uri: str): """ Initialize the SSEndpointHandler. :param str host_url: Base URL of the DaaS API service + :param str api_base_uri: Base URI of the DaaS API service """ self.host_url = host_url + self.api_base_uri = api_base_uri self.http_session = requests.Session() - @staticmethod - def _process_response(response: requests.Response) -> Dict: + def _process_response(self, response: requests.Response) -> Dict: """ Processes the API response and returns the resulting dict. @@ -515,28 +533,47 @@ class SSEndpointHandler: result["status_code"] = response.status_code return result - @staticmethod def _process_streaming_response( + self, response: requests.Response, - ) -> Generator[GenerationChunk, None, None]: + ) -> Generator[Dict, None, None]: """Process the streaming response""" - try: - import sseclient - except ImportError: - raise ImportError( - "could not import sseclient library" - "Please install it with `pip install sseclient-py`." + if "nlp" in self.api_base_uri: + try: + import sseclient + except ImportError: + raise ImportError( + "could not import sseclient library" + "Please install it with `pip install sseclient-py`." + ) + client = sseclient.SSEClient(response) + close_conn = False + for event in client.events(): + if event.event == "error_event": + close_conn = True + chunk = { + "event": event.event, + "data": event.data, + "status_code": response.status_code, + } + yield chunk + if close_conn: + client.close() + elif "generic" in self.api_base_uri: + try: + for line in response.iter_lines(): + chunk = json.loads(line) + if "status_code" not in chunk: + chunk["status_code"] = response.status_code + if chunk["status_code"] == 200 and chunk.get("error"): + chunk["result"] = {"responses": [{"stream_token": ""}]} + yield chunk + except Exception as e: + raise RuntimeError(f"Error processing streaming response: {e}") + else: + raise ValueError( + f"handling of endpoint uri: {self.api_base_uri} not implemented" ) - client = sseclient.SSEClient(response) - close_conn = False - for event in client.events(): - if event.event == "error_event": - close_conn = True - text = json.dumps({"event": event.event, "data": event.data}) - chunk = GenerationChunk(text=text) - yield chunk - if close_conn: - client.close() def _get_full_url(self, path: str) -> str: """ @@ -546,7 +583,7 @@ class SSEndpointHandler: :returns: the full API URL for the sub-path :rtype: str """ - return f"{self.host_url}{self.API_BASE_PATH}{path}" + return f"{self.host_url}/{self.api_base_uri}/{path}" def nlp_predict( self, @@ -570,16 +607,26 @@ class SSEndpointHandler: """ if isinstance(input, str): input = [input] - if params: - data = {"inputs": input, "params": json.loads(params)} + if "nlp" in self.api_base_uri: + if params: + data = {"inputs": input, "params": json.loads(params)} + else: + data = {"inputs": input} + elif "generic" in self.api_base_uri: + if params: + data = {"instances": input, "params": json.loads(params)} + else: + data = {"instances": input} else: - data = {"inputs": input} + raise ValueError( + f"handling of endpoint uri: {self.api_base_uri} not implemented" + ) response = self.http_session.post( - self._get_full_url(f"/predict/nlp/{project}/{endpoint}"), + self._get_full_url(f"{project}/{endpoint}"), headers={"key": key}, json=data, ) - return SSEndpointHandler._process_response(response) + return self._process_response(response) def nlp_predict_stream( self, @@ -588,7 +635,7 @@ class SSEndpointHandler: key: str, input: Union[List[str], str], params: Optional[str] = "", - ) -> Iterator[GenerationChunk]: + ) -> Iterator[Dict]: """ NLP predict using inline input string. @@ -600,20 +647,32 @@ class SSEndpointHandler: :returns: Prediction results :rtype: dict """ - if isinstance(input, str): - input = [input] - if params: - data = {"inputs": input, "params": json.loads(params)} + if "nlp" in self.api_base_uri: + if isinstance(input, str): + input = [input] + if params: + data = {"inputs": input, "params": json.loads(params)} + else: + data = {"inputs": input} + elif "generic" in self.api_base_uri: + if isinstance(input, list): + input = input[0] + if params: + data = {"instance": input, "params": json.loads(params)} + else: + data = {"instance": input} else: - data = {"inputs": input} + raise ValueError( + f"handling of endpoint uri: {self.api_base_uri} not implemented" + ) # Streaming output response = self.http_session.post( - self._get_full_url(f"/predict/nlp/stream/{project}/{endpoint}"), + self._get_full_url(f"stream/{project}/{endpoint}"), headers={"key": key}, json=data, stream=True, ) - for chunk in SSEndpointHandler._process_streaming_response(response): + for chunk in self._process_streaming_response(response): yield chunk @@ -623,6 +682,7 @@ class SambaStudio(LLM): To use, you should have the environment variables ``SAMBASTUDIO_BASE_URL`` set with your SambaStudio environment URL. + ``SAMBASTUDIO_BASE_URI`` set with your SambaStudio api base URI. ``SAMBASTUDIO_PROJECT_ID`` set with your SambaStudio project ID. ``SAMBASTUDIO_ENDPOINT_ID`` set with your SambaStudio endpoint ID. ``SAMBASTUDIO_API_KEY`` set with your SambaStudio endpoint API key. @@ -637,6 +697,7 @@ class SambaStudio(LLM): from langchain_community.llms.sambanova import Sambaverse SambaStudio( sambastudio_base_url="your-SambaStudio-environment-URL", + sambastudio_base_uri="your-SambaStudio-base-URI", sambastudio_project_id="your-SambaStudio-project-ID", sambastudio_endpoint_id="your-SambaStudio-endpoint-ID", sambastudio_api_key="your-SambaStudio-endpoint-API-key, @@ -655,6 +716,9 @@ class SambaStudio(LLM): sambastudio_base_url: str = "" """Base url to use""" + sambastudio_base_uri: str = "" + """endpoint base uri""" + sambastudio_project_id: str = "" """Project id on sambastudio for model""" @@ -695,6 +759,12 @@ class SambaStudio(LLM): values["sambastudio_base_url"] = get_from_dict_or_env( values, "sambastudio_base_url", "SAMBASTUDIO_BASE_URL" ) + values["sambastudio_base_uri"] = get_from_dict_or_env( + values, + "sambastudio_base_uri", + "SAMBASTUDIO_BASE_URI", + default="api/predict/nlp", + ) values["sambastudio_project_id"] = get_from_dict_or_env( values, "sambastudio_project_id", "SAMBASTUDIO_PROJECT_ID" ) @@ -718,14 +788,17 @@ class SambaStudio(LLM): The tuning parameters as a JSON string. """ _model_kwargs = self.model_kwargs or {} - _stop_sequences = _model_kwargs.get("stop_sequences", []) - _stop_sequences = stop or _stop_sequences - # _model_kwargs['stop_sequences'] = ','.join( - # f"'{x}'" for x in _stop_sequences) + _kwarg_stop_sequences = _model_kwargs.get("stop_sequences", []) + _stop_sequences = stop or _kwarg_stop_sequences + # if not _kwarg_stop_sequences: + # _model_kwargs["stop_sequences"] = ",".join( + # f'"{x}"' for x in _stop_sequences + # ) tuning_params_dict = { k: {"type": type(v).__name__, "value": str(v)} for k, v in (_model_kwargs.items()) } + # _model_kwargs["stop_sequences"] = _kwarg_stop_sequences tuning_params = json.dumps(tuning_params_dict) return tuning_params @@ -754,12 +827,25 @@ class SambaStudio(LLM): tuning_params, ) if response["status_code"] != 200: - optional_detail = response["detail"] + optional_detail = response.get("detail") + if optional_detail: + raise RuntimeError( + f"Sambanova /complete call failed with status code " + f"{response['status_code']}.\n Details: {optional_detail}" + ) + else: + raise RuntimeError( + f"Sambanova /complete call failed with status code " + f"{response['status_code']}.\n response {response}" + ) + if "nlp" in self.sambastudio_base_uri: + return response["data"][0]["completion"] + elif "generic" in self.sambastudio_base_uri: + return response["predictions"][0]["completion"] + else: raise ValueError( - f"Sambanova /complete call failed with status code " - f"{response['status_code']}. Details: {optional_detail}" + f"handling of endpoint uri: {self.sambastudio_base_uri} not implemented" ) - return response["data"][0]["completion"] def _handle_completion_requests( self, prompt: Union[List[str], str], stop: Optional[List[str]] @@ -777,7 +863,9 @@ class SambaStudio(LLM): Raises: ValueError: If the prediction fails. """ - ss_endpoint = SSEndpointHandler(self.sambastudio_base_url) + ss_endpoint = SSEndpointHandler( + self.sambastudio_base_url, self.sambastudio_base_uri + ) tuning_params = self._get_tuning_params(stop) return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params) @@ -802,7 +890,36 @@ class SambaStudio(LLM): prompt, tuning_params, ): - yield chunk + if chunk["status_code"] != 200: + error = chunk.get("error") + if error: + optional_code = error.get("code") + optional_details = error.get("details") + optional_message = error.get("message") + raise ValueError( + f"Sambanova /complete call failed with status code " + f"{chunk['status_code']}.\n" + f"Message: {optional_message}\n" + f"Details: {optional_details}\n" + f"Code: {optional_code}\n" + ) + else: + raise RuntimeError( + f"Sambanova /complete call failed with status code " + f"{chunk['status_code']}." + f"{chunk}." + ) + if "nlp" in self.sambastudio_base_uri: + text = json.loads(chunk["data"])["stream_token"] + elif "generic" in self.sambastudio_base_uri: + text = chunk["result"]["responses"][0]["stream_token"] + else: + raise ValueError( + f"handling of endpoint uri: {self.sambastudio_base_uri}" + f"not implemented" + ) + generated_chunk = GenerationChunk(text=text) + yield generated_chunk def _stream( self, @@ -820,7 +937,9 @@ class SambaStudio(LLM): Returns: The string generated by the model. """ - ss_endpoint = SSEndpointHandler(self.sambastudio_base_url) + ss_endpoint = SSEndpointHandler( + self.sambastudio_base_url, self.sambastudio_base_uri + ) tuning_params = self._get_tuning_params(stop) try: if self.streaming: