mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-21 19:31:43 +00:00
feat(model): AI/ML API integration (#2844)
This commit is contained in:
parent
8c14f1981a
commit
845432fea0
29
configs/dbgpt-proxy-aimlapi.toml
Normal file
29
configs/dbgpt-proxy-aimlapi.toml
Normal file
@ -0,0 +1,29 @@
|
||||
[system]
|
||||
language = "${env:DBGPT_LANG:-en}"
|
||||
api_keys = []
|
||||
encrypt_key = "your_secret_key"
|
||||
|
||||
[service.web]
|
||||
host = "0.0.0.0"
|
||||
port = 5670
|
||||
|
||||
[service.web.database]
|
||||
type = "sqlite"
|
||||
path = "pilot/meta_data/dbgpt.db"
|
||||
|
||||
[rag.storage]
|
||||
[rag.storage.vector]
|
||||
type = "chroma"
|
||||
persist_path = "pilot/data"
|
||||
|
||||
[models]
|
||||
[[models.llms]]
|
||||
name = "${env:LLM_MODEL_NAME:-gpt-4o}"
|
||||
provider = "proxy/aimlapi"
|
||||
api_key = "${env:AIMLAPI_API_KEY}"
|
||||
|
||||
[[models.embeddings]]
|
||||
name = "${env:EMBEDDING_MODEL_NAME:-text-embedding-3-small}"
|
||||
provider = "proxy/aimlapi"
|
||||
api_url = "https://api.aimlapi.com/v1/embeddings"
|
||||
api_key = "${env:AIMLAPI_API_KEY}"
|
@ -124,3 +124,6 @@ ENV PATH="${FINAL_VENV_NAME}/bin:$PATH" \
|
||||
VIRTUAL_ENV="${FINAL_VENV_NAME}"
|
||||
# Default command
|
||||
CMD ["dbgpt", "start", "webserver", "--config", "configs/dbgpt-proxy-siliconflow.toml"]
|
||||
|
||||
# Uncomment the following line to use the AI/ML API configuration
|
||||
# CMD ["dbgpt", "start", "webserver", "--config", "configs/dbgpt-proxy-aimlapi.toml"]
|
||||
|
@ -0,0 +1,34 @@
|
||||
---
|
||||
title: "AI/ML API Proxy LLM Configuration"
|
||||
description: "AI/ML API proxy LLM configuration."
|
||||
---
|
||||
|
||||
import { ConfigDetail } from "@site/src/components/mdx/ConfigDetail";
|
||||
|
||||
<ConfigDetail config={{
|
||||
"name": "Llm.v1.gpt-3.5-turbo-1106",
|
||||
"description": "OpenAI-compatible chat completion request schema.",
|
||||
"documentationUrl": "https://api.aimlapi.com/docs-public",
|
||||
"parameters": [
|
||||
{ "name": "model", "type": "string", "required": true, "description": "ID of the model to use." },
|
||||
{ "name": "messages", "type": "array", "required": true, "description": "List of messages comprising the conversation." },
|
||||
{ "name": "max_completion_tokens", "type": "integer", "required": false, "description": "Maximum number of tokens to generate for completion." },
|
||||
{ "name": "max_tokens", "type": "integer", "required": false, "description": "Alias for max_completion_tokens." },
|
||||
{ "name": "stream", "type": "boolean", "required": false, "description": "Whether to stream back partial progress." },
|
||||
{ "name": "stream_options", "type": "object", "required": false, "description": "Additional options to control streaming behavior." },
|
||||
{ "name": "tools", "type": "array", "required": false, "description": "List of tools (functions or APIs) the model may call." },
|
||||
{ "name": "tool_choice", "type": "object", "required": false, "description": "Which tool the model should call, if any." },
|
||||
{ "name": "parallel_tool_calls", "type": "boolean", "required": false, "description": "Whether tools can be called in parallel." },
|
||||
{ "name": "n", "type": "integer", "required": false, "description": "How many completions to generate for each prompt." },
|
||||
{ "name": "stop", "type": "array|string", "required": false, "description": "Sequences where the model will stop generating further tokens." },
|
||||
{ "name": "logprobs", "type": "boolean", "required": false, "description": "Whether to include log probabilities for tokens." },
|
||||
{ "name": "top_logprobs", "type": "integer", "required": false, "description": "Number of most likely tokens to return logprobs for." },
|
||||
{ "name": "logit_bias", "type": "object", "required": false, "description": "Modify likelihood of specified tokens appearing in the completion." },
|
||||
{ "name": "frequency_penalty", "type": "number", "required": false, "description": "How much to penalize new tokens based on frequency." },
|
||||
{ "name": "presence_penalty", "type": "number", "required": false, "description": "How much to penalize new tokens based on whether they appear in the text so far." },
|
||||
{ "name": "seed", "type": "integer", "required": false, "description": "Seed for sampling (for reproducibility)." },
|
||||
{ "name": "temperature", "type": "number", "required": false, "description": "Sampling temperature to use (higher = more random)." },
|
||||
{ "name": "top_p", "type": "number", "required": false, "description": "Nucleus sampling (top-p) cutoff value." },
|
||||
{ "name": "response_format", "type": "object|string", "required": false, "description": "Format to return the completion in, such as 'json' or 'text'." }
|
||||
]
|
||||
}} />
|
@ -82,6 +82,11 @@ import { ConfigClassTable } from '@site/src/components/mdx/ConfigClassTable';
|
||||
"description": "OpenAI Compatible Proxy LLM",
|
||||
"link": "./chatgpt_openaicompatibledeploymodelparameters_c3d426"
|
||||
},
|
||||
{
|
||||
"name": "AimlapiDeployModelParameters",
|
||||
"description": "AI/ML API proxy LLM configuration.",
|
||||
"link": "./aimlapi_aimlapideploymodelparameters_a1b2c3"
|
||||
},
|
||||
{
|
||||
"name": "SiliconFlowDeployModelParameters",
|
||||
"description": "SiliconFlow proxy LLM configuration.",
|
||||
|
@ -24,15 +24,20 @@ docker pull eosphorosai/dbgpt-openai:latest
|
||||
|
||||
2. Run the Docker container
|
||||
|
||||
This example requires you previde a valid API key for the SiliconFlow API. You can obtain one by signing up at [SiliconFlow](https://siliconflow.cn/) and creating an API key at [API Key](https://cloud.siliconflow.cn/account/ak).
|
||||
This example requires you provide a valid API key for the SiliconFlow API. You can obtain one by signing up at [SiliconFlow](https://siliconflow.cn/) and creating an API key at [API Key](https://cloud.siliconflow.cn/account/ak). Alternatively, set `AIMLAPI_API_KEY` to use the AI/ML API service.
|
||||
|
||||
|
||||
```bash
|
||||
docker run -it --rm -e SILICONFLOW_API_KEY=${SILICONFLOW_API_KEY} \
|
||||
-p 5670:5670 --name dbgpt eosphorosai/dbgpt-openai
|
||||
```
|
||||
Or with AI/ML API:
|
||||
```bash
|
||||
docker run -it --rm -e AIMLAPI_API_KEY=${AIMLAPI_API_KEY} \
|
||||
-p 5670:5670 --name dbgpt eosphorosai/dbgpt-openai
|
||||
```
|
||||
|
||||
Please replace `${SILICONFLOW_API_KEY}` with your own API key.
|
||||
Please replace `${SILICONFLOW_API_KEY}` or `${AIMLAPI_API_KEY}` with your own API key.
|
||||
|
||||
|
||||
Then you can visit [http://localhost:5670](http://localhost:5670) in the browser.
|
||||
|
@ -2,12 +2,17 @@
|
||||
|
||||
## Run via Docker-Compose
|
||||
|
||||
This example requires you previde a valid API key for the SiliconFlow API. You can obtain one by signing up at [SiliconFlow](https://siliconflow.cn/) and creating an API key at [API Key](https://cloud.siliconflow.cn/account/ak).
|
||||
This example requires you provide a valid API key for the SiliconFlow API. You can obtain one by signing up at [SiliconFlow](https://siliconflow.cn/) and creating an API key at [API Key](https://cloud.siliconflow.cn/account/ak).
|
||||
Alternatively you can use the [AI/ML API](https://aimlapi.com/) by setting `AIMLAPI_API_KEY`.
|
||||
|
||||
|
||||
```bash
|
||||
SILICONFLOW_API_KEY=${SILICONFLOW_API_KEY} docker compose up -d
|
||||
```
|
||||
Or use AI/ML API:
|
||||
```bash
|
||||
AIMLAPI_API_KEY=${AIMLAPI_API_KEY} docker compose up -d
|
||||
```
|
||||
|
||||
You will see the following output if the deployment is successful.
|
||||
```bash
|
||||
|
24
docs/docs/installation/integrations/aimlapi_llm_install.md
Normal file
24
docs/docs/installation/integrations/aimlapi_llm_install.md
Normal file
@ -0,0 +1,24 @@
|
||||
# AI/ML API
|
||||
|
||||
### AI/ML API provides 300+ AI models including Deepseek, Gemini, ChatGPT. The models run at enterprise-grade rate limits and uptimes.
|
||||
|
||||
### This section describes how to use the AI/ML API provider with DB-GPT.
|
||||
|
||||
1. Sign up at [AI/ML API](https://aimlapi.com/app/?utm_source=db_gpt&utm_medium=github&utm_campaign=integration) and generate an API key.
|
||||
2. Set the environment variable `AIMLAPI_API_KEY` with your key.
|
||||
3. Use the `configs/dbgpt-proxy-aimlapi.toml` configuration when starting DB-GPT.
|
||||
|
||||
### You can look up models at [https://aimlapi.com/models/](https://aimlapi.com/models/?utm_source=db_gpt&utm_medium=github&utm_campaign=integration)
|
||||
|
||||
### Or you can use docker/base/Dockerfile to run DB-GPT with AI/ML API:
|
||||
|
||||
```dockerfile
|
||||
# Expose the port for the web server, if you want to run it directly from the Dockerfile
|
||||
EXPOSE 5670
|
||||
|
||||
# Set the environment variable for the AIMLAPI API key
|
||||
ENV AIMLAPI_API_KEY="***"
|
||||
|
||||
# Just uncomment the following line in the `Dockerfile` to use AI/ML API:
|
||||
CMD ["dbgpt", "start", "webserver", "--config", "configs/dbgpt-proxy-aimlapi.toml"]
|
||||
```
|
@ -3,6 +3,7 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dbgpt.model.proxy.llms.aimlapi import AimlapiLLMClient
|
||||
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
|
||||
from dbgpt.model.proxy.llms.claude import ClaudeLLMClient
|
||||
from dbgpt.model.proxy.llms.deepseek import DeepseekLLMClient
|
||||
@ -24,6 +25,7 @@ def __lazy_import(name):
|
||||
"OpenAILLMClient": "dbgpt.model.proxy.llms.chatgpt",
|
||||
"ClaudeLLMClient": "dbgpt.model.proxy.llms.claude",
|
||||
"GeminiLLMClient": "dbgpt.model.proxy.llms.gemini",
|
||||
"AimlapiLLMClient": "dbgpt.model.proxy.llms.aimlapi",
|
||||
"SiliconFlowLLMClient": "dbgpt.model.proxy.llms.siliconflow",
|
||||
"SparkLLMClient": "dbgpt.model.proxy.llms.spark",
|
||||
"TongyiLLMClient": "dbgpt.model.proxy.llms.tongyi",
|
||||
@ -55,6 +57,7 @@ __all__ = [
|
||||
"TongyiLLMClient",
|
||||
"ZhipuLLMClient",
|
||||
"WenxinLLMClient",
|
||||
"AimlapiLLMClient",
|
||||
"SiliconFlowLLMClient",
|
||||
"SparkLLMClient",
|
||||
"YiLLMClient",
|
||||
|
258
packages/dbgpt-core/src/dbgpt/model/proxy/llms/aimlapi.py
Normal file
258
packages/dbgpt-core/src/dbgpt/model/proxy/llms/aimlapi.py
Normal file
@ -0,0 +1,258 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union
|
||||
|
||||
from dbgpt.core import ModelMetadata
|
||||
from dbgpt.core.awel.flow import (
|
||||
TAGS_ORDER_HIGH,
|
||||
ResourceCategory,
|
||||
auto_register_resource,
|
||||
)
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
from ..base import (
|
||||
AsyncGenerateStreamFunction,
|
||||
GenerateStreamFunction,
|
||||
register_proxy_model_adapter,
|
||||
)
|
||||
from .chatgpt import OpenAICompatibleDeployModelParameters, OpenAILLMClient
|
||||
|
||||
AIMLAPI_HEADERS = {
|
||||
"HTTP-Referer": "https://github.com/eosphoros-ai/DB-GPT",
|
||||
"X-Title": "DB GPT",
|
||||
}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from httpx._types import ProxiesTypes
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||
|
||||
ClientType = Union[AsyncAzureOpenAI, AsyncOpenAI]
|
||||
|
||||
|
||||
_AIMLAPI_DEFAULT_MODEL = "gpt-4o"
|
||||
|
||||
|
||||
@auto_register_resource(
|
||||
label=_("AI/ML API Proxy LLM"),
|
||||
category=ResourceCategory.LLM_CLIENT,
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
description=_("AI/ML API proxy LLM configuration."),
|
||||
documentation_url="https://api.aimlapi.com/v1/",
|
||||
show_in_ui=False,
|
||||
)
|
||||
@dataclass
|
||||
class AimlapiDeployModelParameters(OpenAICompatibleDeployModelParameters):
|
||||
"""Deploy model parameters for AI/ML API."""
|
||||
|
||||
provider: str = "proxy/aimlapi"
|
||||
|
||||
api_base: Optional[str] = field(
|
||||
default="${env:AIMLAPI_API_BASE:-https://api.aimlapi.com/v1}",
|
||||
metadata={"help": _("The base url of the AI/ML API.")},
|
||||
)
|
||||
|
||||
api_key: Optional[str] = field(
|
||||
default="${env:AIMLAPI_API_KEY}",
|
||||
metadata={"help": _("The API key of the AI/ML API."), "tags": "privacy"},
|
||||
)
|
||||
|
||||
|
||||
async def aimlapi_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
client: AimlapiLLMClient = model.proxy_llm_client
|
||||
request = parse_model_request(params, client.default_model, stream=True)
|
||||
async for r in client.generate_stream(request):
|
||||
yield r
|
||||
|
||||
|
||||
class AimlapiLLMClient(OpenAILLMClient):
|
||||
"""AI/ML API LLM Client using OpenAI-compatible endpoints."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_type: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
model: Optional[str] = _AIMLAPI_DEFAULT_MODEL,
|
||||
proxies: Optional["ProxiesTypes"] = None,
|
||||
timeout: Optional[int] = 240,
|
||||
model_alias: Optional[str] = _AIMLAPI_DEFAULT_MODEL,
|
||||
context_length: Optional[int] = None,
|
||||
openai_client: Optional["ClientType"] = None,
|
||||
openai_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
api_base = (
|
||||
api_base or os.getenv("AIMLAPI_API_BASE") or "https://api.aimlapi.com/v1"
|
||||
)
|
||||
api_key = api_key or os.getenv("AIMLAPI_API_KEY")
|
||||
model = model or _AIMLAPI_DEFAULT_MODEL
|
||||
if not context_length:
|
||||
if "200k" in model:
|
||||
context_length = 200 * 1024
|
||||
else:
|
||||
context_length = 4096
|
||||
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"AI/ML API key is required, please set 'AIMLAPI_API_KEY' "
|
||||
"in environment or pass it as an argument."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_type=api_type,
|
||||
api_version=api_version,
|
||||
model=model,
|
||||
proxies=proxies,
|
||||
timeout=timeout,
|
||||
model_alias=model_alias,
|
||||
context_length=context_length,
|
||||
openai_client=openai_client,
|
||||
openai_kwargs=openai_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
self.client.default_headers.update(AIMLAPI_HEADERS)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@property
|
||||
def default_model(self) -> str:
|
||||
model = self._model
|
||||
if not model:
|
||||
model = _AIMLAPI_DEFAULT_MODEL
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def param_class(cls) -> Type[AimlapiDeployModelParameters]:
|
||||
return AimlapiDeployModelParameters
|
||||
|
||||
@classmethod
|
||||
def generate_stream_function(
|
||||
cls,
|
||||
) -> Optional[Union[GenerateStreamFunction, AsyncGenerateStreamFunction]]:
|
||||
return aimlapi_generate_stream
|
||||
|
||||
|
||||
register_proxy_model_adapter(
|
||||
AimlapiLLMClient,
|
||||
supported_models=[
|
||||
ModelMetadata(
|
||||
model=["openai/gpt-4"],
|
||||
context_length=8_000,
|
||||
max_output_length=4_096,
|
||||
description="OpenAI GPT‑4: state‑of‑the‑art language model",
|
||||
link="https://aimlapi.com/models/chat-gpt-4",
|
||||
function_calling=True,
|
||||
),
|
||||
ModelMetadata(
|
||||
model=["openai/gpt-4o", "gpt-4o-mini", "openai/gpt-4-turbo"],
|
||||
context_length=128_000,
|
||||
max_output_length=16_384,
|
||||
description="GPT‑4 family (4o, 4o‑mini, 4 Turbo) via AI/ML API",
|
||||
link="https://aimlapi.com/models#openai-gpt-4o",
|
||||
function_calling=True,
|
||||
),
|
||||
ModelMetadata(
|
||||
model=["gpt-3.5-turbo"],
|
||||
context_length=16_000,
|
||||
max_output_length=4_096,
|
||||
description="GPT‑3.5 Turbo: fast, high‑quality text generation",
|
||||
link="https://aimlapi.com/models/chat-gpt-3-5-turbo",
|
||||
function_calling=True,
|
||||
),
|
||||
ModelMetadata(
|
||||
model=[
|
||||
"mistralai/Mistral-7B-Instruct-v0.3",
|
||||
"meta-llama/Llama-3.1-405B",
|
||||
"Qwen/Qwen2-235B",
|
||||
],
|
||||
context_length=32_000,
|
||||
max_output_length=8_192,
|
||||
description="Instruction‑tuned LLMs with 32k token context window",
|
||||
link="https://aimlapi.com/models",
|
||||
function_calling=False,
|
||||
),
|
||||
ModelMetadata(
|
||||
model=[
|
||||
"google/gemini-2-27b-it",
|
||||
"x-ai/grok-2-beta",
|
||||
"bytedance/seedream-3.0",
|
||||
],
|
||||
context_length=8_000,
|
||||
max_output_length=4_096,
|
||||
description="Models with 8k token context window, no function_calling",
|
||||
link="https://aimlapi.com/models",
|
||||
function_calling=False,
|
||||
),
|
||||
ModelMetadata(
|
||||
model=["claude-3-5-sonnet-20240620"],
|
||||
context_length=8_192,
|
||||
max_output_length=2_048,
|
||||
description="Claude 3.5 Sonnet: advanced multimodal model from Anthropic",
|
||||
link="https://aimlapi.com/models/claude-3-5-sonnet",
|
||||
function_calling=True,
|
||||
),
|
||||
ModelMetadata(
|
||||
model=["deepseek-chat"],
|
||||
context_length=128_000,
|
||||
max_output_length=16_000,
|
||||
description="DeepSeek V3: efficient high‑performance LLM",
|
||||
link="https://aimlapi.com/models/deepseek-v3",
|
||||
function_calling=False,
|
||||
),
|
||||
ModelMetadata(
|
||||
model=["mistralai/Mixtral-8x7B-Instruct-v0.1"],
|
||||
context_length=64_000,
|
||||
max_output_length=8_000,
|
||||
description="Mixtral‑8x7B: sparse mixture‑of‑experts instruction model",
|
||||
link="https://aimlapi.com/models/mixtral-8x7b-instruct-v01",
|
||||
function_calling=False,
|
||||
),
|
||||
ModelMetadata(
|
||||
model=["meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo"],
|
||||
context_length=131_000,
|
||||
max_output_length=16_000,
|
||||
description="Llama 3.2‑90B: advanced vision‑instruct turbo model",
|
||||
link="https://aimlapi.com/models/llama-3-2-90b-vision-instruct-turbo-api",
|
||||
function_calling=False,
|
||||
),
|
||||
ModelMetadata(
|
||||
model=["google/gemini-2-0-flash"],
|
||||
context_length=1_000_000,
|
||||
max_output_length=32_768,
|
||||
description="Gemini 2.0 Flash: ultra‑low latency multimodal model",
|
||||
link="https://aimlapi.com/models/gemini-2-0-flash-api",
|
||||
function_calling=True,
|
||||
),
|
||||
ModelMetadata(
|
||||
model=["meta-llama/Meta-Llama-3-8B-Instruct-Lite"],
|
||||
context_length=9_000,
|
||||
max_output_length=1_024,
|
||||
description="Llama 3 8B Instruct Lite: compact dialogue model",
|
||||
link="https://aimlapi.com/models/llama-3-8b-instruct-lite-api",
|
||||
function_calling=False,
|
||||
),
|
||||
ModelMetadata(
|
||||
model=["cohere/command-r-plus"],
|
||||
context_length=128_000,
|
||||
max_output_length=16_000,
|
||||
description="Cohere Command R+: enterprise‑grade chat model",
|
||||
link="https://aimlapi.com/models/command-r-api",
|
||||
function_calling=False,
|
||||
),
|
||||
ModelMetadata(
|
||||
model=["mistralai/codestral-2501"],
|
||||
context_length=256_000,
|
||||
max_output_length=32_000,
|
||||
description="Codestral‑2501: advanced code generation model",
|
||||
link="https://aimlapi.com/models/mistral-codestral-2501-api",
|
||||
function_calling=False,
|
||||
),
|
||||
],
|
||||
)
|
@ -1,5 +1,6 @@
|
||||
"""Module for embedding related classes and functions."""
|
||||
|
||||
from .aimlapi import AimlapiEmbeddings # noqa: F401
|
||||
from .jina import JinaEmbeddings # noqa: F401
|
||||
from .ollama import OllamaEmbeddings # noqa: F401
|
||||
from .qianfan import QianFanEmbeddings # noqa: F401
|
||||
@ -12,4 +13,5 @@ __ALL__ = [
|
||||
"QianFanEmbeddings",
|
||||
"TongYiEmbeddings",
|
||||
"SiliconFlowEmbeddings",
|
||||
"AimlapiEmbeddings",
|
||||
]
|
||||
|
173
packages/dbgpt-ext/src/dbgpt_ext/rag/embeddings/aimlapi.py
Normal file
173
packages/dbgpt-ext/src/dbgpt_ext/rag/embeddings/aimlapi.py
Normal file
@ -0,0 +1,173 @@
|
||||
"""AI/ML API embeddings for RAG."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
|
||||
from dbgpt.core import EmbeddingModelMetadata, Embeddings
|
||||
from dbgpt.core.interface.parameter import EmbeddingDeployModelParameters
|
||||
from dbgpt.model.adapter.base import register_embedding_adapter
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
AIMLAPI_HEADERS = {
|
||||
"HTTP-Referer": "https://github.com/eosphoros-ai/DB-GPT",
|
||||
"X-Title": "DB GPT",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AimlapiEmbeddingDeployModelParameters(EmbeddingDeployModelParameters):
|
||||
"""AI/ML API Embeddings deploy model parameters."""
|
||||
|
||||
provider: str = "proxy/aimlapi"
|
||||
|
||||
api_key: Optional[str] = field(
|
||||
default="${env:AIMLAPI_API_KEY}",
|
||||
metadata={"help": _("The API key for the embeddings API.")},
|
||||
)
|
||||
backend: Optional[str] = field(
|
||||
default="text-embedding-3-small",
|
||||
metadata={
|
||||
"help": _(
|
||||
"The real model name to pass to the provider, default is None. If "
|
||||
"backend is None, use name as the real model name."
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@property
|
||||
def real_provider_model_name(self) -> str:
|
||||
return self.backend or self.name
|
||||
|
||||
|
||||
class AimlapiEmbeddings(BaseModel, Embeddings):
|
||||
"""The AI/ML API embeddings."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
|
||||
api_key: Optional[str] = Field(
|
||||
default=None, description="The API key for the embeddings API."
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="text-embedding-3-small", description="The name of the model to use."
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the AI/ML API Embeddings."""
|
||||
super().__init__(**kwargs)
|
||||
self._api_key = self.api_key
|
||||
|
||||
@classmethod
|
||||
def param_class(cls) -> Type[AimlapiEmbeddingDeployModelParameters]:
|
||||
return AimlapiEmbeddingDeployModelParameters
|
||||
|
||||
@classmethod
|
||||
def from_parameters(
|
||||
cls, parameters: AimlapiEmbeddingDeployModelParameters
|
||||
) -> "Embeddings":
|
||||
return cls(
|
||||
api_key=parameters.api_key, model_name=parameters.real_provider_model_name
|
||||
)
|
||||
|
||||
def embed_documents(
|
||||
self, texts: List[str], max_batch_chunks_size: int = 25
|
||||
) -> List[List[float]]:
|
||||
"""Get the embeddings for a list of texts."""
|
||||
import requests
|
||||
|
||||
embeddings = []
|
||||
headers = {"Authorization": f"Bearer {self._api_key}"}
|
||||
headers.update(AIMLAPI_HEADERS)
|
||||
|
||||
for i in range(0, len(texts), max_batch_chunks_size):
|
||||
batch_texts = texts[i : i + max_batch_chunks_size]
|
||||
response = requests.post(
|
||||
url="https://api.aimlapi.com/v1/embeddings",
|
||||
json={"model": self.model_name, "input": batch_texts},
|
||||
headers=headers,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(f"Embedding failed: {response.text}")
|
||||
data = response.json()
|
||||
batch_embeddings = data["data"]
|
||||
sorted_embeddings = sorted(batch_embeddings, key=lambda e: e["index"])
|
||||
embeddings.extend([result["embedding"] for result in sorted_embeddings])
|
||||
|
||||
return embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
|
||||
register_embedding_adapter(
|
||||
AimlapiEmbeddings,
|
||||
supported_models=[
|
||||
EmbeddingModelMetadata(
|
||||
model=["text-embedding-3-large", "text-embedding-ada-002"],
|
||||
dimension=1536,
|
||||
context_length=8000,
|
||||
description=_(
|
||||
"High‑performance embedding models with "
|
||||
"flexible dimensions and superior accuracy."
|
||||
),
|
||||
link="https://aimlapi.com/models",
|
||||
),
|
||||
EmbeddingModelMetadata(
|
||||
model=["BAAI/bge-base-en-v1.5", "BAAI/bge-large-en-v1.5"],
|
||||
dimension=1536,
|
||||
context_length=None,
|
||||
description=_(
|
||||
"BAAI BGE models for precise and high‑performance language embeddings."
|
||||
),
|
||||
link="https://aimlapi.com/models",
|
||||
),
|
||||
EmbeddingModelMetadata(
|
||||
model=[
|
||||
"togethercomputer/m2-bert-80M-32k-retrieval",
|
||||
"voyage-finance-2",
|
||||
"voyage-multilingual-2",
|
||||
],
|
||||
dimension=1536,
|
||||
context_length=32000,
|
||||
description=_(
|
||||
"High‑capacity embedding models with 32k token "
|
||||
"context window for retrieval and specialized domains."
|
||||
),
|
||||
link="https://aimlapi.com/models",
|
||||
),
|
||||
EmbeddingModelMetadata(
|
||||
model=[
|
||||
"voyage-large-2-instruct",
|
||||
"voyage-law-2",
|
||||
"voyage-code-2",
|
||||
"voyage-large-2",
|
||||
],
|
||||
dimension=1536,
|
||||
context_length=16000,
|
||||
description=_(
|
||||
"Voyage embedding models with 16k token context window, "
|
||||
"optimized for general and instruction tasks."
|
||||
),
|
||||
link="https://aimlapi.com/models",
|
||||
),
|
||||
EmbeddingModelMetadata(
|
||||
model=["voyage-2"],
|
||||
dimension=1536,
|
||||
context_length=4000,
|
||||
description=_("Voyage 2: compact embeddings for smaller contexts."),
|
||||
link="https://aimlapi.com/models",
|
||||
),
|
||||
EmbeddingModelMetadata(
|
||||
model=[
|
||||
"textembedding-gecko@003",
|
||||
"textembedding-gecko-multilingual@001",
|
||||
"text-multilingual-embedding-002",
|
||||
],
|
||||
dimension=1536,
|
||||
context_length=2000,
|
||||
description=_(
|
||||
"Gecko and multilingual embedding models with 2k token context window."
|
||||
),
|
||||
link="https://aimlapi.com/models",
|
||||
),
|
||||
],
|
||||
)
|
Loading…
Reference in New Issue
Block a user