feat(model): Add new LLMClient and new build tools (#967)

This commit is contained in:
Fangyin Cheng
2023-12-23 16:33:01 +08:00
committed by GitHub
parent 12234ae258
commit 0c46c339ca
30 changed files with 1072 additions and 133 deletions

61
Makefile Normal file
View File

@@ -0,0 +1,61 @@
.DEFAULT_GOAL := help
SHELL=/bin/bash
VENV = venv
# Detect the operating system and set the virtualenv bin directory
ifeq ($(OS),Windows_NT)
VENV_BIN=$(VENV)/Scripts
else
VENV_BIN=$(VENV)/bin
endif
setup: ## Set up the Python development environment
python3 -m venv $(VENV)
$(VENV_BIN)/pip install --upgrade pip
$(VENV_BIN)/pip install -r requirements/dev-requirements.txt
$(VENV_BIN)/pip install -r requirements/lint-requirements.txt
testenv: setup ## Set up the Python test environment
$(VENV_BIN)/pip install -e ".[simple_framework]"
.PHONY: fmt
fmt: setup ## Format Python code
# TODO: Use isort to sort Python imports.
# https://github.com/PyCQA/isort
# $(VENV_BIN)/isort .
# https://github.com/psf/black
$(VENV_BIN)/black .
# TODO: Use blackdoc to format Python doctests.
# https://blackdoc.readthedocs.io/en/latest/
# $(VENV_BIN)/blackdoc .
# TODO: Type checking of Python code.
# https://github.com/python/mypy
# $(VENV_BIN)/mypy dbgpt
# TODO: uUse flake8 to enforce Python style guide.
# https://flake8.pycqa.org/en/latest/
# $(VENV_BIN)/flake8 dbgpt
.PHONY: pre-commit
pre-commit: fmt test ## Run formatting and unit tests before committing
.PHONY: test
test: testenv ## Run unit tests
$(VENV_BIN)/pytest dbgpt
.PHONY: coverage
coverage: setup ## Run tests and report coverage
$(VENV_BIN)/pytest dbgpt --cov=dbgpt
.PHONY: clean
clean: ## Clean up the environment
rm -rf $(VENV)
find . -type f -name '*.pyc' -delete
find . -type d -name '__pycache__' -delete
find . -type d -name '.pytest_cache' -delete
find . -type d -name '.coverage' -delete
.PHONY: help
help: ## Display this help screen
@echo "Available commands:"
@grep -E '^[a-z.A-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-18s\033[0m %s\n", $$1, $$2}' | sort

View File

@@ -1,9 +1,12 @@
from dbgpt.core.interface.llm import (
ModelInferenceMetrics,
ModelRequest,
ModelOutput,
OpenAILLM,
BaseLLMOperator,
LLMClient,
LLMOperator,
StreamingLLMOperator,
RequestBuildOperator,
ModelMetadata,
)
from dbgpt.core.interface.message import (
ModelMessage,
@@ -37,11 +40,15 @@ from dbgpt.core.interface.storage import (
__ALL__ = [
"ModelInferenceMetrics",
"ModelRequest",
"ModelOutput",
"OpenAILLM",
"BaseLLMOperator",
"Operator",
"RequestBuildOperator",
"ModelMetadata",
"ModelMessage",
"LLMClient",
"LLMOperator",
"StreamingLLMOperator",
"ModelMessageRoleType",
"OnceConversation",
"StorageConversation",

View File

@@ -211,7 +211,9 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
Returns:
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
"""
out_ctx = await self._runner.execute_workflow(self, call_data)
out_ctx = await self._runner.execute_workflow(
self, call_data, streaming_call=True
)
return out_ctx.current_task_context.task_output.output_stream
def _blocking_call_stream(

View File

@@ -130,8 +130,9 @@ async def _trigger_dag(
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
}
generator = await end_node.call_stream(call_data={"data": body})
return StreamingResponse(
end_node.call_stream(call_data={"data": body}),
generator,
headers=headers,
media_type=media_type,
)

View File

@@ -1,10 +1,10 @@
from abc import ABC
from abc import ABC, abstractmethod
from typing import Optional, Dict, List, Any, Union, AsyncIterator
import time
from dataclasses import dataclass, asdict
from dataclasses import dataclass, asdict, field
import copy
from dbgpt.util import BaseParameters
from dbgpt.util.annotations import PublicAPI
from dbgpt.util.model_utils import GPUInfo
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
@@ -12,6 +12,7 @@ from dbgpt.core.awel import MapOperator, StreamifyAbsOperator
@dataclass
@PublicAPI(stability="beta")
class ModelInferenceMetrics:
"""A class to represent metrics for assessing the inference performance of a LLM."""
@@ -97,6 +98,7 @@ class ModelInferenceMetrics:
@dataclass
@PublicAPI(stability="beta")
class ModelOutput:
"""A class to represent the output of a LLM.""" ""
@@ -118,6 +120,7 @@ _ModelMessageType = Union[ModelMessage, Dict[str, Any]]
@dataclass
@PublicAPI(stability="beta")
class ModelRequest:
model: str
"""The name of the model."""
@@ -142,7 +145,7 @@ class ModelRequest:
span_id: Optional[str] = None
"""The span id of the model inference."""
def to_dict(self) -> Dict:
def to_dict(self) -> Dict[str, Any]:
new_reqeust = copy.deepcopy(self)
new_reqeust.messages = list(
map(lambda m: m if isinstance(m, dict) else m.dict(), new_reqeust.messages)
@@ -166,6 +169,110 @@ class ModelRequest:
**kwargs,
)
def to_openai_messages(self) -> List[Dict[str, Any]]:
"""Convert the messages to the format of OpenAI API.
This function will move last user message to the end of the list.
Returns:
List[Dict[str, Any]]: The messages in the format of OpenAI API.
Examples:
.. code-block:: python
messages = [
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hi"),
ModelMessage(role=ModelMessageRoleType.AI, content="Hi, I'm a robot.")
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Who are your"),
]
openai_messages = ModelRequest.to_openai_messages(messages)
assert openai_messages == [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hi, I'm a robot."},
{"role": "user", "content": "Who are your"},
]
"""
messages = [
m if isinstance(m, ModelMessage) else ModelMessage(**m)
for m in self.messages
]
return ModelMessage.to_openai_messages(messages)
@dataclass
@PublicAPI(stability="beta")
class ModelMetadata(BaseParameters):
"""A class to represent a LLM model."""
model: str = field(
metadata={"help": "Model name"},
)
context_length: Optional[int] = field(
default=4096,
metadata={"help": "Context length of model"},
)
chat_model: Optional[bool] = field(
default=True,
metadata={"help": "Whether the model is a chat model"},
)
is_function_calling_model: Optional[bool] = field(
default=False,
metadata={"help": "Whether the model is a function calling model"},
)
metadata: Optional[Dict[str, Any]] = field(
default_factory=dict,
metadata={"help": "Model metadata"},
)
@PublicAPI(stability="beta")
class LLMClient(ABC):
"""An abstract class for LLM client."""
@abstractmethod
async def generate(self, request: ModelRequest) -> ModelOutput:
"""Generate a response for a given model request.
Args:
request(ModelRequest): The model request.
Returns:
ModelOutput: The model output.
"""
@abstractmethod
async def generate_stream(
self, request: ModelRequest
) -> AsyncIterator[ModelOutput]:
"""Generate a stream of responses for a given model request.
Args:
request(ModelRequest): The model request.
Returns:
AsyncIterator[ModelOutput]: The model output stream.
"""
@abstractmethod
async def models(self) -> List[ModelMetadata]:
"""Get all the models.
Returns:
List[ModelMetadata]: A list of model metadata.
"""
@abstractmethod
async def count_token(self, model: str, prompt: str) -> int:
"""Count the number of tokens in a given prompt.
Args:
model(str): The model name.
prompt(str): The prompt.
Returns:
int: The number of tokens.
"""
class RequestBuildOperator(MapOperator[str, ModelRequest], ABC):
def __init__(self, model: str, **kwargs):
@@ -176,85 +283,52 @@ class RequestBuildOperator(MapOperator[str, ModelRequest], ABC):
return ModelRequest._build(self._model, input_value)
class BaseLLMOperator(
MapOperator[ModelRequest, ModelOutput],
StreamifyAbsOperator[ModelRequest, ModelOutput],
ABC,
):
class BaseLLM:
"""The abstract operator for a LLM."""
def __init__(self, llm_client: Optional[LLMClient] = None):
self._llm_client = llm_client
@PublicAPI(stability="beta")
class OpenAILLM(BaseLLMOperator):
"""The operator for OpenAI LLM.
@property
def llm_client(self) -> LLMClient:
"""Return the LLM client."""
if not self._llm_client:
raise ValueError("llm_client is not set")
return self._llm_client
Examples:
.. code-block:: python
llm = OpenAILLM()
model_request = ModelRequest(model="gpt-3.5-turbo", messages=[ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello")])
model_output = await llm.map(model_request)
class LLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
"""The operator for a LLM.
Args:
llm_client (LLMClient, optional): The LLM client. Defaults to None.
This operator will generate a no streaming response.
"""
def __int__(self):
try:
import openai
except ImportError as e:
raise ImportError("Please install openai package to use OpenAILLM") from e
import importlib.metadata as metadata
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client=llm_client)
MapOperator.__init__(self, **kwargs)
if not metadata.version("openai") >= "1.0.0":
raise ImportError("Please upgrade openai package to version 1.0.0 or above")
async def map(self, request: ModelRequest) -> ModelOutput:
return await self.llm_client.generate(request)
async def _send_request(
self, model_request: ModelRequest, stream: Optional[bool] = False
):
import os
from openai import AsyncOpenAI
client = AsyncOpenAI(
api_key=os.environ.get("OPENAI_API_KEY"),
base_url=os.environ.get("OPENAI_API_BASE"),
)
messages = ModelMessage.to_openai_messages(model_request._get_messages())
payloads = {
"model": model_request.model,
"stream": stream,
}
if model_request.temperature is not None:
payloads["temperature"] = model_request.temperature
if model_request.max_new_tokens:
payloads["max_tokens"] = model_request.max_new_tokens
class StreamingLLMOperator(
BaseLLM, StreamifyAbsOperator[ModelRequest, ModelOutput], ABC
):
"""The streaming operator for a LLM.
return await client.chat.completions.create(messages=messages, **payloads)
Args:
llm_client (LLMClient, optional): The LLM client. Defaults to None.
async def map(self, model_request: ModelRequest) -> ModelOutput:
try:
chat_completion = await self._send_request(model_request, stream=False)
text = chat_completion.choices[0].message.content
usage = chat_completion.usage.dict()
return ModelOutput(text=text, error_code=0, usage=usage)
except Exception as e:
return ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=1,
)
This operator will generate streaming response.
"""
async def streamify(
self, model_request: ModelRequest
) -> AsyncIterator[ModelOutput]:
try:
chat_completion = await self._send_request(model_request, stream=True)
text = ""
for r in chat_completion:
if len(r.choices) == 0:
continue
if r.choices[0].delta.content is not None:
content = r.choices[0].delta.content
text += content
yield ModelOutput(text=text, error_code=0)
except Exception as e:
yield ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=1,
)
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client=llm_client)
StreamifyAbsOperator.__init__(self, **kwargs)
async def streamify(self, request: ModelRequest) -> AsyncIterator[ModelOutput]:
async for output in self.llm_client.generate_stream(request):
yield output

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import sqlparse
import regex as re
import pandas as pd
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote
from typing import Any, Iterable, List, Optional, Dict
@@ -383,6 +382,10 @@ class RDBMSDatabase(BaseConnect):
return self.get_simple_fields(table_name)
def run_to_df(self, command: str, fetch: str = "all"):
import pandas as pd
# Pandas has too much dependence and the import time is too long
# TODO: Remove the dependency on pandas
result_lst = self.run(command, fetch)
colunms = result_lst[0]
values = result_lst[1:]

View File

@@ -1,13 +1,8 @@
import re
import sqlparse
import clickhouse_connect
from typing import List, Optional, Any, Iterable, Dict
from sqlalchemy import text
from urllib.parse import quote
from sqlalchemy.schema import CreateTable
from urllib.parse import quote_plus as urlquote
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from clickhouse_connect.driver import httputil
from dbgpt.storage.schema import DBType
from sqlalchemy import (
MetaData,
@@ -56,6 +51,11 @@ class ClickhouseConnect(RDBMSDatabase):
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> RDBMSDatabase:
import clickhouse_connect
from clickhouse_connect.driver import httputil
# Lazy import
big_pool_mgr = httputil.get_pool_manager(maxsize=16, num_pools=12)
client = clickhouse_connect.get_client(
host=host,

View File

@@ -37,7 +37,14 @@ class SQLiteConnect(RDBMSDatabase):
"""Get table indexes about specified table."""
cursor = self.session.execute(text(f"PRAGMA index_list({table_name})"))
indexes = cursor.fetchall()
return [(index[1], index[3]) for index in indexes]
result = []
for idx in indexes:
index_name = idx[1]
cursor = self.session.execute(text(f"PRAGMA index_info({index_name})"))
index_infos = cursor.fetchall()
column_names = [index_info[2] for index_info in index_infos]
result.append({"name": index_name, "column_names": column_names})
return result
def get_show_create_table(self, table_name):
"""Get table show create table about specified table."""

View File

@@ -47,7 +47,8 @@ def test_run_no_throw(db):
def test_get_indexes(db):
db.run("CREATE TABLE test (name TEXT);")
db.run("CREATE INDEX idx_name ON test(name);")
assert db.get_indexes("test") == [("idx_name", "c")]
indexes = db.get_indexes("test")
assert indexes == [{"name": "idx_name", "column_names": ["name"]}]
def test_get_indexes_empty(db):

View File

@@ -0,0 +1,4 @@
from dbgpt.model.cluster.client import DefaultLLMClient
from dbgpt.model.utils.chatgpt_utils import OpenAILLMClient
__ALL__ = ["DefaultLLMClient", "OpenAILLMClient"]

View File

@@ -30,6 +30,15 @@ class EmbeddingsRequest(BaseModel):
span_id: str = None
class CountTokenRequest(BaseModel):
model: str
prompt: str
class ModelMetadataRequest(BaseModel):
model: str
class WorkerApplyRequest(BaseModel):
model: str
apply_type: WorkerApplyType

View File

@@ -0,0 +1,40 @@
from typing import AsyncIterator, List
import asyncio
from dbgpt.core.interface.llm import LLMClient, ModelRequest, ModelOutput, ModelMetadata
from dbgpt.model.parameter import WorkerType
from dbgpt.model.cluster.manager_base import WorkerManager
class DefaultLLMClient(LLMClient):
def __init__(self, worker_manager: WorkerManager):
self._worker_manager = worker_manager
async def generate(self, request: ModelRequest) -> ModelOutput:
return await self._worker_manager.generate(request.to_dict())
async def generate_stream(
self, request: ModelRequest
) -> AsyncIterator[ModelOutput]:
async for output in self._worker_manager.generate_stream(request.to_dict()):
yield output
async def models(self) -> List[ModelMetadata]:
instances = await self._worker_manager.get_all_model_instances(
WorkerType.LLM.value, healthy_only=True
)
query_metadata_task = []
for instance in instances:
worker_name, _ = WorkerType.parse_worker_key(instance.worker_key)
query_metadata_task.append(
self._worker_manager.get_model_metadata({"model": worker_name})
)
models: List[ModelMetadata] = await asyncio.gather(*query_metadata_task)
model_map = {}
for single_model in models:
model_map[single_model.model] = single_model
return [model_map[model_name] for model_name in sorted(model_map.keys())]
async def count_token(self, model: str, prompt: str) -> int:
return await self._worker_manager.count_token(
{"model": model, "prompt": prompt}
)

View File

@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from datetime import datetime
from concurrent.futures import Future
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.core import ModelOutput
from dbgpt.core import ModelOutput, ModelMetadata
from dbgpt.model.base import WorkerSupportedModel, WorkerApplyOutput
from dbgpt.model.cluster.worker_base import ModelWorker
from dbgpt.model.cluster.base import WorkerStartupRequest, WorkerApplyRequest
@@ -38,6 +38,11 @@ class WorkerRunData:
port = self.port
return f"model {model_name}@{model_type}({host}:{port})"
@property
def stopped(self):
"""Check if the worker is stopped""" ""
return self.stop_event.is_set()
class WorkerManager(ABC):
@abstractmethod
@@ -62,6 +67,20 @@ class WorkerManager(ABC):
) -> List[WorkerRunData]:
"""Asynchronous get model instances by worker type and model name"""
@abstractmethod
async def get_all_model_instances(
self, worker_type: str, healthy_only: bool = True
) -> List[WorkerRunData]:
"""Asynchronous get all model instances
Args:
worker_type (str): worker type
healthy_only (bool, optional): only return healthy instances. Defaults to True.
Returns:
List[WorkerRunData]: worker run data list
"""
@abstractmethod
def sync_get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
@@ -112,6 +131,25 @@ class WorkerManager(ABC):
We must provide a synchronous version.
"""
@abstractmethod
async def count_token(self, params: Dict) -> int:
"""Count token of prompt
Args:
params (Dict): parameters, eg. {"prompt": "hello", "model": "vicuna-13b-v1.5"}
Returns:
int: token count
"""
@abstractmethod
async def get_model_metadata(self, params: Dict) -> ModelMetadata:
"""Get model metadata
Args:
params (Dict): parameters, eg. {"model": "vicuna-13b-v1.5"}
"""
@abstractmethod
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
"""Worker apply"""

View File

@@ -3,7 +3,7 @@ import pytest_asyncio
from contextlib import contextmanager, asynccontextmanager
from typing import List, Iterator, Dict, Tuple
from dbgpt.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
from dbgpt.core import ModelOutput
from dbgpt.core import ModelOutput, ModelMetadata
from dbgpt.model.cluster.worker_base import ModelWorker
from dbgpt.model.cluster.worker.manager import (
WorkerManager,
@@ -80,6 +80,14 @@ class MockModelWorker(ModelWorker):
output = out
return output
def count_token(self, prompt: str) -> int:
return len(prompt)
def get_model_metadata(self, params: Dict) -> ModelMetadata:
return ModelMetadata(
model=self.model_parameters.model_name,
)
def embeddings(self, params: Dict) -> List[List[float]]:
return self._embeddings

View File

@@ -8,7 +8,7 @@ import traceback
from dbgpt.configs.model_config import get_device
from dbgpt.model.adapter.base import LLMModelAdapter
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
from dbgpt.core import ModelOutput, ModelInferenceMetrics
from dbgpt.core import ModelOutput, ModelInferenceMetrics, ModelMetadata
from dbgpt.model.loader import ModelLoader, _get_model_real_path
from dbgpt.model.parameter import ModelParameters
from dbgpt.model.cluster.worker_base import ModelWorker
@@ -118,6 +118,8 @@ class DefaultModelWorker(ModelWorker):
f"Parse model max length {model_max_length} from model {self.model_name}."
)
self.context_len = model_max_length
elif hasattr(model_params, "max_context_size"):
self.context_len = model_params.max_context_size
def stop(self) -> None:
if not self.model:
@@ -186,6 +188,22 @@ class DefaultModelWorker(ModelWorker):
output = out
return output
def count_token(self, prompt: str) -> int:
return _try_to_count_token(prompt, self.tokenizer)
async def async_count_token(self, prompt: str) -> int:
# TODO if we deploy the model by vllm, it can't work, we should run transformer _try_to_count_token to async
raise NotImplementedError
def get_model_metadata(self, params: Dict) -> ModelMetadata:
return ModelMetadata(
model=self.model_name,
context_length=self.context_len,
)
async def async_get_model_metadata(self, params: Dict) -> ModelMetadata:
return self.get_model_metadata(params)
def embeddings(self, params: Dict) -> List[List[float]]:
raise NotImplementedError
@@ -436,6 +454,25 @@ def _new_metrics_from_model_output(
return metrics
def _try_to_count_token(prompt: str, tokenizer) -> int:
"""Try to count token of prompt
Args:
prompt (str): prompt
tokenizer ([type]): tokenizer
Returns:
int: token count, if error return -1
TODO: More implementation
"""
try:
return len(tokenizer(prompt).input_ids[0])
except Exception as e:
logger.warning(f"Count token error, detail: {e}, return -1")
return -1
def _try_import_torch():
global torch
global _torch_imported

View File

@@ -2,6 +2,7 @@ import logging
from typing import Dict, List, Type, Optional
from dbgpt.configs.model_config import get_device
from dbgpt.core import ModelMetadata
from dbgpt.model.loader import _get_model_real_path
from dbgpt.model.parameter import (
EmbeddingModelParameters,
@@ -89,6 +90,14 @@ class EmbeddingsModelWorker(ModelWorker):
"""Generate non stream result"""
raise NotImplementedError("Not supported generate for embeddings model")
def count_token(self, prompt: str) -> int:
raise NotImplementedError("Not supported count_token for embeddings model")
def get_model_metadata(self, params: Dict) -> ModelMetadata:
raise NotImplementedError(
"Not supported get_model_metadata for embeddings model"
)
def embeddings(self, params: Dict) -> List[List[float]]:
model = params.get("model")
logger.info(f"Receive embeddings request, model: {model}")

View File

@@ -15,7 +15,7 @@ from fastapi.responses import StreamingResponse
from dbgpt.component import SystemApp
from dbgpt.configs.model_config import LOGDIR
from dbgpt.core import ModelOutput
from dbgpt.core import ModelOutput, ModelMetadata
from dbgpt.model.base import (
ModelInstance,
WorkerApplyOutput,
@@ -271,6 +271,18 @@ class LocalWorkerManager(WorkerManager):
) -> List[WorkerRunData]:
return self.sync_get_model_instances(worker_type, model_name, healthy_only)
async def get_all_model_instances(
self, worker_type: str, healthy_only: bool = True
) -> List[WorkerRunData]:
instances = list(itertools.chain(*self.workers.values()))
result = []
for instance in instances:
name, wt = WorkerType.parse_worker_key(instance.worker_key)
if wt != worker_type or (healthy_only and instance.stopped):
continue
result.append(instance)
return result
def sync_get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
@@ -390,6 +402,43 @@ class LocalWorkerManager(WorkerManager):
worker_run_data = self._sync_get_model(params, worker_type="text2vec")
return worker_run_data.worker.embeddings(params)
async def count_token(self, params: Dict) -> int:
"""Count token of prompt"""
with root_tracer.start_span(
"WorkerManager.count_token", params.get("span_id")
) as span:
params["span_id"] = span.span_id
try:
worker_run_data = await self._get_model(params)
except Exception as e:
raise e
prompt = params.get("prompt")
async with worker_run_data.semaphore:
if worker_run_data.worker.support_async():
return await worker_run_data.worker.async_count_token(prompt)
else:
return await self.run_blocking_func(
worker_run_data.worker.count_token, prompt
)
async def get_model_metadata(self, params: Dict) -> ModelMetadata:
"""Get model metadata"""
with root_tracer.start_span(
"WorkerManager.get_model_metadata", params.get("span_id")
) as span:
params["span_id"] = span.span_id
try:
worker_run_data = await self._get_model(params)
except Exception as e:
raise e
async with worker_run_data.semaphore:
if worker_run_data.worker.support_async():
return await worker_run_data.worker.async_get_model_metadata(params)
else:
return await self.run_blocking_func(
worker_run_data.worker.get_model_metadata, params
)
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
apply_func: Callable[[WorkerApplyRequest], Awaitable[str]] = None
if apply_req.apply_type == WorkerApplyType.START:
@@ -601,6 +650,13 @@ class WorkerManagerAdapter(WorkerManager):
worker_type, model_name, healthy_only
)
async def get_all_model_instances(
self, worker_type: str, healthy_only: bool = True
) -> List[WorkerRunData]:
return await self.worker_manager.get_all_model_instances(
worker_type, healthy_only
)
def sync_get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
@@ -635,6 +691,12 @@ class WorkerManagerAdapter(WorkerManager):
def sync_embeddings(self, params: Dict) -> List[List[float]]:
return self.worker_manager.sync_embeddings(params)
async def count_token(self, params: Dict) -> int:
return await self.worker_manager.count_token(params)
async def get_model_metadata(self, params: Dict) -> ModelMetadata:
return await self.worker_manager.get_model_metadata(params)
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
return await self.worker_manager.worker_apply(apply_req)
@@ -696,6 +758,24 @@ async def api_embeddings(request: EmbeddingsRequest):
return await worker_manager.embeddings(params)
@router.post("/worker/count_token")
async def api_count_token(request: CountTokenRequest):
params = request.dict(exclude_none=True)
span_id = root_tracer.get_current_span_id()
if "span_id" not in params and span_id:
params["span_id"] = span_id
return await worker_manager.count_token(params)
@router.post("/worker/model_metadata")
async def api_get_model_metadata(request: ModelMetadataRequest):
params = request.dict(exclude_none=True)
span_id = root_tracer.get_current_span_id()
if "span_id" not in params and span_id:
params["span_id"] = span_id
return await worker_manager.get_model_metadata(params)
@router.post("/worker/apply")
async def api_worker_apply(request: WorkerApplyRequest):
return await worker_manager.worker_apply(request)

View File

@@ -133,22 +133,29 @@ class RemoteWorkerManager(LocalWorkerManager):
self, model_name: str, instances: List[ModelInstance]
) -> List[WorkerRunData]:
worker_instances = []
for ins in instances:
worker = RemoteModelWorker()
worker.load_worker(model_name, model_name, host=ins.host, port=ins.port)
wr = WorkerRunData(
host=ins.host,
port=ins.port,
worker_key=ins.model_name,
worker=worker,
worker_params=None,
model_params=None,
stop_event=asyncio.Event(),
semaphore=asyncio.Semaphore(100), # Not limit in client
for instance in instances:
worker_instances.append(
self._build_single_worker_instance(model_name, instance)
)
worker_instances.append(wr)
return worker_instances
def _build_single_worker_instance(self, model_name: str, instance: ModelInstance):
worker = RemoteModelWorker()
worker.load_worker(
model_name, model_name, host=instance.host, port=instance.port
)
wr = WorkerRunData(
host=instance.host,
port=instance.port,
worker_key=instance.model_name,
worker=worker,
worker_params=None,
model_params=None,
stop_event=asyncio.Event(),
semaphore=asyncio.Semaphore(100), # Not limit in client
)
return wr
async def get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
@@ -158,6 +165,20 @@ class RemoteWorkerManager(LocalWorkerManager):
)
return self._build_worker_instances(model_name, instances)
async def get_all_model_instances(
self, worker_type: str, healthy_only: bool = True
) -> List[WorkerRunData]:
instances: List[
ModelInstance
] = await self.model_registry.get_all_model_instances(healthy_only=healthy_only)
result = []
for instance in instances:
name, wt = WorkerType.parse_worker_key(instance.model_name)
if wt != worker_type:
continue
result.append(self._build_single_worker_instance(name, instance))
return result
def sync_get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:

View File

@@ -1,7 +1,7 @@
import json
from typing import Dict, Iterator, List
import logging
from dbgpt.core import ModelOutput
from dbgpt.core import ModelOutput, ModelMetadata
from dbgpt.model.parameter import ModelParameters
from dbgpt.model.cluster.worker_base import ModelWorker
@@ -90,6 +90,44 @@ class RemoteModelWorker(ModelWorker):
)
return ModelOutput(**response.json())
def count_token(self, prompt: str) -> int:
raise NotImplementedError
async def async_count_token(self, prompt: str) -> int:
import httpx
async with httpx.AsyncClient() as client:
url = self.worker_addr + "/count_token"
logger.debug(f"Send async_count_token to url {url}, params: {prompt}")
response = await client.post(
url,
headers=self.headers,
json={"prompt": prompt},
timeout=self.timeout,
)
return response.json()
async def async_get_model_metadata(self, params: Dict) -> ModelMetadata:
"""Asynchronously get model metadata"""
import httpx
async with httpx.AsyncClient() as client:
url = self.worker_addr + "/model_metadata"
logger.debug(
f"Send async_get_model_metadata to url {url}, params: {params}"
)
response = await client.post(
url,
headers=self.headers,
json=params,
timeout=self.timeout,
)
return ModelMetadata(**response.json())
def get_model_metadata(self, params: Dict) -> ModelMetadata:
"""Get model metadata"""
raise NotImplementedError
def embeddings(self, params: Dict) -> List[List[float]]:
"""Get embeddings for input"""
import requests

View File

@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from typing import Dict, Iterator, List, Type
from dbgpt.core import ModelOutput
from dbgpt.core import ModelOutput, ModelMetadata
from dbgpt.model.parameter import ModelParameters, WorkerType
from dbgpt.util.parameter_utils import (
ParameterDescription,
@@ -92,6 +92,42 @@ class ModelWorker(ABC):
"""Asynchronously generate output (non-stream) based on provided parameters."""
raise NotImplementedError
@abstractmethod
def count_token(self, prompt: str) -> int:
"""Count token of prompt
Args:
prompt (str): prompt
Returns:
int: token count
"""
async def async_count_token(self, prompt: str) -> int:
"""Asynchronously count token of prompt
Args:
prompt (str): prompt
Returns:
int: token count
"""
raise NotImplementedError
@abstractmethod
def get_model_metadata(self, params: Dict) -> ModelMetadata:
"""Get model metadata
Args:
params (Dict): parameters, eg. {"model": "vicuna-13b-v1.5"}
"""
async def async_get_model_metadata(self, params: Dict) -> ModelMetadata:
"""Asynchronously get model metadata
Args:
params (Dict): parameters, eg. {"model": "vicuna-13b-v1.5"}
"""
raise NotImplementedError
@abstractmethod
def embeddings(self, params: Dict) -> List[List[float]]:
"""

View File

@@ -70,7 +70,7 @@ def _initialize_openai_v1(params: ProxyModelParameters):
api_type = params.proxy_api_type or os.getenv("OPENAI_API_TYPE", "open_ai")
base_url = params.proxy_api_base or os.getenv(
"OPENAI_API_TYPE",
"OPENAI_API_BASE",
os.getenv("AZURE_OPENAI_ENDPOINT") if api_type == "azure" else None,
)
api_key = params.proxy_api_key or os.getenv(

View File

View File

@@ -0,0 +1,282 @@
from __future__ import annotations
import os
import logging
from dataclasses import dataclass
import importlib.metadata as metadata
from typing import List, Dict, Any, Optional, TYPE_CHECKING, Union, AsyncIterator
from dbgpt.core.interface.llm import ModelMetadata, LLMClient
from dbgpt.core.interface.llm import ModelOutput, ModelRequest
if TYPE_CHECKING:
import httpx
from httpx._types import ProxiesTypes
from openai import AsyncAzureOpenAI
from openai import AsyncOpenAI
ClientType = Union[AsyncAzureOpenAI, AsyncOpenAI]
logger = logging.getLogger(__name__)
@dataclass
class OpenAIParameters:
"""A class to represent a LLM model."""
api_type: str = "open_ai"
api_base: Optional[str] = None
api_key: Optional[str] = None
api_version: Optional[str] = None
full_url: Optional[str] = None
proxies: Optional["ProxiesTypes"] = None
def _initialize_openai_v1(init_params: OpenAIParameters):
try:
from openai import OpenAI
except ImportError as exc:
raise ValueError(
"Could not import python package: openai "
"Please install openai by command `pip install openai"
) from exc
if not metadata.version("openai") >= "1.0.0":
raise ImportError("Please upgrade openai package to version 1.0.0 or above")
api_type: Optional[str] = init_params.api_type
api_base: Optional[str] = init_params.api_base
api_key: Optional[str] = init_params.api_key
api_version: Optional[str] = init_params.api_version
full_url: Optional[str] = init_params.full_url
api_type = api_type or os.getenv("OPENAI_API_TYPE", "open_ai")
base_url = api_base or os.getenv(
"OPENAI_API_BASE",
os.getenv("AZURE_OPENAI_ENDPOINT") if api_type == "azure" else None,
)
api_key = api_key or os.getenv(
"OPENAI_API_KEY",
os.getenv("AZURE_OPENAI_KEY") if api_type == "azure" else None,
)
api_version = api_version or os.getenv("OPENAI_API_VERSION")
if not base_url and full_url:
base_url = full_url.split("/chat/completions")[0]
if api_key is None:
raise ValueError("api_key is required, please set OPENAI_API_KEY environment")
if base_url is None:
raise ValueError("base_url is required, please set OPENAI_BASE_URL environment")
if base_url.endswith("/"):
base_url = base_url[:-1]
openai_params = {
"api_key": api_key,
"base_url": base_url,
}
return openai_params, api_type, api_version
def _build_openai_client(init_params: OpenAIParameters):
import httpx
openai_params, api_type, api_version = _initialize_openai_v1(init_params)
if api_type == "azure":
from openai import AsyncAzureOpenAI
return AsyncAzureOpenAI(
api_key=openai_params["api_key"],
api_version=api_version,
azure_endpoint=openai_params["base_url"],
http_client=httpx.AsyncClient(proxies=init_params.proxies),
)
else:
from openai import AsyncOpenAI
return AsyncOpenAI(
**openai_params, http_client=httpx.AsyncClient(proxies=init_params.proxies)
)
class OpenAILLMClient(LLMClient):
"""An implementation of LLMClient using OpenAI API.
In order to have as few dependencies as possible, we directly use the http API.
"""
def __init__(
self,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_type: Optional[str] = None,
api_version: Optional[str] = None,
model: Optional[str] = "gpt-3.5-turbo",
proxies: Optional["ProxiesTypes"] = None,
timeout: Optional[int] = 240,
model_alias: Optional[str] = "chatgpt_proxyllm",
context_length: Optional[int] = 8192,
openai_client: Optional["ClientType"] = None,
openai_kwargs: Optional[Dict[str, Any]] = None,
):
self._init_params = OpenAIParameters(
api_type=api_type,
api_base=api_base,
api_key=api_key,
api_version=api_version,
proxies=proxies,
)
self._model = model
self._proxies = proxies
self._timeout = timeout
self._model_alias = model_alias
self._context_length = context_length
self._client = openai_client
self._openai_kwargs = openai_kwargs or {}
@property
def client(self) -> ClientType:
if self._client is None:
self._client = _build_openai_client(init_params=self._init_params)
return self._client
def _build_request(
self, request: ModelRequest, stream: Optional[bool] = False
) -> Dict[str, Any]:
payload = {"model": request.model or self._model, "stream": stream}
# Apply openai kwargs
for k, v in self._openai_kwargs.items():
payload[k] = v
if request.temperature:
payload["temperature"] = request.temperature
if request.max_new_tokens:
payload["max_tokens"] = request.max_new_tokens
return payload
async def generate(self, request: ModelRequest) -> ModelOutput:
messages = request.to_openai_messages()
payload = self._build_request(request)
try:
chat_completion = await self.client.chat.completions.create(
messages=messages, **payload
)
text = chat_completion.choices[0].message.content
usage = chat_completion.usage.dict()
return ModelOutput(text=text, error_code=0, usage=usage)
except Exception as e:
return ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=1,
)
async def generate_stream(
self, request: ModelRequest
) -> AsyncIterator[ModelOutput]:
messages = request.to_openai_messages()
payload = self._build_request(request)
try:
chat_completion = await self.client.chat.completions.create(
messages=messages, **payload
)
text = ""
for r in chat_completion:
if len(r.choices) == 0:
continue
if r.choices[0].delta.content is not None:
content = r.choices[0].delta.content
text += content
yield ModelOutput(text=text, error_code=0)
except Exception as e:
yield ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=1,
)
async def models(self) -> List[ModelMetadata]:
model_metadata = ModelMetadata(
model=self._model_alias,
context_length=await self.get_context_length(),
)
return [model_metadata]
async def get_context_length(self) -> int:
"""Get the context length of the model.
Returns:
int: The context length.
# TODO: This is a temporary solution. We should have a better way to get the context length.
eg. get real context length from the openai api.
"""
return self._context_length
async def count_token(self, model: str, prompt: str) -> int:
"""Count the number of tokens in a given prompt.
TODO: Get the real number of tokens from the openai api or tiktoken package
"""
raise NotImplementedError()
async def _to_openai_stream(
model: str, output_iter: AsyncIterator[ModelOutput]
) -> AsyncIterator[str]:
"""Convert the output_iter to openai stream format.
Args:
model (str): The model name.
output_iter (AsyncIterator[ModelOutput]): The output iterator.
"""
import json
import shortuuid
from fastchat.protocol.openai_api_protocol import (
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
DeltaMessage,
)
id = f"chatcmpl-{shortuuid.random()}"
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(id=id, choices=[choice_data], model=model)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
previous_text = ""
finish_stream_events = []
async for model_output in output_iter:
model_output: ModelOutput = model_output
if model_output.error_code != 0:
yield f"data: {json.dumps(model_output.to_dict(), ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return
decoded_unicode = model_output.text.replace("\ufffd", "")
delta_text = decoded_unicode[len(previous_text) :]
previous_text = (
decoded_unicode
if len(decoded_unicode) > len(previous_text)
else previous_text
)
if len(delta_text) == 0:
delta_text = None
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=delta_text),
finish_reason=model_output.finish_reason,
)
chunk = ChatCompletionStreamResponse(id=id, choices=[choice_data], model=model)
if delta_text is None:
if model_output.finish_reason is not None:
finish_stream_events.append(chunk)
continue
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
for finish_chunk in finish_stream_events:
yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"

View File

@@ -1,5 +1,6 @@
import logging
import traceback
from dbgpt.component import SystemApp
from dbgpt._private.config import Config
from dbgpt.configs.model_config import (
@@ -67,8 +68,9 @@ class DBSummaryClient:
try:
self.db_summary_embedding(item["db_name"], item["db_type"])
except Exception as e:
message = traceback.format_exc()
logger.warn(
f'{item["db_name"]}, {item["db_type"]} summary error!{str(e)}', e
f'{item["db_name"]}, {item["db_type"]} summary error!{str(e)}, detail: {message}'
)
def init_db_profile(self, db_summary_client, dbname, embeddings):

View File

@@ -0,0 +1,149 @@
"""AWEL: Simple llm client example
DB-GPT will automatically load and execute the current file after startup.
Example:
.. code-block:: shell
DBGPT_SERVER="http://127.0.0.1:5000"
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/generate \
-H "Content-Type: application/json" -d '{
"model": "proxyllm",
"messages": "hello"
}'
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/generate_stream \
-H "Content-Type: application/json" -d '{
"model": "proxyllm",
"messages": "hello",
"stream": true
}'
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/count_token \
-H "Content-Type: application/json" -d '{
"model": "proxyllm",
"messages": "hello"
}'
"""
from typing import Dict, Any, AsyncIterator, Optional, Union, List
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.component import ComponentType
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator, TransformStreamAbsOperator
from dbgpt.core import (
ModelMessage,
LLMClient,
LLMOperator,
StreamingLLMOperator,
ModelOutput,
ModelRequest,
)
from dbgpt.model import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
class TriggerReqBody(BaseModel):
messages: Union[str, List[Dict[str, str]]] = Field(
..., description="User input messages"
)
model: str = Field(..., description="Model name")
stream: Optional[bool] = Field(default=False, description="Whether return stream")
class RequestHandleOperator(MapOperator[TriggerReqBody, ModelRequest]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, input_value: TriggerReqBody) -> ModelRequest:
messages = [ModelMessage.build_human_message(input_value.messages)]
await self.current_dag_context.save_to_share_data(
"request_model_name", input_value.model
)
return ModelRequest(
model=input_value.model,
messages=messages,
echo=False,
)
class LLMMixin:
@property
def llm_client(self) -> LLMClient:
if not self._llm_client:
worker_manager = self.system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
self._llm_client = DefaultLLMClient(worker_manager)
return self._llm_client
class MyLLMOperator(LLMMixin, LLMOperator):
def __init__(self, llm_client: LLMClient = None, **kwargs):
super().__init__(llm_client, **kwargs)
class MyStreamingLLMOperator(LLMMixin, StreamingLLMOperator):
def __init__(self, llm_client: LLMClient = None, **kwargs):
super().__init__(llm_client, **kwargs)
class MyLLMStreamingOperator(TransformStreamAbsOperator[ModelOutput, str]):
async def transform_stream(
self, input_value: AsyncIterator[ModelOutput]
) -> AsyncIterator[str]:
from dbgpt.model.utils.chatgpt_utils import _to_openai_stream
model = await self.current_dag_context.get_share_data("request_model_name")
async for output in _to_openai_stream(model, input_value):
yield output
class MyModelToolOperator(LLMMixin, MapOperator[TriggerReqBody, Dict[str, Any]]):
def __init__(self, llm_client: LLMClient = None, **kwargs):
self._llm_client = llm_client
MapOperator.__init__(self, **kwargs)
async def map(self, input_value: TriggerReqBody) -> Dict[str, Any]:
prompt_tokens = await self.llm_client.count_token(
input_value.model, input_value.messages
)
available_models = await self.llm_client.models()
return {
"prompt_tokens": prompt_tokens,
"available_models": available_models,
}
with DAG("dbgpt_awel_simple_llm_client_generate") as client_generate_dag:
# Receive http request and trigger dag to run.
trigger = HttpTrigger(
"/examples/simple_client/generate", methods="POST", request_body=TriggerReqBody
)
request_handle_task = RequestHandleOperator()
model_task = MyLLMOperator()
model_parse_task = MapOperator(lambda out: out.to_dict())
trigger >> request_handle_task >> model_task >> model_parse_task
with DAG("dbgpt_awel_simple_llm_client_generate_stream") as client_generate_stream_dag:
# Receive http request and trigger dag to run.
trigger = HttpTrigger(
"/examples/simple_client/generate_stream",
methods="POST",
request_body=TriggerReqBody,
streaming_response=True,
)
request_handle_task = RequestHandleOperator()
model_task = MyStreamingLLMOperator()
openai_format_stream_task = MyLLMStreamingOperator()
trigger >> request_handle_task >> model_task >> openai_format_stream_task
with DAG("dbgpt_awel_simple_llm_client_count_token") as client_count_token_dag:
# Receive http request and trigger dag to run.
trigger = HttpTrigger(
"/examples/simple_client/count_token",
methods="POST",
request_body=TriggerReqBody,
)
model_task = MyModelToolOperator()
trigger >> model_task

View File

@@ -1,13 +1,19 @@
import asyncio
from dbgpt.core.awel import DAG
from dbgpt.core import BaseOutputParser, OpenAILLM, RequestBuildOperator, PromptTemplate
from dbgpt.core import (
BaseOutputParser,
RequestBuildOperator,
PromptTemplate,
LLMOperator,
)
from dbgpt.model import OpenAILLMClient
with DAG("simple_sdk_llm_example_dag") as dag:
prompt_task = PromptTemplate.from_template(
"Write a SQL of {dialect} to query all data of {table_name}."
)
model_pre_handle_task = RequestBuildOperator(model="gpt-3.5-turbo")
llm_task = OpenAILLM()
llm_task = LLMOperator(OpenAILLMClient())
out_parse_task = BaseOutputParser()
prompt_task >> model_pre_handle_task >> llm_task >> out_parse_task

View File

@@ -8,10 +8,16 @@ from dbgpt.core.awel import (
JoinOperator,
MapOperator,
)
from dbgpt.core import SQLOutputParser, OpenAILLM, RequestBuildOperator, PromptTemplate
from dbgpt.core import (
SQLOutputParser,
LLMOperator,
RequestBuildOperator,
PromptTemplate,
)
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
from dbgpt.datasource.operator.datasource_operator import DatasourceOperator
from dbgpt.rag.operator.datasource import DatasourceRetrieverOperator
from dbgpt.model import OpenAILLMClient
def _create_temporary_connection():
@@ -115,7 +121,7 @@ with DAG("simple_sdk_llm_sql_example") as dag:
prompt_input_task = JoinOperator(combine_function=_join_func)
prompt_task = PromptTemplate.from_template(_sql_prompt())
model_pre_handle_task = RequestBuildOperator(model="gpt-3.5-turbo")
llm_task = OpenAILLM()
llm_task = LLMOperator(OpenAILLMClient())
out_parse_task = SQLOutputParser()
sql_parse_task = MapOperator(map_function=lambda x: x["sql"])
db_query_task = DatasourceOperator(connection=db_connection)

View File

@@ -10,7 +10,7 @@ pytest-mock
pytest-recording
pytesseract==0.3.10
aioresponses
# python code format, usage `black .`
black
# for git hooks
pre-commit
# Type checking
mypy==0.991

View File

@@ -0,0 +1,11 @@
# python code format, usage `black .`
black==22.8.0
blackdoc==0.3.7
flake8==5.0.4
flake8-bugbear==22.10.25
flake8-comprehensions==3.10.0
flake8-docstrings==1.6.0
flake8-simplify==0.19.3
flake8-tidy-imports==4.8.0
isort==5.10.1
pyupgrade==3.1.0

View File

@@ -364,23 +364,40 @@ def core_requires():
"prettytable",
"cachetools",
]
setup_spec.extras["framework"] = [
"coloredlogs",
# Just use by DB-GPT internal, we should find the smallest dependency set for run we core unit test.
# The dependency "framework" is too large for now.
setup_spec.extras["simple_framework"] = setup_spec.extras["core"] + [
"pydantic<2,>=1",
"httpx",
"fastapi==0.98.0",
"shortuuid",
# change from fixed version 2.0.22 to variable version, because other dependencies are >=1.4, such as pydoris is <2
"SQLAlchemy>=1.4,<3",
# for cache
"msgpack",
# for cache, TODO pympler has not been updated for a long time and needs to find a new toolkit.
"pympler",
"sqlparse==0.4.4",
"duckdb==0.8.1",
"duckdb-engine",
]
# TODO: remove fschat from simple_framework
if BUILD_FROM_SOURCE:
setup_spec.extras["simple_framework"].append(
f"fschat @ {BUILD_FROM_SOURCE_URL_FAST_CHAT}"
)
else:
setup_spec.extras["simple_framework"].append("fschat")
setup_spec.extras["framework"] = setup_spec.extras["simple_framework"] + [
"coloredlogs",
"seaborn",
# https://github.com/eosphoros-ai/DB-GPT/issues/551
"pandas==2.0.3",
"auto-gpt-plugin-template",
"gTTS==2.3.1",
"langchain>=0.0.286",
# change from fixed version 2.0.22 to variable version, because other dependencies are >=1.4, such as pydoris is <2
"SQLAlchemy>=1.4,<3",
"fastapi==0.98.0",
"pymysql",
"duckdb==0.8.1",
"duckdb-engine",
"jsonschema",
# TODO move transformers to default
# "transformers>=4.31.0",
@@ -390,20 +407,10 @@ def core_requires():
"openpyxl==3.1.2",
"chardet==5.1.0",
"xlrd==2.0.1",
# for cache, TODO pympler has not been updated for a long time and needs to find a new toolkit.
"pympler",
"aiofiles",
# for cache
"msgpack",
# for agent
"GitPython",
]
if BUILD_FROM_SOURCE:
setup_spec.extras["framework"].append(
f"fschat @ {BUILD_FROM_SOURCE_URL_FAST_CHAT}"
)
else:
setup_spec.extras["framework"].append("fschat")
def knowledge_requires():