feat(model): Proxy model support count token (#996)

This commit is contained in:
Fangyin Cheng
2023-12-29 12:01:31 +08:00
committed by GitHub
parent ba0599ebf4
commit 0cdc77abb2
16 changed files with 366 additions and 248 deletions

View File

@@ -189,7 +189,7 @@ class DefaultModelWorker(ModelWorker):
return output return output
def count_token(self, prompt: str) -> int: def count_token(self, prompt: str) -> int:
return _try_to_count_token(prompt, self.tokenizer) return _try_to_count_token(prompt, self.tokenizer, self.model)
async def async_count_token(self, prompt: str) -> int: async def async_count_token(self, prompt: str) -> int:
# TODO if we deploy the model by vllm, it can't work, we should run transformer _try_to_count_token to async # TODO if we deploy the model by vllm, it can't work, we should run transformer _try_to_count_token to async
@@ -454,12 +454,13 @@ def _new_metrics_from_model_output(
return metrics return metrics
def _try_to_count_token(prompt: str, tokenizer) -> int: def _try_to_count_token(prompt: str, tokenizer, model) -> int:
"""Try to count token of prompt """Try to count token of prompt
Args: Args:
prompt (str): prompt prompt (str): prompt
tokenizer ([type]): tokenizer tokenizer ([type]): tokenizer
model ([type]): model
Returns: Returns:
int: token count, if error return -1 int: token count, if error return -1
@@ -467,6 +468,11 @@ def _try_to_count_token(prompt: str, tokenizer) -> int:
TODO: More implementation TODO: More implementation
""" """
try: try:
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
if isinstance(model, ProxyModel):
return model.count_token(prompt)
# Only support huggingface model now
return len(tokenizer(prompt).input_ids[0]) return len(tokenizer(prompt).input_ids[0])
except Exception as e: except Exception as e:
logger.warning(f"Count token error, detail: {e}, return -1") logger.warning(f"Count token error, detail: {e}, return -1")

View File

@@ -197,7 +197,7 @@ class LocalWorkerManager(WorkerManager):
return True return True
else: else:
# TODO Update worker # TODO Update worker
logger.warn(f"Instance {worker_key} exist") logger.warning(f"Instance {worker_key} exist")
return False return False
def _remove_worker(self, worker_params: ModelWorkerParameters) -> None: def _remove_worker(self, worker_params: ModelWorkerParameters) -> None:
@@ -229,7 +229,7 @@ class LocalWorkerManager(WorkerManager):
) )
if not success: if not success:
msg = f"Add worker {model_name}@{worker_type}, worker instances is exist" msg = f"Add worker {model_name}@{worker_type}, worker instances is exist"
logger.warn(f"{msg}, worker_params: {worker_params}") logger.warning(f"{msg}, worker_params: {worker_params}")
self._remove_worker(worker_params) self._remove_worker(worker_params)
raise Exception(msg) raise Exception(msg)
supported_types = WorkerType.values() supported_types = WorkerType.values()

View File

@@ -92,11 +92,11 @@ def _initialize_openai_v1(params: ProxyModelParameters):
def __convert_2_gpt_messages(messages: List[ModelMessage]): def __convert_2_gpt_messages(messages: List[ModelMessage]):
chat_round = 0
gpt_messages = [] gpt_messages = []
last_usr_message = "" last_usr_message = ""
system_messages = [] system_messages = []
# TODO: We can't change message order in low level
for message in messages: for message in messages:
if message.role == ModelMessageRoleType.HUMAN or message.role == "user": if message.role == ModelMessageRoleType.HUMAN or message.role == "user":
last_usr_message = message.content last_usr_message = message.content

View File

@@ -1,9 +1,36 @@
from __future__ import annotations
from typing import Union, List, Optional, TYPE_CHECKING
import logging
from dbgpt.model.parameter import ProxyModelParameters from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper
if TYPE_CHECKING:
from dbgpt.core.interface.message import ModelMessage, BaseMessage
logger = logging.getLogger(__name__)
class ProxyModel: class ProxyModel:
def __init__(self, model_params: ProxyModelParameters) -> None: def __init__(self, model_params: ProxyModelParameters) -> None:
self._model_params = model_params self._model_params = model_params
self._tokenizer = ProxyTokenizerWrapper()
def get_params(self) -> ProxyModelParameters: def get_params(self) -> ProxyModelParameters:
return self._model_params return self._model_params
def count_token(
self,
messages: Union[str, BaseMessage, ModelMessage, List[ModelMessage]],
model_name: Optional[int] = None,
) -> int:
"""Count token of given messages
Args:
messages (Union[str, BaseMessage, ModelMessage, List[ModelMessage]]): messages to count token
model_name (Optional[int], optional): model name. Defaults to None.
Returns:
int: token count, -1 if failed
"""
return self._tokenizer.count_token(messages, model_name)

View File

@@ -25,6 +25,7 @@ from dbgpt.core.interface.llm import ModelOutput, ModelRequest
from dbgpt.model.cluster.client import DefaultLLMClient from dbgpt.model.cluster.client import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt._private.pydantic import model_to_json from dbgpt._private.pydantic import model_to_json
from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper
if TYPE_CHECKING: if TYPE_CHECKING:
import httpx import httpx
@@ -152,6 +153,7 @@ class OpenAILLMClient(LLMClient):
self._context_length = context_length self._context_length = context_length
self._client = openai_client self._client = openai_client
self._openai_kwargs = openai_kwargs or {} self._openai_kwargs = openai_kwargs or {}
self._tokenizer = ProxyTokenizerWrapper()
@property @property
def client(self) -> ClientType: def client(self) -> ClientType:
@@ -238,10 +240,11 @@ class OpenAILLMClient(LLMClient):
async def count_token(self, model: str, prompt: str) -> int: async def count_token(self, model: str, prompt: str) -> int:
"""Count the number of tokens in a given prompt. """Count the number of tokens in a given prompt.
TODO: Get the real number of tokens from the openai api or tiktoken package Args:
model (str): The model name.
prompt (str): The prompt.
""" """
return self._tokenizer.count_token(prompt, model)
raise NotImplementedError()
class OpenAIStreamingOperator(TransformStreamAbsOperator[ModelOutput, str]): class OpenAIStreamingOperator(TransformStreamAbsOperator[ModelOutput, str]):

View File

@@ -0,0 +1,80 @@
from __future__ import annotations
from typing import Union, List, Optional, TYPE_CHECKING
import logging
if TYPE_CHECKING:
from dbgpt.core.interface.message import ModelMessage, BaseMessage
logger = logging.getLogger(__name__)
class ProxyTokenizerWrapper:
def __init__(self) -> None:
self._support_encoding = True
self._encoding_model = None
def count_token(
self,
messages: Union[str, BaseMessage, ModelMessage, List[ModelMessage]],
model_name: Optional[str] = None,
) -> int:
"""Count token of given messages
Args:
messages (Union[str, BaseMessage, ModelMessage, List[ModelMessage]]): messages to count token
model_name (Optional[str], optional): model name. Defaults to None.
Returns:
int: token count, -1 if failed
"""
if not self._support_encoding:
logger.warning(
"model does not support encoding model, can't count token, returning -1"
)
return -1
encoding = self._get_or_create_encoding_model(model_name)
cnt = 0
if isinstance(messages, str):
cnt = len(encoding.encode(messages, disallowed_special=()))
elif isinstance(messages, BaseMessage):
cnt = len(encoding.encode(messages.content, disallowed_special=()))
elif isinstance(messages, ModelMessage):
cnt = len(encoding.encode(messages.content, disallowed_special=()))
elif isinstance(messages, list):
for message in messages:
cnt += len(encoding.encode(message.content, disallowed_special=()))
else:
logger.warning(
"unsupported type of messages, can't count token, returning -1"
)
return -1
return cnt
def _get_or_create_encoding_model(self, model_name: Optional[str] = None):
"""Get or create encoding model for given model name
More detail see: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
"""
if self._encoding_model:
return self._encoding_model
try:
import tiktoken
logger.info(
"tiktoken installed, using it to count tokens, tiktoken will download tokenizer from network, "
"also you can download it and put it in the directory of environment variable TIKTOKEN_CACHE_DIR"
)
except ImportError:
self._support_encoding = False
logger.warn("tiktoken not installed, cannot count tokens, returning -1")
return -1
try:
if not model_name:
model_name = "gpt-3.5-turbo"
self._encoding_model = tiktoken.model.encoding_for_model(model_name)
except KeyError:
logger.warning(
f"{model_name}'s tokenizer not found, using cl100k_base encoding."
)
self._encoding_model = tiktoken.get_encoding("cl100k_base")
return self._encoding_model

View File

@@ -1,5 +1,3 @@
from typing import List
import pytest import pytest
from dbgpt.storage.metadata import db from dbgpt.storage.metadata import db
@@ -39,11 +37,9 @@ def test_table_exist():
def test_entity_create(default_entity_dict): def test_entity_create(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
# TODO: implement your test case
with db.session() as session: with db.session() as session:
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id) entity = ServeEntity(**default_entity_dict)
assert db_entity.id == entity.id session.add(entity)
def test_entity_unique_key(default_entity_dict): def test_entity_unique_key(default_entity_dict):
@@ -52,10 +48,8 @@ def test_entity_unique_key(default_entity_dict):
def test_entity_get(default_entity_dict): def test_entity_get(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
db_entity: ServeEntity = ServeEntity.get(entity.id)
assert db_entity.id == entity.id
# TODO: implement your test case # TODO: implement your test case
pass
def test_entity_update(default_entity_dict): def test_entity_update(default_entity_dict):
@@ -65,10 +59,7 @@ def test_entity_update(default_entity_dict):
def test_entity_delete(default_entity_dict): def test_entity_delete(default_entity_dict):
# TODO: implement your test case # TODO: implement your test case
entity: ServeEntity = ServeEntity.create(**default_entity_dict) pass
entity.delete()
db_entity: ServeEntity = ServeEntity.get(entity.id)
assert db_entity is None
def test_entity_all(): def test_entity_all():

View File

@@ -47,9 +47,11 @@ def test_table_exist():
def test_entity_create(default_entity_dict): def test_entity_create(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
with db.session() as session: with db.session() as session:
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id) entity: ServeEntity = ServeEntity(**default_entity_dict)
session.add(entity)
session.commit()
db_entity: ServeEntity = session.get(ServeEntity, entity.id)
assert db_entity.id == entity.id assert db_entity.id == entity.id
assert db_entity.chat_scene == "chat_data" assert db_entity.chat_scene == "chat_data"
assert db_entity.sub_chat_scene == "excel" assert db_entity.sub_chat_scene == "excel"
@@ -63,78 +65,96 @@ def test_entity_create(default_entity_dict):
def test_entity_unique_key(default_entity_dict): def test_entity_unique_key(default_entity_dict):
ServeEntity.create(**default_entity_dict) with db.session() as session:
entity = ServeEntity(**default_entity_dict)
session.add(entity)
with pytest.raises(Exception): with pytest.raises(Exception):
ServeEntity.create( with db.session() as session:
**{ entity = ServeEntity(
"prompt_name": "my_prompt_1", **{
"sys_code": "dbgpt", "prompt_name": "my_prompt_1",
"prompt_language": "zh", "sys_code": "dbgpt",
"model": "vicuna-13b-v1.5", "prompt_language": "zh",
} "model": "vicuna-13b-v1.5",
) }
)
session.add(entity)
def test_entity_get(default_entity_dict): def test_entity_get(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict) with db.session() as session:
db_entity: ServeEntity = ServeEntity.get(entity.id) entity = ServeEntity(**default_entity_dict)
assert db_entity.id == entity.id session.add(entity)
assert db_entity.chat_scene == "chat_data" session.commit()
assert db_entity.sub_chat_scene == "excel" db_entity: ServeEntity = session.get(ServeEntity, entity.id)
assert db_entity.prompt_type == "common" assert db_entity.id == entity.id
assert db_entity.prompt_name == "my_prompt_1" assert db_entity.chat_scene == "chat_data"
assert db_entity.content == "Write a qsort function in python." assert db_entity.sub_chat_scene == "excel"
assert db_entity.user_name == "zhangsan" assert db_entity.prompt_type == "common"
assert db_entity.sys_code == "dbgpt" assert db_entity.prompt_name == "my_prompt_1"
assert db_entity.gmt_created is not None assert db_entity.content == "Write a qsort function in python."
assert db_entity.gmt_modified is not None assert db_entity.user_name == "zhangsan"
assert db_entity.sys_code == "dbgpt"
assert db_entity.gmt_created is not None
assert db_entity.gmt_modified is not None
def test_entity_update(default_entity_dict): def test_entity_update(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict) with db.session() as session:
entity.update(prompt_name="my_prompt_2") entity = ServeEntity(**default_entity_dict)
db_entity: ServeEntity = ServeEntity.get(entity.id) session.add(entity)
assert db_entity.id == entity.id session.commit()
assert db_entity.chat_scene == "chat_data" entity.prompt_name = "my_prompt_2"
assert db_entity.sub_chat_scene == "excel" session.merge(entity)
assert db_entity.prompt_type == "common" db_entity: ServeEntity = session.get(ServeEntity, entity.id)
assert db_entity.prompt_name == "my_prompt_2" assert db_entity.id == entity.id
assert db_entity.content == "Write a qsort function in python." assert db_entity.chat_scene == "chat_data"
assert db_entity.user_name == "zhangsan" assert db_entity.sub_chat_scene == "excel"
assert db_entity.sys_code == "dbgpt" assert db_entity.prompt_type == "common"
assert db_entity.gmt_created is not None assert db_entity.prompt_name == "my_prompt_2"
assert db_entity.gmt_modified is not None assert db_entity.content == "Write a qsort function in python."
assert db_entity.user_name == "zhangsan"
assert db_entity.sys_code == "dbgpt"
assert db_entity.gmt_created is not None
assert db_entity.gmt_modified is not None
def test_entity_delete(default_entity_dict): def test_entity_delete(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict) with db.session() as session:
entity.delete() entity = ServeEntity(**default_entity_dict)
db_entity: ServeEntity = ServeEntity.get(entity.id) session.add(entity)
assert db_entity is None session.commit()
session.delete(entity)
session.commit()
db_entity: ServeEntity = session.get(ServeEntity, entity.id)
assert db_entity is None
def test_entity_all(): def test_entity_all():
for i in range(10): with db.session() as session:
ServeEntity.create( for i in range(10):
chat_scene="chat_data", entity = ServeEntity(
sub_chat_scene="excel", chat_scene="chat_data",
prompt_type="common", sub_chat_scene="excel",
prompt_name=f"my_prompt_{i}", prompt_type="common",
content="Write a qsort function in python.", prompt_name=f"my_prompt_{i}",
user_name="zhangsan", content="Write a qsort function in python.",
sys_code="dbgpt", user_name="zhangsan",
) sys_code="dbgpt",
entities = ServeEntity.all() )
assert len(entities) == 10 session.add(entity)
for entity in entities: with db.session() as session:
assert entity.chat_scene == "chat_data" entities = session.query(ServeEntity).all()
assert entity.sub_chat_scene == "excel" assert len(entities) == 10
assert entity.prompt_type == "common" for entity in entities:
assert entity.content == "Write a qsort function in python." assert entity.chat_scene == "chat_data"
assert entity.user_name == "zhangsan" assert entity.sub_chat_scene == "excel"
assert entity.sys_code == "dbgpt" assert entity.prompt_type == "common"
assert entity.gmt_created is not None assert entity.content == "Write a qsort function in python."
assert entity.gmt_modified is not None assert entity.user_name == "zhangsan"
assert entity.sys_code == "dbgpt"
assert entity.gmt_created is not None
assert entity.gmt_modified is not None
def test_dao_create(dao, default_entity_dict): def test_dao_create(dao, default_entity_dict):

View File

@@ -75,7 +75,7 @@ def test_config_default_user(service: Service):
def test_service_create(service: Service, default_entity_dict): def test_service_create(service: Service, default_entity_dict):
entity: ServerResponse = service.create(ServeRequest(**default_entity_dict)) entity: ServerResponse = service.create(ServeRequest(**default_entity_dict))
with db.session() as session: with db.session() as session:
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id) db_entity: ServeEntity = session.get(ServeEntity, entity.id)
assert db_entity.id == entity.id assert db_entity.id == entity.id
assert db_entity.chat_scene == "chat_data" assert db_entity.chat_scene == "chat_data"
assert db_entity.sub_chat_scene == "excel" assert db_entity.sub_chat_scene == "excel"
@@ -92,7 +92,7 @@ def test_service_update(service: Service, default_entity_dict):
service.create(ServeRequest(**default_entity_dict)) service.create(ServeRequest(**default_entity_dict))
entity: ServerResponse = service.update(ServeRequest(**default_entity_dict)) entity: ServerResponse = service.update(ServeRequest(**default_entity_dict))
with db.session() as session: with db.session() as session:
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id) db_entity: ServeEntity = session.get(ServeEntity, entity.id)
assert db_entity.id == entity.id assert db_entity.id == entity.id
assert db_entity.chat_scene == "chat_data" assert db_entity.chat_scene == "chat_data"
assert db_entity.sub_chat_scene == "excel" assert db_entity.sub_chat_scene == "excel"
@@ -109,7 +109,7 @@ def test_service_get(service: Service, default_entity_dict):
service.create(ServeRequest(**default_entity_dict)) service.create(ServeRequest(**default_entity_dict))
entity: ServerResponse = service.get(ServeRequest(**default_entity_dict)) entity: ServerResponse = service.get(ServeRequest(**default_entity_dict))
with db.session() as session: with db.session() as session:
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id) db_entity: ServeEntity = session.get(ServeEntity, entity.id)
assert db_entity.id == entity.id assert db_entity.id == entity.id
assert db_entity.chat_scene == "chat_data" assert db_entity.chat_scene == "chat_data"
assert db_entity.sub_chat_scene == "excel" assert db_entity.sub_chat_scene == "excel"

View File

@@ -1,5 +1,3 @@
from typing import List
import pytest import pytest
from dbgpt.storage.metadata import db from dbgpt.storage.metadata import db
@@ -39,11 +37,9 @@ def test_table_exist():
def test_entity_create(default_entity_dict): def test_entity_create(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
# TODO: implement your test case
with db.session() as session: with db.session() as session:
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id) entity = ServeEntity(**default_entity_dict)
assert db_entity.id == entity.id session.add(entity)
def test_entity_unique_key(default_entity_dict): def test_entity_unique_key(default_entity_dict):
@@ -52,10 +48,8 @@ def test_entity_unique_key(default_entity_dict):
def test_entity_get(default_entity_dict): def test_entity_get(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
db_entity: ServeEntity = ServeEntity.get(entity.id)
assert db_entity.id == entity.id
# TODO: implement your test case # TODO: implement your test case
pass
def test_entity_update(default_entity_dict): def test_entity_update(default_entity_dict):
@@ -65,10 +59,7 @@ def test_entity_update(default_entity_dict):
def test_entity_delete(default_entity_dict): def test_entity_delete(default_entity_dict):
# TODO: implement your test case # TODO: implement your test case
entity: ServeEntity = ServeEntity.create(**default_entity_dict) pass
entity.delete()
db_entity: ServeEntity = ServeEntity.get(entity.id)
assert db_entity is None
def test_entity_all(): def test_entity_all():

View File

@@ -105,12 +105,6 @@ class ChatHistoryDao(BaseDao):
chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid) chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid)
chat_history.delete() chat_history.delete()
def get_by_uid(self, conv_uid: str) -> ChatHistoryEntity: def get_by_uid(self, conv_uid: str) -> Optional[ChatHistoryEntity]:
# return ChatHistoryEntity.query.filter_by(conv_uid=conv_uid).first() with self.session(commit=False) as session:
return session.query(ChatHistoryEntity).filter_by(conv_uid=conv_uid).first()
session = self.get_raw_session()
chat_history = session.query(ChatHistoryEntity)
chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid)
result = chat_history.first()
session.close()
return result

View File

@@ -51,7 +51,9 @@ class BaseDao(Generic[T, REQ, RES]):
Example: Example:
.. code-block:: python .. code-block:: python
user = User(name="Edward Snowden") user = User(name="Edward Snowden")
session = self.get_raw_session() session = self.get_raw_session()
session.add(user) session.add(user)
@@ -61,7 +63,7 @@ class BaseDao(Generic[T, REQ, RES]):
return self._db_manager._session() return self._db_manager._session()
@contextmanager @contextmanager
def session(self) -> Session: def session(self, commit: Optional[bool] = True) -> Session:
"""Provide a transactional scope around a series of operations. """Provide a transactional scope around a series of operations.
If raise an exception, the session will be roll back automatically, otherwise it will be committed. If raise an exception, the session will be roll back automatically, otherwise it will be committed.
@@ -71,13 +73,16 @@ class BaseDao(Generic[T, REQ, RES]):
with self.session() as session: with self.session() as session:
session.query(User).filter(User.name == 'Edward Snowden').first() session.query(User).filter(User.name == 'Edward Snowden').first()
Args:
commit (Optional[bool], optional): Whether to commit the session. Defaults to True.
Returns: Returns:
Session: A session object. Session: A session object.
Raises: Raises:
Exception: Any exception will be raised. Exception: Any exception will be raised.
""" """
with self._db_manager.session() as session: with self._db_manager.session(commit=commit) as session:
yield session yield session
def from_request(self, request: QUERY_SPEC) -> T: def from_request(self, request: QUERY_SPEC) -> T:

View File

@@ -1,8 +1,15 @@
from __future__ import annotations from __future__ import annotations
import abc
from contextlib import contextmanager from contextlib import contextmanager
from typing import TypeVar, Generic, Union, Dict, Optional, Type, Iterator, List from typing import (
TypeVar,
Generic,
Union,
Dict,
Optional,
Type,
ClassVar,
)
import logging import logging
from sqlalchemy import create_engine, URL, Engine from sqlalchemy import create_engine, URL, Engine
from sqlalchemy import orm, inspect, MetaData from sqlalchemy import orm, inspect, MetaData
@@ -13,8 +20,6 @@ from sqlalchemy.orm import (
declarative_base, declarative_base,
DeclarativeMeta, DeclarativeMeta,
) )
from sqlalchemy.orm.session import _PKIdentityArgument
from sqlalchemy.orm.exc import UnmappedClassError
from sqlalchemy.pool import QueuePool from sqlalchemy.pool import QueuePool
from dbgpt.util.string_utils import _to_str from dbgpt.util.string_utils import _to_str
@@ -27,16 +32,10 @@ T = TypeVar("T", bound="BaseModel")
class _QueryObject: class _QueryObject:
"""The query object.""" """The query object."""
def __init__(self, db_manager: "DatabaseManager"): def __get__(self, obj: Union[_Model, None], model_cls: type[_Model]):
self._db_manager = db_manager return model_cls.query_class(
model_cls, session=model_cls.__db_manager__._session()
def __get__(self, obj, type): )
try:
mapper = orm.class_mapper(type)
if mapper:
return type.query_class(mapper, session=self._db_manager._session())
except UnmappedClassError:
return None
class BaseQuery(orm.Query): class BaseQuery(orm.Query):
@@ -46,7 +45,9 @@ class BaseQuery(orm.Query):
"""Paginate the query. """Paginate the query.
Example: Example:
.. code-block:: python .. code-block:: python
from dbgpt.storage.metadata import db, Model from dbgpt.storage.metadata import db, Model
class User(Model): class User(Model):
__tablename__ = "user" __tablename__ = "user"
@@ -58,10 +59,6 @@ class BaseQuery(orm.Query):
pagination = session.query(User).paginate_query(page=1, page_size=10) pagination = session.query(User).paginate_query(page=1, page_size=10)
print(pagination) print(pagination)
# Or you can use the query object
with db.session() as session:
pagination = User.query.paginate_query(page=1, page_size=10)
print(pagination)
Args: Args:
page (Optional[int], optional): The page number. Defaults to 1. page (Optional[int], optional): The page number. Defaults to 1.
@@ -86,26 +83,12 @@ class BaseQuery(orm.Query):
class _Model: class _Model:
"""Base class for SQLAlchemy declarative base model. """Base class for SQLAlchemy declarative base model."""
With this class, we can use the query object to query the database. __db_manager__: ClassVar[DatabaseManager]
query_class = BaseQuery
Examples: # query: Optional[BaseQuery] = _QueryObject()
.. code-block:: python
from dbgpt.storage.metadata import db, Model
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
fullname = Column(String(50))
with db.session() as session:
# User is an instance of _Model, and we can use the query object to query the database.
User.query.filter(User.name == "test").all()
"""
query_class = None
query: Optional[BaseQuery] = None
def __repr__(self): def __repr__(self):
identity = inspect(self).identity identity = inspect(self).identity
@@ -120,7 +103,9 @@ class DatabaseManager:
"""The database manager. """The database manager.
Examples: Examples:
.. code-block:: python .. code-block:: python
from urllib.parse import quote_plus as urlquote, quote from urllib.parse import quote_plus as urlquote, quote
from dbgpt.storage.metadata import DatabaseManager, create_model from dbgpt.storage.metadata import DatabaseManager, create_model
db = DatabaseManager() db = DatabaseManager()
@@ -141,21 +126,25 @@ class DatabaseManager:
session.add(User(name="test", fullname="test")) session.add(User(name="test", fullname="test"))
# db will commit the session automatically default. # db will commit the session automatically default.
# session.commit() # session.commit()
print(User.query.filter(User.name == "test").all()) assert session.query(User).filter(User.name == "test").first().name == "test"
# Use CURDMixin APIs to create, update, delete, query the database. # More usage:
with db.session() as session: with db.session() as session:
User.create(**{"name": "test1", "fullname": "test1"}) session.add(User(name="test1", fullname="test1"))
User.create(**{"name": "test2", "fullname": "test1"}) session.add(User(name="test2", fullname="test1"))
users = User.all() users = session.query(User).all()
print(users) print(users)
user = users[0] user = users[0]
user.update(**{"name": "test1_1111"}) user.name = "test1_1111"
session.merge(user)
user2 = users[1] user2 = users[1]
# Update user2 by save # Update user2 by save
user2.name = "test2_1111" user2.name = "test2_1111"
user2.save() session.merge(user2)
session.commit()
# Delete user2 # Delete user2
user2.delete() user2.delete()
""" """
@@ -189,28 +178,65 @@ class DatabaseManager:
return self._engine is not None and self._session is not None return self._engine is not None and self._session is not None
@contextmanager @contextmanager
def session(self) -> Session: def session(self, commit: Optional[bool] = True) -> Session:
"""Get the session with context manager. """Get the session with context manager.
If raise any exception, the session will roll back automatically, otherwise, the session will commit automatically. This context manager handles the lifecycle of a SQLAlchemy session.
It automatically commits or rolls back transactions based on
the execution and handles session closure.
Example: The `commit` parameter controls whether the session should commit
>>> with db.session() as session: changes at the end of the block. This is useful for separating
>>> session.query(...) read and write operations.
Returns: Examples:
Session: The session.
.. code-block:: python
# For write operations (insert, update, delete):
with db.session() as session:
user = User(name="John Doe")
session.add(user)
# session.commit() is called automatically
# For read-only operations:
with db.session(commit=False) as session:
user = session.query(User).filter_by(name="John Doe").first()
# session.commit() is NOT called, as it's unnecessary for read operations
Args:
commit (Optional[bool], optional): Whether to commit the session.
If True (default), the session will commit changes at the end
of the block. Use False for read-only operations or when manual
control over commit is needed. Defaults to True.
Yields:
Session: The SQLAlchemy session object.
Raises: Raises:
RuntimeError: The database manager is not initialized. RuntimeError: Raised if the database manager is not initialized.
Exception: Any exception. Exception: Propagates any exception that occurred within the block.
Important Notes:
- DetachedInstanceError: This error occurs when trying to access or
modify an instance that has been detached from its session.
DetachedInstanceError can occur in scenarios where the session is
closed, and further interaction with the ORM object is attempted,
especially when accessing lazy-loaded attributes. To avoid this:
a. Ensure required attributes are loaded before session closure.
b. Avoid closing the session before all necessary interactions
with the ORM object are complete.
c. Re-bind the instance to a new session if further interaction
is required after the session is closed.
""" """
if not self.is_initialized: if not self.is_initialized:
raise RuntimeError("The database manager is not initialized.") raise RuntimeError("The database manager is not initialized.")
session = self._session() session = self._session()
try: try:
yield session yield session
session.commit() if commit:
session.commit()
except: except:
session.rollback() session.rollback()
raise raise
@@ -223,7 +249,7 @@ class DatabaseManager:
"""Make the declarative base. """Make the declarative base.
Args: Args:
base (DeclarativeMeta): The base class. model (DeclarativeMeta): The base class.
Returns: Returns:
DeclarativeMeta: The declarative base. DeclarativeMeta: The declarative base.
@@ -232,7 +258,8 @@ class DatabaseManager:
model = declarative_base(cls=model, name="Model") model = declarative_base(cls=model, name="Model")
if not getattr(model, "query_class", None): if not getattr(model, "query_class", None):
model.query_class = self.Query model.query_class = self.Query
model.query = _QueryObject(self) # model.query = _QueryObject()
model.__db_manager__ = self
return model return model
def init_db( def init_db(
@@ -242,6 +269,7 @@ class DatabaseManager:
base: Optional[DeclarativeMeta] = None, base: Optional[DeclarativeMeta] = None,
query_class=BaseQuery, query_class=BaseQuery,
override_query_class: Optional[bool] = False, override_query_class: Optional[bool] = False,
session_options: Optional[Dict] = None,
): ):
"""Initialize the database manager. """Initialize the database manager.
@@ -251,18 +279,26 @@ class DatabaseManager:
base (Optional[DeclarativeMeta]): The base class. Defaults to None. base (Optional[DeclarativeMeta]): The base class. Defaults to None.
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery. query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False. override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False.
session_options (Optional[Dict], optional): The session options. Defaults to None.
""" """
if session_options is None:
session_options = {}
self._db_url = db_url self._db_url = db_url
if query_class is not None: if query_class is not None:
self.Query = query_class self.Query = query_class
if base is not None: if base is not None:
self._base = base self._base = base
if not hasattr(base, "query") or override_query_class: # if not hasattr(base, "query") or override_query_class:
base.query = _QueryObject(self) # base.query = _QueryObject()
if not getattr(base, "query_class", None) or override_query_class: if not getattr(base, "query_class", None) or override_query_class:
base.query_class = self.Query base.query_class = self.Query
if not hasattr(base, "__db_manager__") or override_query_class:
base.__db_manager__ = self
self._engine = create_engine(db_url, **(engine_args or {})) self._engine = create_engine(db_url, **(engine_args or {}))
session_factory = sessionmaker(bind=self._engine)
session_options.setdefault("class_", Session)
session_options.setdefault("query_cls", self.Query)
session_factory = sessionmaker(bind=self._engine, **session_options)
self._session = scoped_session(session_factory) self._session = scoped_session(session_factory)
self._base.metadata.bind = self._engine self._base.metadata.bind = self._engine
@@ -397,35 +433,12 @@ class BaseCRUDMixin(Generic[T]):
__abstract__ = True __abstract__ = True
@classmethod @classmethod
def create(cls: Type[T], **kwargs) -> T: def db(cls) -> DatabaseManager:
instance = cls(**kwargs) """Get the database manager."""
return instance.save() return cls.__db_manager__
@classmethod
def all(cls: Type[T]) -> List[T]:
return cls.query.all()
@classmethod
def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]:
"""Get a record by its primary key identifier."""
def update(self: T, commit: Optional[bool] = True, **kwargs) -> T:
"""Update specific fields of a record."""
for attr, value in kwargs.items():
setattr(self, attr, value)
return commit and self.save() or self
@abc.abstractmethod
def save(self: T, commit: Optional[bool] = True) -> T:
"""Save the record."""
@abc.abstractmethod
def delete(self: T, commit: Optional[bool] = True) -> None:
"""Remove the record from the database."""
class BaseModel(BaseCRUDMixin[T], _Model, Generic[T]): class BaseModel(BaseCRUDMixin[T], _Model, Generic[T]):
"""The base model class that includes CRUD convenience methods.""" """The base model class that includes CRUD convenience methods."""
__abstract__ = True __abstract__ = True
@@ -438,28 +451,14 @@ def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
_db_manager: DatabaseManager = db_manager _db_manager: DatabaseManager = db_manager
@classmethod @classmethod
def set_db_manager(cls, db_manager: DatabaseManager): def set_db(cls, db_manager: DatabaseManager):
# TODO: It is hard to replace to user DB Connection # TODO: It is hard to replace to user DB Connection
cls._db_manager = db_manager cls._db_manager = db_manager
@classmethod @classmethod
def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]: def db(cls) -> DatabaseManager:
"""Get a record by its primary key identifier.""" """Get the database manager."""
return cls._db_manager._session().get(cls, ident) return cls._db_manager
def save(self: T, commit: Optional[bool] = True) -> T:
"""Save the record."""
session = self._db_manager._session()
session.add(self)
if commit:
session.commit()
return self
def delete(self: T, commit: Optional[bool] = True) -> None:
"""Remove the record from the database."""
session = self._db_manager._session()
session.delete(self)
return commit and session.commit()
class _NewModel(CRUDMixin[T], db_manager.Model, Generic[T]): class _NewModel(CRUDMixin[T], db_manager.Model, Generic[T]):
"""Base model class that includes CRUD convenience methods.""" """Base model class that includes CRUD convenience methods."""
@@ -478,6 +477,7 @@ def initialize_db(
engine_args: Optional[Dict] = None, engine_args: Optional[Dict] = None,
base: Optional[DeclarativeMeta] = None, base: Optional[DeclarativeMeta] = None,
try_to_create_db: Optional[bool] = False, try_to_create_db: Optional[bool] = False,
session_options: Optional[Dict] = None,
) -> DatabaseManager: ) -> DatabaseManager:
"""Initialize the database manager. """Initialize the database manager.
@@ -487,10 +487,11 @@ def initialize_db(
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None. engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
base (Optional[DeclarativeMeta]): The base class. Defaults to None. base (Optional[DeclarativeMeta]): The base class. Defaults to None.
try_to_create_db (Optional[bool], optional): Whether to try to create the database. Defaults to False. try_to_create_db (Optional[bool], optional): Whether to try to create the database. Defaults to False.
session_options (Optional[Dict], optional): The session options. Defaults to None.
Returns: Returns:
DatabaseManager: The database manager. DatabaseManager: The database manager.
""" """
db.init_db(db_url, engine_args, base) db.init_db(db_url, engine_args, base, session_options=session_options)
if try_to_create_db: if try_to_create_db:
try: try:
db.create_all() db.create_all()

View File

@@ -100,7 +100,7 @@ def test_update_user(db: DatabaseManager, User: Type[BaseModel], user_dao, user_
# Verify that the user is updated in the database # Verify that the user is updated in the database
with db.session() as session: with db.session() as session:
user = session.query(User).get(created_user_response.id) user = session.get(User, created_user_response.id)
assert user.age == 35 assert user.age == 35
@@ -121,7 +121,7 @@ def test_update_user_partial(
# Verify that the user is updated in the database # Verify that the user is updated in the database
with db.session() as session: with db.session() as session:
user = session.query(User).get(created_user_response.id) user = session.get(User, created_user_response.id)
assert user.age == user_req.age assert user.age == user_req.age
assert user.password == "newpassword" assert user.password == "newpassword"

View File

@@ -53,11 +53,10 @@ def test_crud_operations(db: DatabaseManager, Model: Type[BaseModel]):
# Create # Create
with db.session() as session: with db.session() as session:
user = User.create(name="John Doe") user = User(name="John Doe")
session.add(user) session.add(user)
session.commit()
# Read # # Read
with db.session() as session: with db.session() as session:
user = session.query(User).filter_by(name="John Doe").first() user = session.query(User).filter_by(name="John Doe").first()
assert user is not None assert user is not None
@@ -65,12 +64,20 @@ def test_crud_operations(db: DatabaseManager, Model: Type[BaseModel]):
# Update # Update
with db.session() as session: with db.session() as session:
user = session.query(User).filter_by(name="John Doe").first() user = session.query(User).filter_by(name="John Doe").first()
user.update(name="Jane Doe") user.name = "Mike Doe"
session.merge(user)
# Delete
with db.session() as session: with db.session() as session:
user = session.query(User).filter_by(name="Jane Doe").first() user = session.query(User).filter_by(name="Mike Doe").first()
user.delete() assert user is not None
session.query(User).filter(User.name == "John Doe").first() is None
#
# # Delete
with db.session() as session:
user = session.query(User).filter_by(name="Mike Doe").first()
session.delete(user)
with db.session() as session:
assert len(session.query(User).all()) == 0
def test_crud_mixins(db: DatabaseManager, Model: Type[BaseModel]): def test_crud_mixins(db: DatabaseManager, Model: Type[BaseModel]):
@@ -80,20 +87,7 @@ def test_crud_mixins(db: DatabaseManager, Model: Type[BaseModel]):
name = Column(String(50)) name = Column(String(50))
db.create_all() db.create_all()
User.db() == db
# Create
user = User.create(name="John Doe")
assert User.get(user.id) is not None
users = User.all()
assert len(users) == 1
# Update
user.update(name="Bob Doe")
assert User.get(user.id).name == "Bob Doe"
user = User.get(user.id)
user.delete()
assert User.get(user.id) is None
def test_pagination_query(db: DatabaseManager, Model: Type[BaseModel]): def test_pagination_query(db: DatabaseManager, Model: Type[BaseModel]):
@@ -108,11 +102,10 @@ def test_pagination_query(db: DatabaseManager, Model: Type[BaseModel]):
for i in range(30): for i in range(30):
user = User(name=f"User {i}") user = User(name=f"User {i}")
session.add(user) session.add(user)
session.commit() with db.session() as session:
users_page_1 = session.query(User).paginate_query(page=1, per_page=10)
users_page_1 = User.query.paginate_query(page=1, per_page=10) assert len(users_page_1.items) == 10
assert len(users_page_1.items) == 10 assert users_page_1.total_pages == 3
assert users_page_1.total_pages == 3
def test_invalid_pagination(db: DatabaseManager, Model: Type[BaseModel]): def test_invalid_pagination(db: DatabaseManager, Model: Type[BaseModel]):
@@ -124,9 +117,11 @@ def test_invalid_pagination(db: DatabaseManager, Model: Type[BaseModel]):
db.create_all() db.create_all()
with pytest.raises(ValueError): with pytest.raises(ValueError):
User.query.paginate_query(page=0, per_page=10) with db.session() as session:
session.query(User).paginate_query(page=0, per_page=10)
with pytest.raises(ValueError): with pytest.raises(ValueError):
User.query.paginate_query(page=1, per_page=-1) with db.session() as session:
session.query(User).paginate_query(page=1, per_page=-1)
def test_set_model_db_manager(db: DatabaseManager, Model: Type[BaseModel]): def test_set_model_db_manager(db: DatabaseManager, Model: Type[BaseModel]):
@@ -142,14 +137,19 @@ def test_set_model_db_manager(db: DatabaseManager, Model: Type[BaseModel]):
new_db = DatabaseManager.build_from( new_db = DatabaseManager.build_from(
f"sqlite:///{filename}", base=Model, override_query_class=True f"sqlite:///{filename}", base=Model, override_query_class=True
) )
Model.set_db_manager(new_db) Model.set_db(new_db)
new_db.create_all() new_db.create_all()
db.create_all() db.create_all()
assert list(new_db.metadata.tables.keys())[0] == "user" assert list(new_db.metadata.tables.keys())[0] == "user"
User.create(**{"name": "John Doe"}) with new_db.session() as session:
user = User(name="John Doe")
session.add(user)
with new_db.session() as session: with new_db.session() as session:
assert session.query(User).filter_by(name="John Doe").first() is not None assert session.query(User).filter_by(name="John Doe").first() is not None
with db.session() as session: with db.session() as session:
assert session.query(User).filter_by(name="John Doe").first() is None assert session.query(User).filter_by(name="John Doe").first() is None
assert len(User.query.all()) == 1 with new_db.session() as session:
assert User.query.filter(User.name == "John Doe").first().name == "John Doe" session.query(User).all() == 1
session.query(User).filter(
User.name == "John Doe"
).first().name == "John Doe"

View File

@@ -7,7 +7,7 @@
Call with non-streaming response. Call with non-streaming response.
.. code-block:: shell .. code-block:: shell
DBGPT_SERVER="http://127.0.0.1:5000" DBGPT_SERVER="http://127.0.0.1:5555"
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/chat/completions \ curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/chat/completions \
-H "Content-Type: application/json" -d '{ -H "Content-Type: application/json" -d '{
"model": "proxyllm", "model": "proxyllm",