feat(model): Proxy model support count token (#996)

This commit is contained in:
Fangyin Cheng
2023-12-29 12:01:31 +08:00
committed by GitHub
parent ba0599ebf4
commit 0cdc77abb2
16 changed files with 366 additions and 248 deletions

View File

@@ -189,7 +189,7 @@ class DefaultModelWorker(ModelWorker):
return output
def count_token(self, prompt: str) -> int:
return _try_to_count_token(prompt, self.tokenizer)
return _try_to_count_token(prompt, self.tokenizer, self.model)
async def async_count_token(self, prompt: str) -> int:
# TODO if we deploy the model by vllm, it can't work, we should run transformer _try_to_count_token to async
@@ -454,12 +454,13 @@ def _new_metrics_from_model_output(
return metrics
def _try_to_count_token(prompt: str, tokenizer) -> int:
def _try_to_count_token(prompt: str, tokenizer, model) -> int:
"""Try to count token of prompt
Args:
prompt (str): prompt
tokenizer ([type]): tokenizer
model ([type]): model
Returns:
int: token count, if error return -1
@@ -467,6 +468,11 @@ def _try_to_count_token(prompt: str, tokenizer) -> int:
TODO: More implementation
"""
try:
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
if isinstance(model, ProxyModel):
return model.count_token(prompt)
# Only support huggingface model now
return len(tokenizer(prompt).input_ids[0])
except Exception as e:
logger.warning(f"Count token error, detail: {e}, return -1")

View File

@@ -197,7 +197,7 @@ class LocalWorkerManager(WorkerManager):
return True
else:
# TODO Update worker
logger.warn(f"Instance {worker_key} exist")
logger.warning(f"Instance {worker_key} exist")
return False
def _remove_worker(self, worker_params: ModelWorkerParameters) -> None:
@@ -229,7 +229,7 @@ class LocalWorkerManager(WorkerManager):
)
if not success:
msg = f"Add worker {model_name}@{worker_type}, worker instances is exist"
logger.warn(f"{msg}, worker_params: {worker_params}")
logger.warning(f"{msg}, worker_params: {worker_params}")
self._remove_worker(worker_params)
raise Exception(msg)
supported_types = WorkerType.values()

View File

@@ -92,11 +92,11 @@ def _initialize_openai_v1(params: ProxyModelParameters):
def __convert_2_gpt_messages(messages: List[ModelMessage]):
chat_round = 0
gpt_messages = []
last_usr_message = ""
system_messages = []
# TODO: We can't change message order in low level
for message in messages:
if message.role == ModelMessageRoleType.HUMAN or message.role == "user":
last_usr_message = message.content

View File

@@ -1,9 +1,36 @@
from __future__ import annotations
from typing import Union, List, Optional, TYPE_CHECKING
import logging
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper
if TYPE_CHECKING:
from dbgpt.core.interface.message import ModelMessage, BaseMessage
logger = logging.getLogger(__name__)
class ProxyModel:
def __init__(self, model_params: ProxyModelParameters) -> None:
self._model_params = model_params
self._tokenizer = ProxyTokenizerWrapper()
def get_params(self) -> ProxyModelParameters:
return self._model_params
def count_token(
self,
messages: Union[str, BaseMessage, ModelMessage, List[ModelMessage]],
model_name: Optional[int] = None,
) -> int:
"""Count token of given messages
Args:
messages (Union[str, BaseMessage, ModelMessage, List[ModelMessage]]): messages to count token
model_name (Optional[int], optional): model name. Defaults to None.
Returns:
int: token count, -1 if failed
"""
return self._tokenizer.count_token(messages, model_name)

View File

@@ -25,6 +25,7 @@ from dbgpt.core.interface.llm import ModelOutput, ModelRequest
from dbgpt.model.cluster.client import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt._private.pydantic import model_to_json
from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper
if TYPE_CHECKING:
import httpx
@@ -152,6 +153,7 @@ class OpenAILLMClient(LLMClient):
self._context_length = context_length
self._client = openai_client
self._openai_kwargs = openai_kwargs or {}
self._tokenizer = ProxyTokenizerWrapper()
@property
def client(self) -> ClientType:
@@ -238,10 +240,11 @@ class OpenAILLMClient(LLMClient):
async def count_token(self, model: str, prompt: str) -> int:
"""Count the number of tokens in a given prompt.
TODO: Get the real number of tokens from the openai api or tiktoken package
Args:
model (str): The model name.
prompt (str): The prompt.
"""
raise NotImplementedError()
return self._tokenizer.count_token(prompt, model)
class OpenAIStreamingOperator(TransformStreamAbsOperator[ModelOutput, str]):

View File

@@ -0,0 +1,80 @@
from __future__ import annotations
from typing import Union, List, Optional, TYPE_CHECKING
import logging
if TYPE_CHECKING:
from dbgpt.core.interface.message import ModelMessage, BaseMessage
logger = logging.getLogger(__name__)
class ProxyTokenizerWrapper:
def __init__(self) -> None:
self._support_encoding = True
self._encoding_model = None
def count_token(
self,
messages: Union[str, BaseMessage, ModelMessage, List[ModelMessage]],
model_name: Optional[str] = None,
) -> int:
"""Count token of given messages
Args:
messages (Union[str, BaseMessage, ModelMessage, List[ModelMessage]]): messages to count token
model_name (Optional[str], optional): model name. Defaults to None.
Returns:
int: token count, -1 if failed
"""
if not self._support_encoding:
logger.warning(
"model does not support encoding model, can't count token, returning -1"
)
return -1
encoding = self._get_or_create_encoding_model(model_name)
cnt = 0
if isinstance(messages, str):
cnt = len(encoding.encode(messages, disallowed_special=()))
elif isinstance(messages, BaseMessage):
cnt = len(encoding.encode(messages.content, disallowed_special=()))
elif isinstance(messages, ModelMessage):
cnt = len(encoding.encode(messages.content, disallowed_special=()))
elif isinstance(messages, list):
for message in messages:
cnt += len(encoding.encode(message.content, disallowed_special=()))
else:
logger.warning(
"unsupported type of messages, can't count token, returning -1"
)
return -1
return cnt
def _get_or_create_encoding_model(self, model_name: Optional[str] = None):
"""Get or create encoding model for given model name
More detail see: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
"""
if self._encoding_model:
return self._encoding_model
try:
import tiktoken
logger.info(
"tiktoken installed, using it to count tokens, tiktoken will download tokenizer from network, "
"also you can download it and put it in the directory of environment variable TIKTOKEN_CACHE_DIR"
)
except ImportError:
self._support_encoding = False
logger.warn("tiktoken not installed, cannot count tokens, returning -1")
return -1
try:
if not model_name:
model_name = "gpt-3.5-turbo"
self._encoding_model = tiktoken.model.encoding_for_model(model_name)
except KeyError:
logger.warning(
f"{model_name}'s tokenizer not found, using cl100k_base encoding."
)
self._encoding_model = tiktoken.get_encoding("cl100k_base")
return self._encoding_model

View File

@@ -1,5 +1,3 @@
from typing import List
import pytest
from dbgpt.storage.metadata import db
@@ -39,11 +37,9 @@ def test_table_exist():
def test_entity_create(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
# TODO: implement your test case
with db.session() as session:
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
assert db_entity.id == entity.id
entity = ServeEntity(**default_entity_dict)
session.add(entity)
def test_entity_unique_key(default_entity_dict):
@@ -52,10 +48,8 @@ def test_entity_unique_key(default_entity_dict):
def test_entity_get(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
db_entity: ServeEntity = ServeEntity.get(entity.id)
assert db_entity.id == entity.id
# TODO: implement your test case
pass
def test_entity_update(default_entity_dict):
@@ -65,10 +59,7 @@ def test_entity_update(default_entity_dict):
def test_entity_delete(default_entity_dict):
# TODO: implement your test case
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
entity.delete()
db_entity: ServeEntity = ServeEntity.get(entity.id)
assert db_entity is None
pass
def test_entity_all():

View File

@@ -47,9 +47,11 @@ def test_table_exist():
def test_entity_create(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
with db.session() as session:
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
entity: ServeEntity = ServeEntity(**default_entity_dict)
session.add(entity)
session.commit()
db_entity: ServeEntity = session.get(ServeEntity, entity.id)
assert db_entity.id == entity.id
assert db_entity.chat_scene == "chat_data"
assert db_entity.sub_chat_scene == "excel"
@@ -63,78 +65,96 @@ def test_entity_create(default_entity_dict):
def test_entity_unique_key(default_entity_dict):
ServeEntity.create(**default_entity_dict)
with db.session() as session:
entity = ServeEntity(**default_entity_dict)
session.add(entity)
with pytest.raises(Exception):
ServeEntity.create(
**{
"prompt_name": "my_prompt_1",
"sys_code": "dbgpt",
"prompt_language": "zh",
"model": "vicuna-13b-v1.5",
}
)
with db.session() as session:
entity = ServeEntity(
**{
"prompt_name": "my_prompt_1",
"sys_code": "dbgpt",
"prompt_language": "zh",
"model": "vicuna-13b-v1.5",
}
)
session.add(entity)
def test_entity_get(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
db_entity: ServeEntity = ServeEntity.get(entity.id)
assert db_entity.id == entity.id
assert db_entity.chat_scene == "chat_data"
assert db_entity.sub_chat_scene == "excel"
assert db_entity.prompt_type == "common"
assert db_entity.prompt_name == "my_prompt_1"
assert db_entity.content == "Write a qsort function in python."
assert db_entity.user_name == "zhangsan"
assert db_entity.sys_code == "dbgpt"
assert db_entity.gmt_created is not None
assert db_entity.gmt_modified is not None
with db.session() as session:
entity = ServeEntity(**default_entity_dict)
session.add(entity)
session.commit()
db_entity: ServeEntity = session.get(ServeEntity, entity.id)
assert db_entity.id == entity.id
assert db_entity.chat_scene == "chat_data"
assert db_entity.sub_chat_scene == "excel"
assert db_entity.prompt_type == "common"
assert db_entity.prompt_name == "my_prompt_1"
assert db_entity.content == "Write a qsort function in python."
assert db_entity.user_name == "zhangsan"
assert db_entity.sys_code == "dbgpt"
assert db_entity.gmt_created is not None
assert db_entity.gmt_modified is not None
def test_entity_update(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
entity.update(prompt_name="my_prompt_2")
db_entity: ServeEntity = ServeEntity.get(entity.id)
assert db_entity.id == entity.id
assert db_entity.chat_scene == "chat_data"
assert db_entity.sub_chat_scene == "excel"
assert db_entity.prompt_type == "common"
assert db_entity.prompt_name == "my_prompt_2"
assert db_entity.content == "Write a qsort function in python."
assert db_entity.user_name == "zhangsan"
assert db_entity.sys_code == "dbgpt"
assert db_entity.gmt_created is not None
assert db_entity.gmt_modified is not None
with db.session() as session:
entity = ServeEntity(**default_entity_dict)
session.add(entity)
session.commit()
entity.prompt_name = "my_prompt_2"
session.merge(entity)
db_entity: ServeEntity = session.get(ServeEntity, entity.id)
assert db_entity.id == entity.id
assert db_entity.chat_scene == "chat_data"
assert db_entity.sub_chat_scene == "excel"
assert db_entity.prompt_type == "common"
assert db_entity.prompt_name == "my_prompt_2"
assert db_entity.content == "Write a qsort function in python."
assert db_entity.user_name == "zhangsan"
assert db_entity.sys_code == "dbgpt"
assert db_entity.gmt_created is not None
assert db_entity.gmt_modified is not None
def test_entity_delete(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
entity.delete()
db_entity: ServeEntity = ServeEntity.get(entity.id)
assert db_entity is None
with db.session() as session:
entity = ServeEntity(**default_entity_dict)
session.add(entity)
session.commit()
session.delete(entity)
session.commit()
db_entity: ServeEntity = session.get(ServeEntity, entity.id)
assert db_entity is None
def test_entity_all():
for i in range(10):
ServeEntity.create(
chat_scene="chat_data",
sub_chat_scene="excel",
prompt_type="common",
prompt_name=f"my_prompt_{i}",
content="Write a qsort function in python.",
user_name="zhangsan",
sys_code="dbgpt",
)
entities = ServeEntity.all()
assert len(entities) == 10
for entity in entities:
assert entity.chat_scene == "chat_data"
assert entity.sub_chat_scene == "excel"
assert entity.prompt_type == "common"
assert entity.content == "Write a qsort function in python."
assert entity.user_name == "zhangsan"
assert entity.sys_code == "dbgpt"
assert entity.gmt_created is not None
assert entity.gmt_modified is not None
with db.session() as session:
for i in range(10):
entity = ServeEntity(
chat_scene="chat_data",
sub_chat_scene="excel",
prompt_type="common",
prompt_name=f"my_prompt_{i}",
content="Write a qsort function in python.",
user_name="zhangsan",
sys_code="dbgpt",
)
session.add(entity)
with db.session() as session:
entities = session.query(ServeEntity).all()
assert len(entities) == 10
for entity in entities:
assert entity.chat_scene == "chat_data"
assert entity.sub_chat_scene == "excel"
assert entity.prompt_type == "common"
assert entity.content == "Write a qsort function in python."
assert entity.user_name == "zhangsan"
assert entity.sys_code == "dbgpt"
assert entity.gmt_created is not None
assert entity.gmt_modified is not None
def test_dao_create(dao, default_entity_dict):

View File

@@ -75,7 +75,7 @@ def test_config_default_user(service: Service):
def test_service_create(service: Service, default_entity_dict):
entity: ServerResponse = service.create(ServeRequest(**default_entity_dict))
with db.session() as session:
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
db_entity: ServeEntity = session.get(ServeEntity, entity.id)
assert db_entity.id == entity.id
assert db_entity.chat_scene == "chat_data"
assert db_entity.sub_chat_scene == "excel"
@@ -92,7 +92,7 @@ def test_service_update(service: Service, default_entity_dict):
service.create(ServeRequest(**default_entity_dict))
entity: ServerResponse = service.update(ServeRequest(**default_entity_dict))
with db.session() as session:
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
db_entity: ServeEntity = session.get(ServeEntity, entity.id)
assert db_entity.id == entity.id
assert db_entity.chat_scene == "chat_data"
assert db_entity.sub_chat_scene == "excel"
@@ -109,7 +109,7 @@ def test_service_get(service: Service, default_entity_dict):
service.create(ServeRequest(**default_entity_dict))
entity: ServerResponse = service.get(ServeRequest(**default_entity_dict))
with db.session() as session:
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
db_entity: ServeEntity = session.get(ServeEntity, entity.id)
assert db_entity.id == entity.id
assert db_entity.chat_scene == "chat_data"
assert db_entity.sub_chat_scene == "excel"

View File

@@ -1,5 +1,3 @@
from typing import List
import pytest
from dbgpt.storage.metadata import db
@@ -39,11 +37,9 @@ def test_table_exist():
def test_entity_create(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
# TODO: implement your test case
with db.session() as session:
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
assert db_entity.id == entity.id
entity = ServeEntity(**default_entity_dict)
session.add(entity)
def test_entity_unique_key(default_entity_dict):
@@ -52,10 +48,8 @@ def test_entity_unique_key(default_entity_dict):
def test_entity_get(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
db_entity: ServeEntity = ServeEntity.get(entity.id)
assert db_entity.id == entity.id
# TODO: implement your test case
pass
def test_entity_update(default_entity_dict):
@@ -65,10 +59,7 @@ def test_entity_update(default_entity_dict):
def test_entity_delete(default_entity_dict):
# TODO: implement your test case
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
entity.delete()
db_entity: ServeEntity = ServeEntity.get(entity.id)
assert db_entity is None
pass
def test_entity_all():

View File

@@ -105,12 +105,6 @@ class ChatHistoryDao(BaseDao):
chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid)
chat_history.delete()
def get_by_uid(self, conv_uid: str) -> ChatHistoryEntity:
# return ChatHistoryEntity.query.filter_by(conv_uid=conv_uid).first()
session = self.get_raw_session()
chat_history = session.query(ChatHistoryEntity)
chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid)
result = chat_history.first()
session.close()
return result
def get_by_uid(self, conv_uid: str) -> Optional[ChatHistoryEntity]:
with self.session(commit=False) as session:
return session.query(ChatHistoryEntity).filter_by(conv_uid=conv_uid).first()

View File

@@ -51,7 +51,9 @@ class BaseDao(Generic[T, REQ, RES]):
Example:
.. code-block:: python
user = User(name="Edward Snowden")
session = self.get_raw_session()
session.add(user)
@@ -61,7 +63,7 @@ class BaseDao(Generic[T, REQ, RES]):
return self._db_manager._session()
@contextmanager
def session(self) -> Session:
def session(self, commit: Optional[bool] = True) -> Session:
"""Provide a transactional scope around a series of operations.
If raise an exception, the session will be roll back automatically, otherwise it will be committed.
@@ -71,13 +73,16 @@ class BaseDao(Generic[T, REQ, RES]):
with self.session() as session:
session.query(User).filter(User.name == 'Edward Snowden').first()
Args:
commit (Optional[bool], optional): Whether to commit the session. Defaults to True.
Returns:
Session: A session object.
Raises:
Exception: Any exception will be raised.
"""
with self._db_manager.session() as session:
with self._db_manager.session(commit=commit) as session:
yield session
def from_request(self, request: QUERY_SPEC) -> T:

View File

@@ -1,8 +1,15 @@
from __future__ import annotations
import abc
from contextlib import contextmanager
from typing import TypeVar, Generic, Union, Dict, Optional, Type, Iterator, List
from typing import (
TypeVar,
Generic,
Union,
Dict,
Optional,
Type,
ClassVar,
)
import logging
from sqlalchemy import create_engine, URL, Engine
from sqlalchemy import orm, inspect, MetaData
@@ -13,8 +20,6 @@ from sqlalchemy.orm import (
declarative_base,
DeclarativeMeta,
)
from sqlalchemy.orm.session import _PKIdentityArgument
from sqlalchemy.orm.exc import UnmappedClassError
from sqlalchemy.pool import QueuePool
from dbgpt.util.string_utils import _to_str
@@ -27,16 +32,10 @@ T = TypeVar("T", bound="BaseModel")
class _QueryObject:
"""The query object."""
def __init__(self, db_manager: "DatabaseManager"):
self._db_manager = db_manager
def __get__(self, obj, type):
try:
mapper = orm.class_mapper(type)
if mapper:
return type.query_class(mapper, session=self._db_manager._session())
except UnmappedClassError:
return None
def __get__(self, obj: Union[_Model, None], model_cls: type[_Model]):
return model_cls.query_class(
model_cls, session=model_cls.__db_manager__._session()
)
class BaseQuery(orm.Query):
@@ -46,7 +45,9 @@ class BaseQuery(orm.Query):
"""Paginate the query.
Example:
.. code-block:: python
from dbgpt.storage.metadata import db, Model
class User(Model):
__tablename__ = "user"
@@ -58,10 +59,6 @@ class BaseQuery(orm.Query):
pagination = session.query(User).paginate_query(page=1, page_size=10)
print(pagination)
# Or you can use the query object
with db.session() as session:
pagination = User.query.paginate_query(page=1, page_size=10)
print(pagination)
Args:
page (Optional[int], optional): The page number. Defaults to 1.
@@ -86,26 +83,12 @@ class BaseQuery(orm.Query):
class _Model:
"""Base class for SQLAlchemy declarative base model.
"""Base class for SQLAlchemy declarative base model."""
With this class, we can use the query object to query the database.
__db_manager__: ClassVar[DatabaseManager]
query_class = BaseQuery
Examples:
.. code-block:: python
from dbgpt.storage.metadata import db, Model
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
fullname = Column(String(50))
with db.session() as session:
# User is an instance of _Model, and we can use the query object to query the database.
User.query.filter(User.name == "test").all()
"""
query_class = None
query: Optional[BaseQuery] = None
# query: Optional[BaseQuery] = _QueryObject()
def __repr__(self):
identity = inspect(self).identity
@@ -120,7 +103,9 @@ class DatabaseManager:
"""The database manager.
Examples:
.. code-block:: python
from urllib.parse import quote_plus as urlquote, quote
from dbgpt.storage.metadata import DatabaseManager, create_model
db = DatabaseManager()
@@ -141,21 +126,25 @@ class DatabaseManager:
session.add(User(name="test", fullname="test"))
# db will commit the session automatically default.
# session.commit()
print(User.query.filter(User.name == "test").all())
assert session.query(User).filter(User.name == "test").first().name == "test"
# Use CURDMixin APIs to create, update, delete, query the database.
# More usage:
with db.session() as session:
User.create(**{"name": "test1", "fullname": "test1"})
User.create(**{"name": "test2", "fullname": "test1"})
users = User.all()
session.add(User(name="test1", fullname="test1"))
session.add(User(name="test2", fullname="test1"))
users = session.query(User).all()
print(users)
user = users[0]
user.update(**{"name": "test1_1111"})
user.name = "test1_1111"
session.merge(user)
user2 = users[1]
# Update user2 by save
user2.name = "test2_1111"
user2.save()
session.merge(user2)
session.commit()
# Delete user2
user2.delete()
"""
@@ -189,28 +178,65 @@ class DatabaseManager:
return self._engine is not None and self._session is not None
@contextmanager
def session(self) -> Session:
def session(self, commit: Optional[bool] = True) -> Session:
"""Get the session with context manager.
If raise any exception, the session will roll back automatically, otherwise, the session will commit automatically.
This context manager handles the lifecycle of a SQLAlchemy session.
It automatically commits or rolls back transactions based on
the execution and handles session closure.
Example:
>>> with db.session() as session:
>>> session.query(...)
The `commit` parameter controls whether the session should commit
changes at the end of the block. This is useful for separating
read and write operations.
Returns:
Session: The session.
Examples:
.. code-block:: python
# For write operations (insert, update, delete):
with db.session() as session:
user = User(name="John Doe")
session.add(user)
# session.commit() is called automatically
# For read-only operations:
with db.session(commit=False) as session:
user = session.query(User).filter_by(name="John Doe").first()
# session.commit() is NOT called, as it's unnecessary for read operations
Args:
commit (Optional[bool], optional): Whether to commit the session.
If True (default), the session will commit changes at the end
of the block. Use False for read-only operations or when manual
control over commit is needed. Defaults to True.
Yields:
Session: The SQLAlchemy session object.
Raises:
RuntimeError: The database manager is not initialized.
Exception: Any exception.
RuntimeError: Raised if the database manager is not initialized.
Exception: Propagates any exception that occurred within the block.
Important Notes:
- DetachedInstanceError: This error occurs when trying to access or
modify an instance that has been detached from its session.
DetachedInstanceError can occur in scenarios where the session is
closed, and further interaction with the ORM object is attempted,
especially when accessing lazy-loaded attributes. To avoid this:
a. Ensure required attributes are loaded before session closure.
b. Avoid closing the session before all necessary interactions
with the ORM object are complete.
c. Re-bind the instance to a new session if further interaction
is required after the session is closed.
"""
if not self.is_initialized:
raise RuntimeError("The database manager is not initialized.")
session = self._session()
try:
yield session
session.commit()
if commit:
session.commit()
except:
session.rollback()
raise
@@ -223,7 +249,7 @@ class DatabaseManager:
"""Make the declarative base.
Args:
base (DeclarativeMeta): The base class.
model (DeclarativeMeta): The base class.
Returns:
DeclarativeMeta: The declarative base.
@@ -232,7 +258,8 @@ class DatabaseManager:
model = declarative_base(cls=model, name="Model")
if not getattr(model, "query_class", None):
model.query_class = self.Query
model.query = _QueryObject(self)
# model.query = _QueryObject()
model.__db_manager__ = self
return model
def init_db(
@@ -242,6 +269,7 @@ class DatabaseManager:
base: Optional[DeclarativeMeta] = None,
query_class=BaseQuery,
override_query_class: Optional[bool] = False,
session_options: Optional[Dict] = None,
):
"""Initialize the database manager.
@@ -251,18 +279,26 @@ class DatabaseManager:
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False.
session_options (Optional[Dict], optional): The session options. Defaults to None.
"""
if session_options is None:
session_options = {}
self._db_url = db_url
if query_class is not None:
self.Query = query_class
if base is not None:
self._base = base
if not hasattr(base, "query") or override_query_class:
base.query = _QueryObject(self)
# if not hasattr(base, "query") or override_query_class:
# base.query = _QueryObject()
if not getattr(base, "query_class", None) or override_query_class:
base.query_class = self.Query
if not hasattr(base, "__db_manager__") or override_query_class:
base.__db_manager__ = self
self._engine = create_engine(db_url, **(engine_args or {}))
session_factory = sessionmaker(bind=self._engine)
session_options.setdefault("class_", Session)
session_options.setdefault("query_cls", self.Query)
session_factory = sessionmaker(bind=self._engine, **session_options)
self._session = scoped_session(session_factory)
self._base.metadata.bind = self._engine
@@ -397,35 +433,12 @@ class BaseCRUDMixin(Generic[T]):
__abstract__ = True
@classmethod
def create(cls: Type[T], **kwargs) -> T:
instance = cls(**kwargs)
return instance.save()
@classmethod
def all(cls: Type[T]) -> List[T]:
return cls.query.all()
@classmethod
def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]:
"""Get a record by its primary key identifier."""
def update(self: T, commit: Optional[bool] = True, **kwargs) -> T:
"""Update specific fields of a record."""
for attr, value in kwargs.items():
setattr(self, attr, value)
return commit and self.save() or self
@abc.abstractmethod
def save(self: T, commit: Optional[bool] = True) -> T:
"""Save the record."""
@abc.abstractmethod
def delete(self: T, commit: Optional[bool] = True) -> None:
"""Remove the record from the database."""
def db(cls) -> DatabaseManager:
"""Get the database manager."""
return cls.__db_manager__
class BaseModel(BaseCRUDMixin[T], _Model, Generic[T]):
"""The base model class that includes CRUD convenience methods."""
__abstract__ = True
@@ -438,28 +451,14 @@ def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
_db_manager: DatabaseManager = db_manager
@classmethod
def set_db_manager(cls, db_manager: DatabaseManager):
def set_db(cls, db_manager: DatabaseManager):
# TODO: It is hard to replace to user DB Connection
cls._db_manager = db_manager
@classmethod
def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]:
"""Get a record by its primary key identifier."""
return cls._db_manager._session().get(cls, ident)
def save(self: T, commit: Optional[bool] = True) -> T:
"""Save the record."""
session = self._db_manager._session()
session.add(self)
if commit:
session.commit()
return self
def delete(self: T, commit: Optional[bool] = True) -> None:
"""Remove the record from the database."""
session = self._db_manager._session()
session.delete(self)
return commit and session.commit()
def db(cls) -> DatabaseManager:
"""Get the database manager."""
return cls._db_manager
class _NewModel(CRUDMixin[T], db_manager.Model, Generic[T]):
"""Base model class that includes CRUD convenience methods."""
@@ -478,6 +477,7 @@ def initialize_db(
engine_args: Optional[Dict] = None,
base: Optional[DeclarativeMeta] = None,
try_to_create_db: Optional[bool] = False,
session_options: Optional[Dict] = None,
) -> DatabaseManager:
"""Initialize the database manager.
@@ -487,10 +487,11 @@ def initialize_db(
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
try_to_create_db (Optional[bool], optional): Whether to try to create the database. Defaults to False.
session_options (Optional[Dict], optional): The session options. Defaults to None.
Returns:
DatabaseManager: The database manager.
"""
db.init_db(db_url, engine_args, base)
db.init_db(db_url, engine_args, base, session_options=session_options)
if try_to_create_db:
try:
db.create_all()

View File

@@ -100,7 +100,7 @@ def test_update_user(db: DatabaseManager, User: Type[BaseModel], user_dao, user_
# Verify that the user is updated in the database
with db.session() as session:
user = session.query(User).get(created_user_response.id)
user = session.get(User, created_user_response.id)
assert user.age == 35
@@ -121,7 +121,7 @@ def test_update_user_partial(
# Verify that the user is updated in the database
with db.session() as session:
user = session.query(User).get(created_user_response.id)
user = session.get(User, created_user_response.id)
assert user.age == user_req.age
assert user.password == "newpassword"

View File

@@ -53,11 +53,10 @@ def test_crud_operations(db: DatabaseManager, Model: Type[BaseModel]):
# Create
with db.session() as session:
user = User.create(name="John Doe")
user = User(name="John Doe")
session.add(user)
session.commit()
# Read
# # Read
with db.session() as session:
user = session.query(User).filter_by(name="John Doe").first()
assert user is not None
@@ -65,12 +64,20 @@ def test_crud_operations(db: DatabaseManager, Model: Type[BaseModel]):
# Update
with db.session() as session:
user = session.query(User).filter_by(name="John Doe").first()
user.update(name="Jane Doe")
# Delete
user.name = "Mike Doe"
session.merge(user)
with db.session() as session:
user = session.query(User).filter_by(name="Jane Doe").first()
user.delete()
user = session.query(User).filter_by(name="Mike Doe").first()
assert user is not None
session.query(User).filter(User.name == "John Doe").first() is None
#
# # Delete
with db.session() as session:
user = session.query(User).filter_by(name="Mike Doe").first()
session.delete(user)
with db.session() as session:
assert len(session.query(User).all()) == 0
def test_crud_mixins(db: DatabaseManager, Model: Type[BaseModel]):
@@ -80,20 +87,7 @@ def test_crud_mixins(db: DatabaseManager, Model: Type[BaseModel]):
name = Column(String(50))
db.create_all()
# Create
user = User.create(name="John Doe")
assert User.get(user.id) is not None
users = User.all()
assert len(users) == 1
# Update
user.update(name="Bob Doe")
assert User.get(user.id).name == "Bob Doe"
user = User.get(user.id)
user.delete()
assert User.get(user.id) is None
User.db() == db
def test_pagination_query(db: DatabaseManager, Model: Type[BaseModel]):
@@ -108,11 +102,10 @@ def test_pagination_query(db: DatabaseManager, Model: Type[BaseModel]):
for i in range(30):
user = User(name=f"User {i}")
session.add(user)
session.commit()
users_page_1 = User.query.paginate_query(page=1, per_page=10)
assert len(users_page_1.items) == 10
assert users_page_1.total_pages == 3
with db.session() as session:
users_page_1 = session.query(User).paginate_query(page=1, per_page=10)
assert len(users_page_1.items) == 10
assert users_page_1.total_pages == 3
def test_invalid_pagination(db: DatabaseManager, Model: Type[BaseModel]):
@@ -124,9 +117,11 @@ def test_invalid_pagination(db: DatabaseManager, Model: Type[BaseModel]):
db.create_all()
with pytest.raises(ValueError):
User.query.paginate_query(page=0, per_page=10)
with db.session() as session:
session.query(User).paginate_query(page=0, per_page=10)
with pytest.raises(ValueError):
User.query.paginate_query(page=1, per_page=-1)
with db.session() as session:
session.query(User).paginate_query(page=1, per_page=-1)
def test_set_model_db_manager(db: DatabaseManager, Model: Type[BaseModel]):
@@ -142,14 +137,19 @@ def test_set_model_db_manager(db: DatabaseManager, Model: Type[BaseModel]):
new_db = DatabaseManager.build_from(
f"sqlite:///{filename}", base=Model, override_query_class=True
)
Model.set_db_manager(new_db)
Model.set_db(new_db)
new_db.create_all()
db.create_all()
assert list(new_db.metadata.tables.keys())[0] == "user"
User.create(**{"name": "John Doe"})
with new_db.session() as session:
user = User(name="John Doe")
session.add(user)
with new_db.session() as session:
assert session.query(User).filter_by(name="John Doe").first() is not None
with db.session() as session:
assert session.query(User).filter_by(name="John Doe").first() is None
assert len(User.query.all()) == 1
assert User.query.filter(User.name == "John Doe").first().name == "John Doe"
with new_db.session() as session:
session.query(User).all() == 1
session.query(User).filter(
User.name == "John Doe"
).first().name == "John Doe"

View File

@@ -7,7 +7,7 @@
Call with non-streaming response.
.. code-block:: shell
DBGPT_SERVER="http://127.0.0.1:5000"
DBGPT_SERVER="http://127.0.0.1:5555"
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/chat/completions \
-H "Content-Type: application/json" -d '{
"model": "proxyllm",