DB-GPT/dbgpt/rag/embedding/embeddings.py
2024-08-29 16:37:31 +08:00

1028 lines
34 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Embedding implementations."""
from typing import Any, Dict, List, Optional
import aiohttp
import requests
from dbgpt._private.pydantic import EXTRA_FORBID, BaseModel, ConfigDict, Field
from dbgpt.core import Embeddings
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
from dbgpt.util.i18n_utils import _
from dbgpt.util.tracer import DBGPT_TRACER_SPAN_ID, root_tracer
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large"
DEFAULT_BGE_MODEL = "BAAI/bge-large-en"
DEFAULT_EMBED_INSTRUCTION = "Represent the document for retrieval: "
DEFAULT_QUERY_INSTRUCTION = (
"Represent the question for retrieving supporting documents: "
)
DEFAULT_QUERY_BGE_INSTRUCTION_EN = (
"Represent this question for searching relevant passages: "
)
DEFAULT_QUERY_BGE_INSTRUCTION_ZH = "为这个句子生成表示以用于检索相关文章:"
@register_resource(
_("HuggingFace Embeddings"),
"huggingface_embeddings",
category=ResourceCategory.EMBEDDINGS,
description=_("HuggingFace sentence_transformers embedding models."),
parameters=[
Parameter.build_from(
_("Model Name"),
"model_name",
str,
optional=True,
default=DEFAULT_MODEL_NAME,
description=_("Model name to use."),
),
# TODO, support more parameters
],
)
class HuggingFaceEmbeddings(BaseModel, Embeddings):
"""HuggingFace sentence_transformers embedding models.
To use, you should have the ``sentence_transformers`` python package installed.
Refer to `Langchain Embeddings <https://github.com/langchain-ai/langchain/tree/
master/libs/langchain/langchain/embeddings>`_.
Example:
.. code-block:: python
from dbgpt.rag.embedding import HuggingFaceEmbeddings
model_name = "sentence-transformers/all-mpnet-base-v2"
model_kwargs = {"device": "cpu"}
encode_kwargs = {"normalize_embeddings": False}
hf = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
"""
model_config = ConfigDict(extra=EXTRA_FORBID, protected_namespaces=())
client: Any #: :meta private:
model_name: str = DEFAULT_MODEL_NAME
"""Model name to use."""
cache_folder: Optional[str] = Field(None, description="Path of the cache folder.")
"""Path to store models. Can be also set by SENTENCE_TRANSFORMERS_HOME
environment variable."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass to the model."""
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass when calling the `encode` method of the model."""
multi_process: bool = False
"""Run encode() on multiple GPUs."""
def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
try:
import sentence_transformers
except ImportError as exc:
raise ImportError(
"Could not import sentence_transformers python package. "
"Please install it with `pip install sentence-transformers`."
) from exc
kwargs["client"] = sentence_transformers.SentenceTransformer(
kwargs.get("model_name") or DEFAULT_MODEL_NAME,
cache_folder=kwargs.get("cache_folder"),
**(kwargs.get("model_kwargs") or {}),
)
super().__init__(**kwargs)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace transformer model.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
import sentence_transformers
texts = list(map(lambda x: x.replace("\n", " "), texts))
if self.multi_process:
pool = self.client.start_multi_process_pool()
embeddings = self.client.encode_multi_process(texts, pool)
sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool)
else:
embeddings = self.client.encode(texts, **self.encode_kwargs)
return embeddings.tolist()
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a HuggingFace transformer model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]
@register_resource(
_("HuggingFace Instructor Embeddings"),
"huggingface_instructor_embeddings",
category=ResourceCategory.EMBEDDINGS,
description=_("HuggingFace Instructor embeddings."),
parameters=[
Parameter.build_from(
_("Model Name"),
"model_name",
str,
optional=True,
default=DEFAULT_INSTRUCT_MODEL,
description=_("Model name to use."),
),
Parameter.build_from(
_("Embed Instruction"),
"embed_instruction",
str,
optional=True,
default=DEFAULT_EMBED_INSTRUCTION,
description=_("Instruction to use for embedding documents."),
),
Parameter.build_from(
_("Query Instruction"),
"query_instruction",
str,
optional=True,
default=DEFAULT_QUERY_INSTRUCTION,
description=_("Instruction to use for embedding query."),
),
],
)
class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
"""Wrapper around sentence_transformers embedding models.
To use, you should have the ``sentence_transformers``
and ``InstructorEmbedding`` python packages installed.
Example:
.. code-block:: python
from dbgpt.rag.embeddings import HuggingFaceInstructEmbeddings
model_name = "hkunlp/instructor-large"
model_kwargs = {"device": "cpu"}
encode_kwargs = {"normalize_embeddings": True}
hf = HuggingFaceInstructEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
"""
model_config = ConfigDict(extra=EXTRA_FORBID, protected_namespaces=())
client: Any #: :meta private:
model_name: str = DEFAULT_INSTRUCT_MODEL
"""Model name to use."""
cache_folder: Optional[str] = None
"""Path to store models. Can be also set by SENTENCE_TRANSFORMERS_HOME
environment variable."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass to the model."""
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass when calling the `encode` method of the model."""
embed_instruction: str = DEFAULT_EMBED_INSTRUCTION
"""Instruction to use for embedding documents."""
query_instruction: str = DEFAULT_QUERY_INSTRUCTION
"""Instruction to use for embedding query."""
def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
try:
from InstructorEmbedding import INSTRUCTOR
kwargs["client"] = INSTRUCTOR(
kwargs.get("model_name"),
cache_folder=kwargs.get("cache_folder"),
**kwargs.get("model_kwargs"),
)
except ImportError as e:
raise ImportError("Dependencies for InstructorEmbedding not found.") from e
super().__init__(**kwargs)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace instruct model.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
instruction_pairs = [[self.embed_instruction, text] for text in texts]
embeddings = self.client.encode(instruction_pairs, **self.encode_kwargs)
return embeddings.tolist()
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a HuggingFace instruct model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
instruction_pair = [self.query_instruction, text]
embedding = self.client.encode([instruction_pair], **self.encode_kwargs)[0]
return embedding.tolist()
# TODO: Support AWEL flow
class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
"""HuggingFace BGE sentence_transformers embedding models.
To use, you should have the ``sentence_transformers`` python package installed.
refer to `Langchain Embeddings <https://github.com/langchain-ai/langchain/tree/
master/libs/langchain/langchain/embeddings>`_.
Example:
.. code-block:: python
from dbgpt.rag.embeddings import HuggingFaceBgeEmbeddings
model_name = "BAAI/bge-large-en"
model_kwargs = {"device": "cpu"}
encode_kwargs = {"normalize_embeddings": True}
hf = HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
"""
model_config = ConfigDict(extra=EXTRA_FORBID, protected_namespaces=())
client: Any #: :meta private:
model_name: str = DEFAULT_BGE_MODEL
"""Model name to use."""
cache_folder: Optional[str] = None
"""Path to store models.
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass to the model."""
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass when calling the `encode` method of the model."""
query_instruction: str = DEFAULT_QUERY_BGE_INSTRUCTION_EN
"""Instruction to use for embedding query."""
def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
try:
import sentence_transformers
except ImportError as exc:
raise ImportError(
"Could not import sentence_transformers python package. "
"Please install it with `pip install sentence_transformers`."
) from exc
kwargs["client"] = sentence_transformers.SentenceTransformer(
kwargs.get("model_name"),
cache_folder=kwargs.get("cache_folder"),
**(kwargs.get("model_kwargs") or {}),
)
super().__init__(**kwargs)
if "-zh" in self.model_name:
self.query_instruction = DEFAULT_QUERY_BGE_INSTRUCTION_ZH
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace transformer model.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
texts = [t.replace("\n", " ") for t in texts]
embeddings = self.client.encode(texts, **self.encode_kwargs)
return embeddings.tolist()
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a HuggingFace transformer model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
text = text.replace("\n", " ")
embedding = self.client.encode(
self.query_instruction + text, **self.encode_kwargs
)
return embedding.tolist()
@register_resource(
_("HuggingFace Inference API Embeddings"),
"huggingface_inference_api_embeddings",
category=ResourceCategory.EMBEDDINGS,
description=_("HuggingFace Inference API embeddings."),
parameters=[
Parameter.build_from(
_("API Key"),
"api_key",
str,
description=_("Your API key for the HuggingFace Inference API."),
),
Parameter.build_from(
_("Model Name"),
"model_name",
str,
optional=True,
default="sentence-transformers/all-MiniLM-L6-v2",
description=_("The name of the model to use for text embeddings."),
),
],
)
class HuggingFaceInferenceAPIEmbeddings(BaseModel, Embeddings):
"""Embed texts using the HuggingFace API.
Requires a HuggingFace Inference API key and a model name.
"""
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
api_key: str
"""Your API key for the HuggingFace Inference API."""
model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
"""The name of the model to use for text embeddings."""
@property
def _api_url(self) -> str:
return (
"https://api-inference.huggingface.co"
"/pipeline"
"/feature-extraction"
f"/{self.model_name}"
)
@property
def _headers(self) -> dict:
return {"Authorization": f"Bearer {self.api_key}"}
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Get the embeddings for a list of texts.
Args:
texts (Documents): A list of texts to get embeddings for.
Returns:
Embedded texts as List[List[float]], where each inner List[float]
corresponds to a single input text.
Example:
.. code-block:: python
from langchain.embeddings import HuggingFaceInferenceAPIEmbeddings
hf_embeddings = HuggingFaceInferenceAPIEmbeddings(
api_key="your_api_key",
model_name="sentence-transformers/all-MiniLM-l6-v2",
)
texts = ["Hello, world!", "How are you?"]
hf_embeddings.embed_documents(texts)
"""
response = requests.post(
self._api_url,
headers=self._headers,
json={
"inputs": texts,
"options": {"wait_for_model": True, "use_cache": True},
},
)
return response.json()
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a HuggingFace transformer model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]
def _handle_request_result(res: requests.Response) -> List[List[float]]:
"""Parse the result from a request.
Args:
res: The response from the request.
Returns:
List[List[float]]: The embeddings.
Raises:
RuntimeError: If the response is not successful.
"""
res.raise_for_status()
resp = res.json()
if "data" not in resp:
raise RuntimeError(resp["detail"])
embeddings = resp["data"]
# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
# Return just the embeddings
return [result["embedding"] for result in sorted_embeddings]
@register_resource(
_("Jina AI Embeddings"),
"jina_embeddings",
category=ResourceCategory.EMBEDDINGS,
description=_("Jina AI embeddings."),
parameters=[
Parameter.build_from(
_("API Key"),
"api_key",
str,
description=_("Your API key for the Jina AI API."),
),
Parameter.build_from(
_("Model Name"),
"model_name",
str,
optional=True,
default="jina-embeddings-v2-base-en",
description=_("The name of the model to use for text embeddings."),
),
],
)
class JinaEmbeddings(BaseModel, Embeddings):
"""Jina AI embeddings.
This class is used to get embeddings for a list of texts using the Jina AI API.
It requires an API key and a model name. The default model name is
"jina-embeddings-v2-base-en".
"""
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
api_url: Any #: :meta private:
session: Any #: :meta private:
api_key: str
"""API key for the Jina AI API.."""
model_name: str = "jina-embeddings-v2-base-en"
"""The name of the model to use for text embeddings. Defaults to
"jina-embeddings-v2-base-en"."""
def __init__(self, **kwargs):
"""Create a new JinaEmbeddings instance."""
try:
import requests
except ImportError:
raise ValueError(
"The requests python package is not installed. Please install it with "
"`pip install requests`"
)
if "api_url" not in kwargs:
kwargs["api_url"] = "https://api.jina.ai/v1/embeddings"
if "session" not in kwargs: # noqa: SIM401
session = requests.Session()
else:
session = kwargs["session"]
api_key = kwargs.get("api_key")
if api_key:
session.headers.update(
{
"Authorization": f"Bearer {api_key}",
"Accept-Encoding": "identity",
}
)
kwargs["session"] = session
super().__init__(**kwargs)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Get the embeddings for a list of texts.
Args:
texts (Documents): A list of texts to get embeddings for.
Returns:
Embedded texts as List[List[float]], where each inner List[float]
corresponds to a single input text.
"""
# Call Jina AI Embedding API
resp = self.session.post( # type: ignore
self.api_url, json={"input": texts, "model": self.model_name}
)
return _handle_request_result(resp)
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a Jina AI embedding model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]
@register_resource(
_("OpenAPI Embeddings"),
"openapi_embeddings",
category=ResourceCategory.EMBEDDINGS,
description=_("OpenAPI embeddings."),
parameters=[
Parameter.build_from(
_("API URL"),
"api_url",
str,
optional=True,
default="http://localhost:8100/api/v1/embeddings",
description=_("The URL of the embeddings API."),
),
Parameter.build_from(
_("API Key"),
"api_key",
str,
optional=True,
default=None,
description=_("Your API key for the Open API."),
),
Parameter.build_from(
_("Model Name"),
"model_name",
str,
optional=True,
default="text2vec",
description=_("The name of the model to use for text embeddings."),
),
Parameter.build_from(
_("Timeout"),
"timeout",
int,
optional=True,
default=60,
description=_("The timeout for the request in seconds."),
),
],
)
class OpenAPIEmbeddings(BaseModel, Embeddings):
"""The OpenAPI embeddings.
This class is used to get embeddings for a list of texts using the API.
This API is compatible with the OpenAI Embedding API.
Examples:
Using OpenAI's API:
.. code-block:: python
from dbgpt.rag.embedding import OpenAPIEmbeddings
openai_embeddings = OpenAPIEmbeddings(
api_url="https://api.openai.com/v1/embeddings",
api_key="your_api_key",
model_name="text-embedding-3-small",
)
texts = ["Hello, world!", "How are you?"]
openai_embeddings.embed_documents(texts)
Using DB-GPT APIServer's embedding API:
To use the DB-GPT APIServer's embedding API, you should deploy DB-GPT according
to the `Cluster Deploy
<https://docs.dbgpt.site/docs/installation/model_service/cluster>`_.
A simple example:
1. Deploy Model Cluster with following command:
.. code-block:: bash
dbgpt start controller --port 8000
2. Deploy Embedding Model Worker with following command:
.. code-block:: bash
dbgpt start worker --model_name text2vec \
--model_path /app/models/text2vec-large-chinese \
--worker_type text2vec \
--port 8003 \
--controller_addr http://127.0.0.1:8000
3. Deploy API Server with following command:
.. code-block:: bash
dbgpt start apiserver --controller_addr http://127.0.0.1:8000 \
--api_keys my_api_token --port 8100
Now, you can use the API server to get embeddings:
.. code-block:: python
from dbgpt.rag.embedding import OpenAPIEmbeddings
openai_embeddings = OpenAPIEmbeddings(
api_url="http://localhost:8100/api/v1/embeddings",
api_key="my_api_token",
model_name="text2vec",
)
texts = ["Hello, world!", "How are you?"]
openai_embeddings.embed_documents(texts)
"""
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
api_url: str = Field(
default="http://localhost:8100/api/v1/embeddings",
description="The URL of the embeddings API.",
)
api_key: Optional[str] = Field(
default=None, description="The API key for the embeddings API."
)
model_name: str = Field(
default="text2vec", description="The name of the model to use."
)
timeout: int = Field(
default=60, description="The timeout for the request in seconds."
)
pass_trace_id: bool = Field(
default=True, description="Whether to pass the trace ID to the API."
)
session: Optional[requests.Session] = None
def __init__(self, **kwargs):
"""Initialize the OpenAPIEmbeddings."""
try:
import requests
except ImportError:
raise ValueError(
"The requests python package is not installed. "
"Please install it with `pip install requests`"
)
if "session" not in kwargs: # noqa: SIM401
session = requests.Session()
else:
session = kwargs["session"]
api_key = kwargs.get("api_key")
if api_key:
session.headers.update({"Authorization": f"Bearer {api_key}"})
kwargs["session"] = session
super().__init__(**kwargs)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Get the embeddings for a list of texts.
Args:
texts (Documents): A list of texts to get embeddings for.
Returns:
Embedded texts as List[List[float]], where each inner List[float]
corresponds to a single input text.
"""
# Call OpenAI Embedding API
headers = {}
current_span_id = root_tracer.get_current_span_id()
if self.pass_trace_id and current_span_id:
# Set the trace ID if available
headers[DBGPT_TRACER_SPAN_ID] = current_span_id
res = self.session.post( # type: ignore
self.api_url,
json={"input": texts, "model": self.model_name},
timeout=self.timeout,
headers=headers,
)
return _handle_request_result(res)
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a OpenAPI embedding model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs.
Args:
texts: A list of texts to get embeddings for.
Returns:
List[List[float]]: Embedded texts as List[List[float]], where each inner
List[float] corresponds to a single input text.
"""
headers = {"Authorization": f"Bearer {self.api_key}"}
current_span_id = root_tracer.get_current_span_id()
if self.pass_trace_id and current_span_id:
# Set the trace ID if available
headers[DBGPT_TRACER_SPAN_ID] = current_span_id
async with aiohttp.ClientSession(
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
) as session:
async with session.post(
self.api_url, json={"input": texts, "model": self.model_name}
) as resp:
resp.raise_for_status()
data = await resp.json()
if "data" not in data:
raise RuntimeError(data["detail"])
embeddings = data["data"]
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"])
return [result["embedding"] for result in sorted_embeddings]
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text."""
embeddings = await self.aembed_documents([text])
return embeddings[0]
class OllamaEmbeddings(BaseModel, Embeddings):
"""Ollama proxy embeddings.
This class is used to get embeddings for a list of texts using the Ollama API.
It requires a proxy server url `api_url` and a model name `model_name`.
The default model name is "llama2".
"""
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
api_url: str = Field(
default="http://localhost:11434",
description="The URL of the embeddings API.",
)
model_name: str = Field(
default="llama2", description="The name of the model to use."
)
def __init__(self, **kwargs):
"""Initialize the OllamaEmbeddings."""
super().__init__(**kwargs)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Get the embeddings for a list of texts.
Args:
texts (Documents): A list of texts to get embeddings for.
Returns:
Embedded texts as List[List[float]], where each inner List[float]
corresponds to a single input text.
"""
return [self.embed_query(text) for text in texts]
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a OpenAPI embedding model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
try:
import ollama
from ollama import Client
except ImportError as e:
raise ValueError(
"Could not import python package: ollama "
"Please install ollama by command `pip install ollama"
) from e
try:
return (
Client(self.api_url).embeddings(model=self.model_name, prompt=text)
)["embedding"]
except ollama.ResponseError as e:
raise ValueError(f"**Ollama Response Error, Please CheckErrorInfo.**: {e}")
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs.
Args:
texts: A list of texts to get embeddings for.
Returns:
List[List[float]]: Embedded texts as List[List[float]], where each inner
List[float] corresponds to a single input text.
"""
embeddings = []
for text in texts:
embedding = await self.aembed_query(text)
embeddings.append(embedding)
return embeddings
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text."""
try:
import ollama
from ollama import AsyncClient
except ImportError:
raise ValueError(
"The ollama python package is not installed. "
"Please install it with `pip install ollama`"
)
try:
embedding = await AsyncClient(host=self.api_url).embeddings(
model=self.model_name, prompt=text
)
return embedding["embedding"]
except ollama.ResponseError as e:
raise ValueError(f"**Ollama Response Error, Please CheckErrorInfo.**: {e}")
class TongYiEmbeddings(BaseModel, Embeddings):
"""The tongyi embeddings.
import dashscope
from http import HTTPStatus
from dashscope import TextEmbedding
dashscope.api_key = ''
def embed_with_list_of_str():
resp = TextEmbedding.call(
model=TextEmbedding.Models.text_embedding_v1,
# 最多支持10条每条最长支持2048tokens
input=['风急天高猿啸哀', '渚清沙白鸟飞回', '无边落木萧萧下', '不尽长江滚滚来']
)
if resp.status_code == HTTPStatus.OK:
print(resp)
else:
print(resp)
if __name__ == '__main__':
embed_with_list_of_str()
"""
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
api_key: Optional[str] = Field(
default=None, description="The API key for the embeddings API."
)
model_name: str = Field(
default="text-embedding-v1", description="The name of the model to use."
)
def __init__(self, **kwargs):
"""Initialize the OpenAPIEmbeddings."""
try:
import dashscope # type: ignore
except ImportError as exc:
raise ValueError(
"Could not import python package: dashscope "
"Please install dashscope by command `pip install dashscope"
) from exc
dashscope.TextEmbedding.api_key = kwargs.get("api_key")
super().__init__(**kwargs)
self._api_key = kwargs.get("api_key")
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Get the embeddings for a list of texts.
Args:
texts (Documents): A list of texts to get embeddings for.
Returns:
Embedded texts as List[List[float]], where each inner List[float]
corresponds to a single input text.
"""
from dashscope import TextEmbedding
# 最多支持10条每条最长支持2048tokens
resp = TextEmbedding.call(
model=self.model_name, input=texts, api_key=self._api_key
)
if "output" not in resp:
raise RuntimeError(resp["message"])
embeddings = resp["output"]["embeddings"]
sorted_embeddings = sorted(embeddings, key=lambda e: e["text_index"])
return [result["embedding"] for result in sorted_embeddings]
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a OpenAPI embedding model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]
class QianFanEmbeddings(BaseModel, Embeddings):
"""Baidu Qianfan Embeddings embedding models.
Embed:
.. code-block:: python
# embed the documents
vectors = embeddings.embed_documents([text1, text2, ...])
# embed the query
vectors = embeddings.embed_query(text)
""" # noqa: E501
client: Any
chunk_size: int = 16
endpoint: str = ""
"""Endpoint of the Qianfan Embedding, required if custom model used."""
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
api_key: Optional[str] = Field(
default=None, description="The API key for the embeddings API."
)
api_secret: Optional[str] = Field(
default=None, description="The Secret key for the embeddings API."
)
"""Model name
you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
for now, we support Embedding-V1 and
- Embedding-V1 (默认模型)
- bge-large-en
- bge-large-zh
preset models are mapping to an endpoint.
`model` will be ignored if `endpoint` is set
"""
model_name: str = Field(
default="text-embedding-v1", description="The name of the model to use."
)
init_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""init kwargs for qianfan client init, such as `query_per_second` which is
associated with qianfan resource object to limit QPS"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""extra params for model invoke using with `do`."""
def __init__(self, **kwargs):
"""Initialize the QianFanEmbeddings."""
try:
import qianfan
except ImportError as exc:
raise ValueError(
"Could not import python package: qianfan. "
"Please install qianfan by running `pip install qianfan`."
) from exc
qianfan_ak = kwargs.get("api_key")
qianfan_sk = kwargs.get("api_secret")
model_name = kwargs.get("model_name")
if not qianfan_ak or not qianfan_sk or not model_name:
raise ValueError(
"API key, API secret, and model name are required to initialize "
"QianFanEmbeddings."
)
params = {
"model": model_name,
"ak": qianfan_ak,
"sk": qianfan_sk,
}
# Initialize the qianfan.Embedding client
kwargs["client"] = qianfan.Embedding(**params)
super().__init__(**kwargs)
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a QianFan embedding model."""
resp = self.embed_documents([text])
return resp[0]
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
Embeds a list of text documents using the AutoVOT algorithm.
Args:
texts (List[str]): A list of text documents to embed.
Returns:
List[List[float]]: A list of embeddings for each document in the input list.
Each embedding is represented as a list of float values.
"""
text_in_chunks = [
texts[i : i + self.chunk_size]
for i in range(0, len(texts), self.chunk_size)
]
lst = []
for chunk in text_in_chunks:
resp = self.client.do(texts=chunk, **self.model_kwargs)
lst.extend([res["embedding"] for res in resp["data"]])
return lst