From 2cf1e73d12db0939ba41189d517f743e038e8f01 Mon Sep 17 00:00:00 2001 From: Hin <31956487+lujingxuansc@users.noreply.github.com> Date: Tue, 2 Jan 2024 06:37:35 +0800 Subject: [PATCH] Feat add volcano embedding (#14693) Description: Volcano Ark is an enterprise-grade large-model service platform for developers, providing a full range of functions and services such as model training, inference, evaluation, fine-tuning. You can visit its homepage at https://www.volcengine.com/docs/82379/1099455 for details. This change could help developers use the platform for embedding. Issue: None Dependencies: volcengine Tag maintainer: @baskaryan Twitter handle: @hinnnnnnnnnnnns --------- Co-authored-by: lujingxuansc --- .../text_embedding/volcengine.ipynb | 123 +++++++++++++++++ .../embeddings/__init__.py | 2 + .../embeddings/volcengine.py | 128 ++++++++++++++++++ .../embeddings/test_volcano.py | 19 +++ .../unit_tests/embeddings/test_imports.py | 1 + 5 files changed, 273 insertions(+) create mode 100644 docs/docs/integrations/text_embedding/volcengine.ipynb create mode 100644 libs/community/langchain_community/embeddings/volcengine.py create mode 100644 libs/community/tests/integration_tests/embeddings/test_volcano.py diff --git a/docs/docs/integrations/text_embedding/volcengine.ipynb b/docs/docs/integrations/text_embedding/volcengine.ipynb new file mode 100644 index 00000000000..c32bfb53aee --- /dev/null +++ b/docs/docs/integrations/text_embedding/volcengine.ipynb @@ -0,0 +1,123 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "# Volc Engine\n", + "\n", + "This notebook provides you with a guide on how to load the Volcano Embedding class.\n", + "\n", + "\n", + "## API Initialization\n", + "\n", + "To use the LLM services based on [VolcEngine](https://www.volcengine.com/docs/82379/1099455), you have to initialize these parameters:\n", + "\n", + "You could either choose to init the AK,SK in environment variables or init params:\n", + "\n", + "```base\n", + "export VOLC_ACCESSKEY=XXX\n", + "export VOLC_SECRETKEY=XXX\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "start_time": "2023-12-14T03:05:29.857798Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "embed_documents result:\n", + " [0.02929673343896866, -0.009310632012784481, -0.060323506593704224, 0.0031018739100545645, -0.002218986628577113, -0.0023125179577618837, -0.04864659160375595, -2.062115163425915e-05]\n", + " [0.01987231895327568, -0.026041055098176003, -0.08395249396562576, 0.020043574273586273, -0.028862033039331436, 0.004629664588719606, -0.023107370361685753, -0.0342753604054451]\n" + ] + } + ], + "source": [ + "\"\"\"For basic init and call\"\"\"\n", + "import os\n", + "\n", + "from langchain_community.embeddings import VolcanoEmbeddings\n", + "\n", + "os.environ[\"VOLC_ACCESSKEY\"] = \"\"\n", + "os.environ[\"VOLC_SECRETKEY\"] = \"\"\n", + "\n", + "embed = VolcanoEmbeddings(volcano_ak=\"\", volcano_sk=\"\")\n", + "print(\"embed_documents result:\")\n", + "res1 = embed.embed_documents([\"foo\", \"bar\"])\n", + "for r in res1:\n", + " print(\"\", r[:8])" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "start_time": "2023-12-14T03:05:29.859276Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "embed_query result:\n", + " [0.01987231895327568, -0.026041055098176003, -0.08395249396562576, 0.020043574273586273, -0.028862033039331436, 0.004629664588719606, -0.023107370361685753, -0.0342753604054451]\n" + ] + } + ], + "source": [ + "print(\"embed_query result:\")\n", + "res2 = embed.embed_query(\"foo\")\n", + "print(\"\", r[:8])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "start_time": "2023-12-14T03:05:29.860282Z" + } + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + }, + "vscode": { + "interpreter": { + "hash": "6fa70026b407ae751a5c9e6bd7f7d482379da8ad616f98512780b705c84ee157" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/libs/community/langchain_community/embeddings/__init__.py b/libs/community/langchain_community/embeddings/__init__.py index ce9cfc7aa0b..3ae7e8ac4f4 100644 --- a/libs/community/langchain_community/embeddings/__init__.py +++ b/libs/community/langchain_community/embeddings/__init__.py @@ -78,6 +78,7 @@ from langchain_community.embeddings.sentence_transformer import ( from langchain_community.embeddings.spacy_embeddings import SpacyEmbeddings from langchain_community.embeddings.tensorflow_hub import TensorflowHubEmbeddings from langchain_community.embeddings.vertexai import VertexAIEmbeddings +from langchain_community.embeddings.volcengine import VolcanoEmbeddings from langchain_community.embeddings.voyageai import VoyageEmbeddings from langchain_community.embeddings.xinference import XinferenceEmbeddings @@ -136,6 +137,7 @@ __all__ = [ "JohnSnowLabsEmbeddings", "VoyageEmbeddings", "BookendEmbeddings", + "VolcanoEmbeddings", ] diff --git a/libs/community/langchain_community/embeddings/volcengine.py b/libs/community/langchain_community/embeddings/volcengine.py new file mode 100644 index 00000000000..98ac729b968 --- /dev/null +++ b/libs/community/langchain_community/embeddings/volcengine.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +import logging +from typing import Any, Dict, List, Optional + +from langchain_core.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.utils import get_from_dict_or_env + +logger = logging.getLogger(__name__) + + +class VolcanoEmbeddings(BaseModel, Embeddings): + """`Volcengine Embeddings` embedding models.""" + + volcano_ak: Optional[str] = None + """volcano access key + learn more from: https://www.volcengine.com/docs/6459/76491#ak-sk""" + + volcano_sk: Optional[str] = None + """volcano secret key + learn more from: https://www.volcengine.com/docs/6459/76491#ak-sk""" + + host: str = "maas-api.ml-platform-cn-beijing.volces.com" + """host + learn more from https://www.volcengine.com/docs/82379/1174746""" + region: str = "cn-beijing" + """region + learn more from https://www.volcengine.com/docs/82379/1174746""" + + model: str = "bge-large-zh" + """Model name + you could get from https://www.volcengine.com/docs/82379/1174746 + for now, we support bge_large_zh + """ + + version: str = "1.0" + """ model version """ + + chunk_size: int = 100 + """Chunk size when multiple texts are input""" + + client: Any + """volcano client""" + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """ + Validate whether volcano_ak and volcano_sk in the environment variables or + configuration file are available or not. + + init volcano embedding client with `ak`, `sk`, `host`, `region` + + Args: + + values: a dictionary containing configuration information, must include the + fields of volcano_ak and volcano_sk + Returns: + + a dictionary containing configuration information. If volcano_ak and + volcano_sk are not provided in the environment variables or configuration + file,the original values will be returned; otherwise, values containing + volcano_ak and volcano_sk will be returned. + Raises: + + ValueError: volcengine package not found, please install it with + `pip install volcengine` + """ + values["volcano_ak"] = get_from_dict_or_env( + values, + "volcano_ak", + "VOLC_ACCESSKEY", + ) + values["volcano_sk"] = get_from_dict_or_env( + values, + "volcano_sk", + "VOLC_SECRETKEY", + ) + + try: + from volcengine.maas import MaasService + + client = MaasService(values["host"], values["region"]) + client.set_ak(values["volcano_ak"]) + client.set_sk(values["volcano_sk"]) + values["client"] = client + except ImportError: + raise ImportError( + "volcengine package not found, please install it with " + "`pip install volcengine`" + ) + return values + + def embed_query(self, text: str) -> List[float]: + return self.embed_documents([text])[0] + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """ + Embeds a list of text documents using the AutoVOT algorithm. + + Args: + texts (List[str]): A list of text documents to embed. + + Returns: + List[List[float]]: A list of embeddings for each document in the input list. + Each embedding is represented as a list of float values. + """ + text_in_chunks = [ + texts[i : i + self.chunk_size] + for i in range(0, len(texts), self.chunk_size) + ] + lst = [] + for chunk in text_in_chunks: + req = { + "model": { + "name": self.model, + "version": self.version, + }, + "input": chunk, + } + try: + from volcengine.maas import MaasException + + resp = self.client.embeddings(req) + lst.extend([res["embedding"] for res in resp["data"]]) + except MaasException as e: + raise ValueError(f"embed by volcengine Error: {e}") + return lst diff --git a/libs/community/tests/integration_tests/embeddings/test_volcano.py b/libs/community/tests/integration_tests/embeddings/test_volcano.py new file mode 100644 index 00000000000..7ef7ac33fa4 --- /dev/null +++ b/libs/community/tests/integration_tests/embeddings/test_volcano.py @@ -0,0 +1,19 @@ +"""Test Bytedance Volcano Embedding.""" +from langchain_community.embeddings import VolcanoEmbeddings + + +def test_embedding_documents() -> None: + """Test embeddings for documents.""" + documents = ["foo", "bar"] + embedding = VolcanoEmbeddings() + output = embedding.embed_documents(documents) + assert len(output) == 2 + assert len(output[0]) == 1024 + + +def test_embedding_query() -> None: + """Test embeddings for query.""" + document = "foo bar" + embedding = VolcanoEmbeddings() + output = embedding.embed_query(document) + assert len(output) == 1024 diff --git a/libs/community/tests/unit_tests/embeddings/test_imports.py b/libs/community/tests/unit_tests/embeddings/test_imports.py index d33d98e493b..cd91e675da5 100644 --- a/libs/community/tests/unit_tests/embeddings/test_imports.py +++ b/libs/community/tests/unit_tests/embeddings/test_imports.py @@ -53,6 +53,7 @@ EXPECTED_ALL = [ "JohnSnowLabsEmbeddings", "VoyageEmbeddings", "BookendEmbeddings", + "VolcanoEmbeddings", ]