From e65652c3e81c248feb15325dc96f3a2bd56aa475 Mon Sep 17 00:00:00 2001 From: Jorge Piedrahita Ortiz Date: Mon, 6 May 2024 15:29:59 -0500 Subject: [PATCH] community: add SambaNova embeddings integration (#21227) - **Description:** SambaNova hosted embeddings integration --- .../text_embedding/sambanova.ipynb | 91 +++++++++++ .../embeddings/__init__.py | 5 + .../embeddings/sambanova.py | 142 ++++++++++++++++++ .../embeddings/test_sambanova.py | 22 +++ .../unit_tests/embeddings/test_imports.py | 1 + 5 files changed, 261 insertions(+) create mode 100644 docs/docs/integrations/text_embedding/sambanova.ipynb create mode 100644 libs/community/langchain_community/embeddings/sambanova.py create mode 100644 libs/community/tests/integration_tests/embeddings/test_sambanova.py diff --git a/docs/docs/integrations/text_embedding/sambanova.ipynb b/docs/docs/integrations/text_embedding/sambanova.ipynb new file mode 100644 index 00000000000..f0e4131aa27 --- /dev/null +++ b/docs/docs/integrations/text_embedding/sambanova.ipynb @@ -0,0 +1,91 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# SambaNova\n", + "\n", + "**[SambaNova](https://sambanova.ai/)'s** [Sambastudio](https://sambanova.ai/technology/full-stack-ai-platform) is a platform for running your own open-source models\n", + "\n", + "This example goes over how to use LangChain to interact with SambaNova embedding models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## SambaStudio\n", + "\n", + "**SambaStudio** allows you to train, run batch inference jobs, and deploy online inference endpoints to run open source models that you fine tuned yourself." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A SambaStudio environment is required to deploy a model. Get more information at [sambanova.ai/products/enterprise-ai-platform-sambanova-suite](https://sambanova.ai/products/enterprise-ai-platform-sambanova-suite)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Register your environment variables:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "sambastudio_base_url = \"\"\n", + "sambastudio_project_id = \"\"\n", + "sambastudio_endpoint_id = \"\"\n", + "sambastudio_api_key = \"\"\n", + "\n", + "# Set the environment variables\n", + "os.environ[\"SAMBASTUDIO_EMBEDDINGS_BASE_URL\"] = sambastudio_base_url\n", + "os.environ[\"SAMBASTUDIO_EMBEDDINGS_PROJECT_ID\"] = sambastudio_project_id\n", + "os.environ[\"SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID\"] = sambastudio_endpoint_id\n", + "os.environ[\"SAMBASTUDIO_EMBEDDINGS_API_KEY\"] = sambastudio_api_key" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Call SambaStudio hosted embeddings directly from LangChain!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.embeddings.sambanova import SambaStudioEmbeddings\n", + "\n", + "embeddings = SambaStudioEmbeddings()\n", + "\n", + "text = \"Hello, this is a test\"\n", + "result = embeddings.embed_query(text)\n", + "print(result)\n", + "\n", + "texts = [\"Hello, this is a test\", \"Hello, this is another test\"]\n", + "results = embeddings.embed_documents(texts)\n", + "print(results)" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/libs/community/langchain_community/embeddings/__init__.py b/libs/community/langchain_community/embeddings/__init__.py index 88e3777801d..499f79453a2 100644 --- a/libs/community/langchain_community/embeddings/__init__.py +++ b/libs/community/langchain_community/embeddings/__init__.py @@ -178,6 +178,9 @@ if TYPE_CHECKING: from langchain_community.embeddings.sagemaker_endpoint import ( SagemakerEndpointEmbeddings, ) + from langchain_community.embeddings.sambanova import ( + SambaStudioEmbeddings, + ) from langchain_community.embeddings.self_hosted import ( SelfHostedEmbeddings, ) @@ -276,6 +279,7 @@ __all__ = [ "QuantizedBgeEmbeddings", "QuantizedBiEncoderEmbeddings", "SagemakerEndpointEmbeddings", + "SambaStudioEmbeddings", "SelfHostedEmbeddings", "SelfHostedHuggingFaceEmbeddings", "SelfHostedHuggingFaceInstructEmbeddings", @@ -350,6 +354,7 @@ _module_lookup = { "QuantizedBiEncoderEmbeddings": "langchain_community.embeddings.optimum_intel", "OracleEmbeddings": "langchain_community.embeddings.oracleai", "SagemakerEndpointEmbeddings": "langchain_community.embeddings.sagemaker_endpoint", + "SambaStudioEmbeddings": "langchain_community.embeddings.sambanova", "SelfHostedEmbeddings": "langchain_community.embeddings.self_hosted", "SelfHostedHuggingFaceEmbeddings": "langchain_community.embeddings.self_hosted_hugging_face", # noqa: E501 "SelfHostedHuggingFaceInstructEmbeddings": "langchain_community.embeddings.self_hosted_hugging_face", # noqa: E501 diff --git a/libs/community/langchain_community/embeddings/sambanova.py b/libs/community/langchain_community/embeddings/sambanova.py new file mode 100644 index 00000000000..a0efa3b685f --- /dev/null +++ b/libs/community/langchain_community/embeddings/sambanova.py @@ -0,0 +1,142 @@ +from typing import Dict, Generator, List + +import requests +from langchain_core.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.utils import get_from_dict_or_env + + +class SambaStudioEmbeddings(BaseModel, Embeddings): + """SambaNova embedding models. + + To use, you should have the environment variables + ``SAMBASTUDIO_EMBEDDINGS_BASE_URL``, ``SAMBASTUDIO_EMBEDDINGS_PROJECT_ID``, + ``SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID``, ``SAMBASTUDIO_EMBEDDINGS_API_KEY``, + set with your personal sambastudio variable or pass it as a named parameter + to the constructor. + + Example: + .. code-block:: python + + from langchain_community.embeddings import SambaStudioEmbeddings + embeddings = SambaStudioEmbeddings(sambastudio_embeddings_base_url=base_url, + sambastudio_embeddings_project_id=project_id, + sambastudio_embeddings_endpoint_id=endpoint_id, + sambastudio_embeddings_api_key=api_key) + (or) + embeddings = SambaStudioEmbeddings() + """ + + API_BASE_PATH = "/api/predict/nlp/" + """Base path to use for the API usage""" + + sambastudio_embeddings_base_url: str = "" + """Base url to use""" + + sambastudio_embeddings_project_id: str = "" + """Project id on sambastudio for model""" + + sambastudio_embeddings_endpoint_id: str = "" + """endpoint id on sambastudio for model""" + + sambastudio_embeddings_api_key: str = "" + """sambastudio api key""" + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + values["sambastudio_embeddings_base_url"] = get_from_dict_or_env( + values, "sambastudio_embeddings_base_url", "SAMBASTUDIO_EMBEDDINGS_BASE_URL" + ) + values["sambastudio_embeddings_project_id"] = get_from_dict_or_env( + values, + "sambastudio_embeddings_project_id", + "SAMBASTUDIO_EMBEDDINGS_PROJECT_ID", + ) + values["sambastudio_embeddings_endpoint_id"] = get_from_dict_or_env( + values, + "sambastudio_embeddings_endpoint_id", + "SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID", + ) + values["sambastudio_embeddings_api_key"] = get_from_dict_or_env( + values, "sambastudio_embeddings_api_key", "SAMBASTUDIO_EMBEDDINGS_API_KEY" + ) + return values + + def _get_full_url(self, path: str) -> str: + """ + Return the full API URL for a given path. + + :param str path: the sub-path + :returns: the full API URL for the sub-path + :rtype: str + """ + return f"{self.sambastudio_embeddings_base_url}{self.API_BASE_PATH}{path}" + + def _iterate_over_batches(self, texts: List[str], batch_size: int) -> Generator: + """Generator for creating batches in the embed documents method + Args: + texts (List[str]): list of strings to embed + batch_size (int, optional): batch size to be used for the embedding model. + Will depend on the RDU endpoint used. + Yields: + List[str]: list (batch) of strings of size batch size + """ + for i in range(0, len(texts), batch_size): + yield texts[i : i + batch_size] + + def embed_documents( + self, texts: List[str], batch_size: int = 32 + ) -> List[List[float]]: + """Returns a list of embeddings for the given sentences. + Args: + texts (`List[str]`): List of texts to encode + batch_size (`int`): Batch size for the encoding + + Returns: + `List[np.ndarray]` or `List[tensor]`: List of embeddings + for the given sentences + """ + http_session = requests.Session() + url = self._get_full_url( + f"{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}" + ) + + embeddings = [] + + for batch in self._iterate_over_batches(texts, batch_size): + data = {"inputs": batch} + response = http_session.post( + url, + headers={"key": self.sambastudio_embeddings_api_key}, + json=data, + ) + embedding = response.json()["data"] + embeddings.extend(embedding) + + return embeddings + + def embed_query(self, text: str) -> List[float]: + """Returns a list of embeddings for the given sentences. + Args: + sentences (`List[str]`): List of sentences to encode + + Returns: + `List[np.ndarray]` or `List[tensor]`: List of embeddings + for the given sentences + """ + http_session = requests.Session() + url = self._get_full_url( + f"{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}" + ) + + data = {"inputs": [text]} + + response = http_session.post( + url, + headers={"key": self.sambastudio_embeddings_api_key}, + json=data, + ) + embedding = response.json()["data"][0] + + return embedding diff --git a/libs/community/tests/integration_tests/embeddings/test_sambanova.py b/libs/community/tests/integration_tests/embeddings/test_sambanova.py new file mode 100644 index 00000000000..91024eab7c1 --- /dev/null +++ b/libs/community/tests/integration_tests/embeddings/test_sambanova.py @@ -0,0 +1,22 @@ +"""Test SambaNova Embeddings.""" + +from langchain_community.embeddings.sambanova import ( + SambaStudioEmbeddings, +) + + +def test_embedding_documents() -> None: + """Test embeddings for documents.""" + documents = ["foo", "bar"] + embedding = SambaStudioEmbeddings() + output = embedding.embed_documents(documents) + assert len(output) == 2 + assert len(output[0]) == 1024 + + +def test_embedding_query() -> None: + """Test embeddings for query.""" + document = "foo bar" + embedding = SambaStudioEmbeddings() + output = embedding.embed_query(document) + assert len(output) == 1024 diff --git a/libs/community/tests/unit_tests/embeddings/test_imports.py b/libs/community/tests/unit_tests/embeddings/test_imports.py index 61fb228bd7c..f059e525051 100644 --- a/libs/community/tests/unit_tests/embeddings/test_imports.py +++ b/libs/community/tests/unit_tests/embeddings/test_imports.py @@ -67,6 +67,7 @@ EXPECTED_ALL = [ "QuantizedBiEncoderEmbeddings", "NeMoEmbeddings", "SparkLLMTextEmbeddings", + "SambaStudioEmbeddings", "TitanTakeoffEmbed", "QuantizedBgeEmbeddings", "PremAIEmbeddings",