Feat add volcano embedding (#14693)

Description: Volcano Ark is an enterprise-grade large-model service
platform for developers, providing a full range of functions and
services such as model training, inference, evaluation, fine-tuning. You
can visit its homepage at https://www.volcengine.com/docs/82379/1099455
for details. This change could help developers use the platform for
embedding.
Issue: None
Dependencies: volcengine
Tag maintainer: @baskaryan
Twitter handle: @hinnnnnnnnnnnns

---------

Co-authored-by: lujingxuansc <lujingxuansc@bytedance.com>
This commit is contained in:
Hin 2024-01-02 06:37:35 +08:00 committed by GitHub
parent 81a7a83b21
commit 2cf1e73d12
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 273 additions and 0 deletions

View File

@ -0,0 +1,123 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"# Volc Engine\n",
"\n",
"This notebook provides you with a guide on how to load the Volcano Embedding class.\n",
"\n",
"\n",
"## API Initialization\n",
"\n",
"To use the LLM services based on [VolcEngine](https://www.volcengine.com/docs/82379/1099455), you have to initialize these parameters:\n",
"\n",
"You could either choose to init the AK,SK in environment variables or init params:\n",
"\n",
"```base\n",
"export VOLC_ACCESSKEY=XXX\n",
"export VOLC_SECRETKEY=XXX\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"start_time": "2023-12-14T03:05:29.857798Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"embed_documents result:\n",
" [0.02929673343896866, -0.009310632012784481, -0.060323506593704224, 0.0031018739100545645, -0.002218986628577113, -0.0023125179577618837, -0.04864659160375595, -2.062115163425915e-05]\n",
" [0.01987231895327568, -0.026041055098176003, -0.08395249396562576, 0.020043574273586273, -0.028862033039331436, 0.004629664588719606, -0.023107370361685753, -0.0342753604054451]\n"
]
}
],
"source": [
"\"\"\"For basic init and call\"\"\"\n",
"import os\n",
"\n",
"from langchain_community.embeddings import VolcanoEmbeddings\n",
"\n",
"os.environ[\"VOLC_ACCESSKEY\"] = \"\"\n",
"os.environ[\"VOLC_SECRETKEY\"] = \"\"\n",
"\n",
"embed = VolcanoEmbeddings(volcano_ak=\"\", volcano_sk=\"\")\n",
"print(\"embed_documents result:\")\n",
"res1 = embed.embed_documents([\"foo\", \"bar\"])\n",
"for r in res1:\n",
" print(\"\", r[:8])"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"start_time": "2023-12-14T03:05:29.859276Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"embed_query result:\n",
" [0.01987231895327568, -0.026041055098176003, -0.08395249396562576, 0.020043574273586273, -0.028862033039331436, 0.004629664588719606, -0.023107370361685753, -0.0342753604054451]\n"
]
}
],
"source": [
"print(\"embed_query result:\")\n",
"res2 = embed.embed_query(\"foo\")\n",
"print(\"\", r[:8])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"start_time": "2023-12-14T03:05:29.860282Z"
}
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
},
"vscode": {
"interpreter": {
"hash": "6fa70026b407ae751a5c9e6bd7f7d482379da8ad616f98512780b705c84ee157"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@ -78,6 +78,7 @@ from langchain_community.embeddings.sentence_transformer import (
from langchain_community.embeddings.spacy_embeddings import SpacyEmbeddings
from langchain_community.embeddings.tensorflow_hub import TensorflowHubEmbeddings
from langchain_community.embeddings.vertexai import VertexAIEmbeddings
from langchain_community.embeddings.volcengine import VolcanoEmbeddings
from langchain_community.embeddings.voyageai import VoyageEmbeddings
from langchain_community.embeddings.xinference import XinferenceEmbeddings
@ -136,6 +137,7 @@ __all__ = [
"JohnSnowLabsEmbeddings",
"VoyageEmbeddings",
"BookendEmbeddings",
"VolcanoEmbeddings",
]

View File

@ -0,0 +1,128 @@
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class VolcanoEmbeddings(BaseModel, Embeddings):
"""`Volcengine Embeddings` embedding models."""
volcano_ak: Optional[str] = None
"""volcano access key
learn more from: https://www.volcengine.com/docs/6459/76491#ak-sk"""
volcano_sk: Optional[str] = None
"""volcano secret key
learn more from: https://www.volcengine.com/docs/6459/76491#ak-sk"""
host: str = "maas-api.ml-platform-cn-beijing.volces.com"
"""host
learn more from https://www.volcengine.com/docs/82379/1174746"""
region: str = "cn-beijing"
"""region
learn more from https://www.volcengine.com/docs/82379/1174746"""
model: str = "bge-large-zh"
"""Model name
you could get from https://www.volcengine.com/docs/82379/1174746
for now, we support bge_large_zh
"""
version: str = "1.0"
""" model version """
chunk_size: int = 100
"""Chunk size when multiple texts are input"""
client: Any
"""volcano client"""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""
Validate whether volcano_ak and volcano_sk in the environment variables or
configuration file are available or not.
init volcano embedding client with `ak`, `sk`, `host`, `region`
Args:
values: a dictionary containing configuration information, must include the
fields of volcano_ak and volcano_sk
Returns:
a dictionary containing configuration information. If volcano_ak and
volcano_sk are not provided in the environment variables or configuration
file,the original values will be returned; otherwise, values containing
volcano_ak and volcano_sk will be returned.
Raises:
ValueError: volcengine package not found, please install it with
`pip install volcengine`
"""
values["volcano_ak"] = get_from_dict_or_env(
values,
"volcano_ak",
"VOLC_ACCESSKEY",
)
values["volcano_sk"] = get_from_dict_or_env(
values,
"volcano_sk",
"VOLC_SECRETKEY",
)
try:
from volcengine.maas import MaasService
client = MaasService(values["host"], values["region"])
client.set_ak(values["volcano_ak"])
client.set_sk(values["volcano_sk"])
values["client"] = client
except ImportError:
raise ImportError(
"volcengine package not found, please install it with "
"`pip install volcengine`"
)
return values
def embed_query(self, text: str) -> List[float]:
return self.embed_documents([text])[0]
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
Embeds a list of text documents using the AutoVOT algorithm.
Args:
texts (List[str]): A list of text documents to embed.
Returns:
List[List[float]]: A list of embeddings for each document in the input list.
Each embedding is represented as a list of float values.
"""
text_in_chunks = [
texts[i : i + self.chunk_size]
for i in range(0, len(texts), self.chunk_size)
]
lst = []
for chunk in text_in_chunks:
req = {
"model": {
"name": self.model,
"version": self.version,
},
"input": chunk,
}
try:
from volcengine.maas import MaasException
resp = self.client.embeddings(req)
lst.extend([res["embedding"] for res in resp["data"]])
except MaasException as e:
raise ValueError(f"embed by volcengine Error: {e}")
return lst

View File

@ -0,0 +1,19 @@
"""Test Bytedance Volcano Embedding."""
from langchain_community.embeddings import VolcanoEmbeddings
def test_embedding_documents() -> None:
"""Test embeddings for documents."""
documents = ["foo", "bar"]
embedding = VolcanoEmbeddings()
output = embedding.embed_documents(documents)
assert len(output) == 2
assert len(output[0]) == 1024
def test_embedding_query() -> None:
"""Test embeddings for query."""
document = "foo bar"
embedding = VolcanoEmbeddings()
output = embedding.embed_query(document)
assert len(output) == 1024

View File

@ -53,6 +53,7 @@ EXPECTED_ALL = [
"JohnSnowLabsEmbeddings",
"VoyageEmbeddings",
"BookendEmbeddings",
"VolcanoEmbeddings",
]