diff --git a/libs/community/langchain_community/embeddings/sambanova.py b/libs/community/langchain_community/embeddings/sambanova.py index 6f86b417454..57601b5c629 100644 --- a/libs/community/langchain_community/embeddings/sambanova.py +++ b/libs/community/langchain_community/embeddings/sambanova.py @@ -98,10 +98,13 @@ class SambaStudioEmbeddings(BaseModel, Embeddings): Returns: The tuning parameters as a JSON string. """ - tuning_params_dict = { - k: {"type": type(v).__name__, "value": str(v)} - for k, v in (self.model_kwargs.items()) - } + if "api/v2/predict/generic" in self.sambastudio_embeddings_base_uri: + tuning_params_dict = self.model_kwargs + else: + tuning_params_dict = { + k: {"type": type(v).__name__, "value": str(v)} + for k, v in (self.model_kwargs.items()) + } tuning_params = json.dumps(tuning_params_dict) return tuning_params @@ -148,7 +151,7 @@ class SambaStudioEmbeddings(BaseModel, Embeddings): params = json.loads(self._get_tuning_params()) embeddings = [] - if "nlp" in self.sambastudio_embeddings_base_uri: + if "api/predict/nlp" in self.sambastudio_embeddings_base_uri: for batch in self._iterate_over_batches(texts, batch_size): data = {"inputs": batch, "params": params} response = http_session.post( @@ -156,6 +159,11 @@ class SambaStudioEmbeddings(BaseModel, Embeddings): headers={"key": self.sambastudio_embeddings_api_key}, json=data, ) + if response.status_code != 200: + raise RuntimeError( + f"Sambanova /complete call failed with status code " + f"{response.status_code}.\n Details: {response.text}" + ) try: embedding = response.json()["data"] embeddings.extend(embedding) @@ -165,7 +173,32 @@ class SambaStudioEmbeddings(BaseModel, Embeddings): response.json(), ) - elif "generic" in self.sambastudio_embeddings_base_uri: + elif "api/v2/predict/generic" in self.sambastudio_embeddings_base_uri: + for batch in self._iterate_over_batches(texts, batch_size): + items = [ + {"id": f"item{i}", "value": item} for i, item in enumerate(batch) + ] + data = {"items": items, "params": params} + response = http_session.post( + url, + headers={"key": self.sambastudio_embeddings_api_key}, + json=data, + ) + if response.status_code != 200: + raise RuntimeError( + f"Sambanova /complete call failed with status code " + f"{response.status_code}.\n Details: {response.text}" + ) + try: + embedding = [item["value"] for item in response.json()["items"]] + embeddings.extend(embedding) + except KeyError: + raise KeyError( + "'items' not found in endpoint response", + response.json(), + ) + + elif "api/predict/generic" in self.sambastudio_embeddings_base_uri: for batch in self._iterate_over_batches(texts, batch_size): data = {"instances": batch, "params": params} response = http_session.post( @@ -173,6 +206,11 @@ class SambaStudioEmbeddings(BaseModel, Embeddings): headers={"key": self.sambastudio_embeddings_api_key}, json=data, ) + if response.status_code != 200: + raise RuntimeError( + f"Sambanova /complete call failed with status code " + f"{response.status_code}.\n Details: {response.text}" + ) try: if params.get("select_expert"): embedding = response.json()["predictions"][0] @@ -207,13 +245,18 @@ class SambaStudioEmbeddings(BaseModel, Embeddings): ) params = json.loads(self._get_tuning_params()) - if "nlp" in self.sambastudio_embeddings_base_uri: + if "api/predict/nlp" in self.sambastudio_embeddings_base_uri: data = {"inputs": [text], "params": params} response = http_session.post( url, headers={"key": self.sambastudio_embeddings_api_key}, json=data, ) + if response.status_code != 200: + raise RuntimeError( + f"Sambanova /complete call failed with status code " + f"{response.status_code}.\n Details: {response.text}" + ) try: embedding = response.json()["data"][0] except KeyError: @@ -222,13 +265,38 @@ class SambaStudioEmbeddings(BaseModel, Embeddings): response.json(), ) - elif "generic" in self.sambastudio_embeddings_base_uri: + elif "api/v2/predict/generic" in self.sambastudio_embeddings_base_uri: + data = {"items": [{"id": "item0", "value": text}], "params": params} + response = http_session.post( + url, + headers={"key": self.sambastudio_embeddings_api_key}, + json=data, + ) + if response.status_code != 200: + raise RuntimeError( + f"Sambanova /complete call failed with status code " + f"{response.status_code}.\n Details: {response.text}" + ) + try: + embedding = response.json()["items"][0]["value"] + except KeyError: + raise KeyError( + "'items' not found in endpoint response", + response.json(), + ) + + elif "api/predict/generic" in self.sambastudio_embeddings_base_uri: data = {"instances": [text], "params": params} response = http_session.post( url, headers={"key": self.sambastudio_embeddings_api_key}, json=data, ) + if response.status_code != 200: + raise RuntimeError( + f"Sambanova /complete call failed with status code " + f"{response.status_code}.\n Details: {response.text}" + ) try: if params.get("select_expert"): embedding = response.json()["predictions"][0][0]