diff --git a/libs/community/langchain_community/llms/sambanova.py b/libs/community/langchain_community/llms/sambanova.py index b728eb56781..187c1708fe4 100644 --- a/libs/community/langchain_community/llms/sambanova.py +++ b/libs/community/langchain_community/llms/sambanova.py @@ -514,7 +514,7 @@ class SSEndpointHandler: response: requests.Response, ) -> Generator[Dict, None, None]: """Process the streaming response""" - if "nlp" in self.api_base_uri: + if "api/predict/nlp" in self.api_base_uri: try: import sseclient except ImportError: @@ -535,14 +535,15 @@ class SSEndpointHandler: yield chunk if close_conn: client.close() - elif "generic" in self.api_base_uri: + elif ( + "api/v2/predict/generic" in self.api_base_uri + or "api/predict/generic" in self.api_base_uri + ): try: for line in response.iter_lines(): chunk = json.loads(line) if "status_code" not in chunk: chunk["status_code"] = response.status_code - if chunk["status_code"] == 200 and chunk.get("error"): - chunk["result"] = {"responses": [{"stream_token": ""}]} yield chunk except Exception as e: raise RuntimeError(f"Error processing streaming response: {e}") @@ -583,12 +584,18 @@ class SSEndpointHandler: """ if isinstance(input, str): input = [input] - if "nlp" in self.api_base_uri: + if "api/predict/nlp" in self.api_base_uri: if params: data = {"inputs": input, "params": json.loads(params)} else: data = {"inputs": input} - elif "generic" in self.api_base_uri: + elif "api/v2/predict/generic" in self.api_base_uri: + items = [{"id": f"item{i}", "value": item} for i, item in enumerate(input)] + if params: + data = {"items": items, "params": json.loads(params)} + else: + data = {"items": items} + elif "api/predict/generic" in self.api_base_uri: if params: data = {"instances": input, "params": json.loads(params)} else: @@ -623,14 +630,22 @@ class SSEndpointHandler: :returns: Prediction results :type: dict """ - if "nlp" in self.api_base_uri: + if "api/predict/nlp" in self.api_base_uri: if isinstance(input, str): input = [input] if params: data = {"inputs": input, "params": json.loads(params)} else: data = {"inputs": input} - elif "generic" in self.api_base_uri: + elif "api/v2/predict/generic" in self.api_base_uri: + if isinstance(input, str): + input = [input] + items = [{"id": f"item{i}", "value": item} for i, item in enumerate(input)] + if params: + data = {"items": items, "params": json.loads(params)} + else: + data = {"items": items} + elif "api/predict/generic" in self.api_base_uri: if isinstance(input, list): input = input[0] if params: @@ -770,10 +785,13 @@ class SambaStudio(LLM): # _model_kwargs["stop_sequences"] = ",".join( # f'"{x}"' for x in _stop_sequences # ) - tuning_params_dict = { - k: {"type": type(v).__name__, "value": str(v)} - for k, v in (_model_kwargs.items()) - } + if "api/v2/predict/generic" in self.sambastudio_base_uri: + tuning_params_dict = _model_kwargs + else: + tuning_params_dict = { + k: {"type": type(v).__name__, "value": str(v)} + for k, v in (_model_kwargs.items()) + } # _model_kwargs["stop_sequences"] = _kwarg_stop_sequences tuning_params = json.dumps(tuning_params_dict) return tuning_params @@ -814,9 +832,11 @@ class SambaStudio(LLM): f"Sambanova /complete call failed with status code " f"{response['status_code']}.\n response {response}" ) - if "nlp" in self.sambastudio_base_uri: + if "api/predict/nlp" in self.sambastudio_base_uri: return response["data"][0]["completion"] - elif "generic" in self.sambastudio_base_uri: + elif "api/v2/predict/generic" in self.sambastudio_base_uri: + return response["items"][0]["value"]["completion"] + elif "api/predict/generic" in self.sambastudio_base_uri: return response["predictions"][0]["completion"] else: raise ValueError( @@ -885,10 +905,15 @@ class SambaStudio(LLM): f"{chunk['status_code']}." f"{chunk}." ) - if "nlp" in self.sambastudio_base_uri: + if "api/predict/nlp" in self.sambastudio_base_uri: text = json.loads(chunk["data"])["stream_token"] - elif "generic" in self.sambastudio_base_uri: - text = chunk["result"]["responses"][0]["stream_token"] + elif "api/v2/predict/generic" in self.sambastudio_base_uri: + text = chunk["result"]["items"][0]["value"]["stream_token"] + elif "api/predict/generic" in self.sambastudio_base_uri: + if len(chunk["result"]["responses"]) > 0: + text = chunk["result"]["responses"][0]["stream_token"] + else: + text = "" else: raise ValueError( f"handling of endpoint uri: {self.sambastudio_base_uri}"