mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-05 11:01:09 +00:00
feat(model): Proxy model support count token (#996)
This commit is contained in:
@@ -189,7 +189,7 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
def count_token(self, prompt: str) -> int:
|
def count_token(self, prompt: str) -> int:
|
||||||
return _try_to_count_token(prompt, self.tokenizer)
|
return _try_to_count_token(prompt, self.tokenizer, self.model)
|
||||||
|
|
||||||
async def async_count_token(self, prompt: str) -> int:
|
async def async_count_token(self, prompt: str) -> int:
|
||||||
# TODO if we deploy the model by vllm, it can't work, we should run transformer _try_to_count_token to async
|
# TODO if we deploy the model by vllm, it can't work, we should run transformer _try_to_count_token to async
|
||||||
@@ -454,12 +454,13 @@ def _new_metrics_from_model_output(
|
|||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
def _try_to_count_token(prompt: str, tokenizer) -> int:
|
def _try_to_count_token(prompt: str, tokenizer, model) -> int:
|
||||||
"""Try to count token of prompt
|
"""Try to count token of prompt
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt (str): prompt
|
prompt (str): prompt
|
||||||
tokenizer ([type]): tokenizer
|
tokenizer ([type]): tokenizer
|
||||||
|
model ([type]): model
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int: token count, if error return -1
|
int: token count, if error return -1
|
||||||
@@ -467,6 +468,11 @@ def _try_to_count_token(prompt: str, tokenizer) -> int:
|
|||||||
TODO: More implementation
|
TODO: More implementation
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||||
|
|
||||||
|
if isinstance(model, ProxyModel):
|
||||||
|
return model.count_token(prompt)
|
||||||
|
# Only support huggingface model now
|
||||||
return len(tokenizer(prompt).input_ids[0])
|
return len(tokenizer(prompt).input_ids[0])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Count token error, detail: {e}, return -1")
|
logger.warning(f"Count token error, detail: {e}, return -1")
|
||||||
|
@@ -197,7 +197,7 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
# TODO Update worker
|
# TODO Update worker
|
||||||
logger.warn(f"Instance {worker_key} exist")
|
logger.warning(f"Instance {worker_key} exist")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _remove_worker(self, worker_params: ModelWorkerParameters) -> None:
|
def _remove_worker(self, worker_params: ModelWorkerParameters) -> None:
|
||||||
@@ -229,7 +229,7 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
)
|
)
|
||||||
if not success:
|
if not success:
|
||||||
msg = f"Add worker {model_name}@{worker_type}, worker instances is exist"
|
msg = f"Add worker {model_name}@{worker_type}, worker instances is exist"
|
||||||
logger.warn(f"{msg}, worker_params: {worker_params}")
|
logger.warning(f"{msg}, worker_params: {worker_params}")
|
||||||
self._remove_worker(worker_params)
|
self._remove_worker(worker_params)
|
||||||
raise Exception(msg)
|
raise Exception(msg)
|
||||||
supported_types = WorkerType.values()
|
supported_types = WorkerType.values()
|
||||||
|
@@ -92,11 +92,11 @@ def _initialize_openai_v1(params: ProxyModelParameters):
|
|||||||
|
|
||||||
|
|
||||||
def __convert_2_gpt_messages(messages: List[ModelMessage]):
|
def __convert_2_gpt_messages(messages: List[ModelMessage]):
|
||||||
chat_round = 0
|
|
||||||
gpt_messages = []
|
gpt_messages = []
|
||||||
last_usr_message = ""
|
last_usr_message = ""
|
||||||
system_messages = []
|
system_messages = []
|
||||||
|
|
||||||
|
# TODO: We can't change message order in low level
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if message.role == ModelMessageRoleType.HUMAN or message.role == "user":
|
if message.role == ModelMessageRoleType.HUMAN or message.role == "user":
|
||||||
last_usr_message = message.content
|
last_usr_message = message.content
|
||||||
|
@@ -1,9 +1,36 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Union, List, Optional, TYPE_CHECKING
|
||||||
|
import logging
|
||||||
from dbgpt.model.parameter import ProxyModelParameters
|
from dbgpt.model.parameter import ProxyModelParameters
|
||||||
|
from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from dbgpt.core.interface.message import ModelMessage, BaseMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ProxyModel:
|
class ProxyModel:
|
||||||
def __init__(self, model_params: ProxyModelParameters) -> None:
|
def __init__(self, model_params: ProxyModelParameters) -> None:
|
||||||
self._model_params = model_params
|
self._model_params = model_params
|
||||||
|
self._tokenizer = ProxyTokenizerWrapper()
|
||||||
|
|
||||||
def get_params(self) -> ProxyModelParameters:
|
def get_params(self) -> ProxyModelParameters:
|
||||||
return self._model_params
|
return self._model_params
|
||||||
|
|
||||||
|
def count_token(
|
||||||
|
self,
|
||||||
|
messages: Union[str, BaseMessage, ModelMessage, List[ModelMessage]],
|
||||||
|
model_name: Optional[int] = None,
|
||||||
|
) -> int:
|
||||||
|
"""Count token of given messages
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (Union[str, BaseMessage, ModelMessage, List[ModelMessage]]): messages to count token
|
||||||
|
model_name (Optional[int], optional): model name. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: token count, -1 if failed
|
||||||
|
"""
|
||||||
|
return self._tokenizer.count_token(messages, model_name)
|
||||||
|
@@ -25,6 +25,7 @@ from dbgpt.core.interface.llm import ModelOutput, ModelRequest
|
|||||||
from dbgpt.model.cluster.client import DefaultLLMClient
|
from dbgpt.model.cluster.client import DefaultLLMClient
|
||||||
from dbgpt.model.cluster import WorkerManagerFactory
|
from dbgpt.model.cluster import WorkerManagerFactory
|
||||||
from dbgpt._private.pydantic import model_to_json
|
from dbgpt._private.pydantic import model_to_json
|
||||||
|
from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import httpx
|
import httpx
|
||||||
@@ -152,6 +153,7 @@ class OpenAILLMClient(LLMClient):
|
|||||||
self._context_length = context_length
|
self._context_length = context_length
|
||||||
self._client = openai_client
|
self._client = openai_client
|
||||||
self._openai_kwargs = openai_kwargs or {}
|
self._openai_kwargs = openai_kwargs or {}
|
||||||
|
self._tokenizer = ProxyTokenizerWrapper()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self) -> ClientType:
|
def client(self) -> ClientType:
|
||||||
@@ -238,10 +240,11 @@ class OpenAILLMClient(LLMClient):
|
|||||||
async def count_token(self, model: str, prompt: str) -> int:
|
async def count_token(self, model: str, prompt: str) -> int:
|
||||||
"""Count the number of tokens in a given prompt.
|
"""Count the number of tokens in a given prompt.
|
||||||
|
|
||||||
TODO: Get the real number of tokens from the openai api or tiktoken package
|
Args:
|
||||||
|
model (str): The model name.
|
||||||
|
prompt (str): The prompt.
|
||||||
"""
|
"""
|
||||||
|
return self._tokenizer.count_token(prompt, model)
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIStreamingOperator(TransformStreamAbsOperator[ModelOutput, str]):
|
class OpenAIStreamingOperator(TransformStreamAbsOperator[ModelOutput, str]):
|
||||||
|
80
dbgpt/model/utils/token_utils.py
Normal file
80
dbgpt/model/utils/token_utils.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Union, List, Optional, TYPE_CHECKING
|
||||||
|
import logging
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from dbgpt.core.interface.message import ModelMessage, BaseMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ProxyTokenizerWrapper:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._support_encoding = True
|
||||||
|
self._encoding_model = None
|
||||||
|
|
||||||
|
def count_token(
|
||||||
|
self,
|
||||||
|
messages: Union[str, BaseMessage, ModelMessage, List[ModelMessage]],
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
) -> int:
|
||||||
|
"""Count token of given messages
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (Union[str, BaseMessage, ModelMessage, List[ModelMessage]]): messages to count token
|
||||||
|
model_name (Optional[str], optional): model name. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: token count, -1 if failed
|
||||||
|
"""
|
||||||
|
if not self._support_encoding:
|
||||||
|
logger.warning(
|
||||||
|
"model does not support encoding model, can't count token, returning -1"
|
||||||
|
)
|
||||||
|
return -1
|
||||||
|
encoding = self._get_or_create_encoding_model(model_name)
|
||||||
|
cnt = 0
|
||||||
|
if isinstance(messages, str):
|
||||||
|
cnt = len(encoding.encode(messages, disallowed_special=()))
|
||||||
|
elif isinstance(messages, BaseMessage):
|
||||||
|
cnt = len(encoding.encode(messages.content, disallowed_special=()))
|
||||||
|
elif isinstance(messages, ModelMessage):
|
||||||
|
cnt = len(encoding.encode(messages.content, disallowed_special=()))
|
||||||
|
elif isinstance(messages, list):
|
||||||
|
for message in messages:
|
||||||
|
cnt += len(encoding.encode(message.content, disallowed_special=()))
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"unsupported type of messages, can't count token, returning -1"
|
||||||
|
)
|
||||||
|
return -1
|
||||||
|
return cnt
|
||||||
|
|
||||||
|
def _get_or_create_encoding_model(self, model_name: Optional[str] = None):
|
||||||
|
"""Get or create encoding model for given model name
|
||||||
|
More detail see: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||||
|
"""
|
||||||
|
if self._encoding_model:
|
||||||
|
return self._encoding_model
|
||||||
|
try:
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"tiktoken installed, using it to count tokens, tiktoken will download tokenizer from network, "
|
||||||
|
"also you can download it and put it in the directory of environment variable TIKTOKEN_CACHE_DIR"
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
self._support_encoding = False
|
||||||
|
logger.warn("tiktoken not installed, cannot count tokens, returning -1")
|
||||||
|
return -1
|
||||||
|
try:
|
||||||
|
if not model_name:
|
||||||
|
model_name = "gpt-3.5-turbo"
|
||||||
|
self._encoding_model = tiktoken.model.encoding_for_model(model_name)
|
||||||
|
except KeyError:
|
||||||
|
logger.warning(
|
||||||
|
f"{model_name}'s tokenizer not found, using cl100k_base encoding."
|
||||||
|
)
|
||||||
|
self._encoding_model = tiktoken.get_encoding("cl100k_base")
|
||||||
|
return self._encoding_model
|
@@ -1,5 +1,3 @@
|
|||||||
from typing import List
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from dbgpt.storage.metadata import db
|
from dbgpt.storage.metadata import db
|
||||||
@@ -39,11 +37,9 @@ def test_table_exist():
|
|||||||
|
|
||||||
|
|
||||||
def test_entity_create(default_entity_dict):
|
def test_entity_create(default_entity_dict):
|
||||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
|
||||||
# TODO: implement your test case
|
|
||||||
with db.session() as session:
|
with db.session() as session:
|
||||||
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
|
entity = ServeEntity(**default_entity_dict)
|
||||||
assert db_entity.id == entity.id
|
session.add(entity)
|
||||||
|
|
||||||
|
|
||||||
def test_entity_unique_key(default_entity_dict):
|
def test_entity_unique_key(default_entity_dict):
|
||||||
@@ -52,10 +48,8 @@ def test_entity_unique_key(default_entity_dict):
|
|||||||
|
|
||||||
|
|
||||||
def test_entity_get(default_entity_dict):
|
def test_entity_get(default_entity_dict):
|
||||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
|
||||||
db_entity: ServeEntity = ServeEntity.get(entity.id)
|
|
||||||
assert db_entity.id == entity.id
|
|
||||||
# TODO: implement your test case
|
# TODO: implement your test case
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def test_entity_update(default_entity_dict):
|
def test_entity_update(default_entity_dict):
|
||||||
@@ -65,10 +59,7 @@ def test_entity_update(default_entity_dict):
|
|||||||
|
|
||||||
def test_entity_delete(default_entity_dict):
|
def test_entity_delete(default_entity_dict):
|
||||||
# TODO: implement your test case
|
# TODO: implement your test case
|
||||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
pass
|
||||||
entity.delete()
|
|
||||||
db_entity: ServeEntity = ServeEntity.get(entity.id)
|
|
||||||
assert db_entity is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_entity_all():
|
def test_entity_all():
|
||||||
|
@@ -47,9 +47,11 @@ def test_table_exist():
|
|||||||
|
|
||||||
|
|
||||||
def test_entity_create(default_entity_dict):
|
def test_entity_create(default_entity_dict):
|
||||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
|
||||||
with db.session() as session:
|
with db.session() as session:
|
||||||
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
|
entity: ServeEntity = ServeEntity(**default_entity_dict)
|
||||||
|
session.add(entity)
|
||||||
|
session.commit()
|
||||||
|
db_entity: ServeEntity = session.get(ServeEntity, entity.id)
|
||||||
assert db_entity.id == entity.id
|
assert db_entity.id == entity.id
|
||||||
assert db_entity.chat_scene == "chat_data"
|
assert db_entity.chat_scene == "chat_data"
|
||||||
assert db_entity.sub_chat_scene == "excel"
|
assert db_entity.sub_chat_scene == "excel"
|
||||||
@@ -63,78 +65,96 @@ def test_entity_create(default_entity_dict):
|
|||||||
|
|
||||||
|
|
||||||
def test_entity_unique_key(default_entity_dict):
|
def test_entity_unique_key(default_entity_dict):
|
||||||
ServeEntity.create(**default_entity_dict)
|
with db.session() as session:
|
||||||
|
entity = ServeEntity(**default_entity_dict)
|
||||||
|
session.add(entity)
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
ServeEntity.create(
|
with db.session() as session:
|
||||||
**{
|
entity = ServeEntity(
|
||||||
"prompt_name": "my_prompt_1",
|
**{
|
||||||
"sys_code": "dbgpt",
|
"prompt_name": "my_prompt_1",
|
||||||
"prompt_language": "zh",
|
"sys_code": "dbgpt",
|
||||||
"model": "vicuna-13b-v1.5",
|
"prompt_language": "zh",
|
||||||
}
|
"model": "vicuna-13b-v1.5",
|
||||||
)
|
}
|
||||||
|
)
|
||||||
|
session.add(entity)
|
||||||
|
|
||||||
|
|
||||||
def test_entity_get(default_entity_dict):
|
def test_entity_get(default_entity_dict):
|
||||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
with db.session() as session:
|
||||||
db_entity: ServeEntity = ServeEntity.get(entity.id)
|
entity = ServeEntity(**default_entity_dict)
|
||||||
assert db_entity.id == entity.id
|
session.add(entity)
|
||||||
assert db_entity.chat_scene == "chat_data"
|
session.commit()
|
||||||
assert db_entity.sub_chat_scene == "excel"
|
db_entity: ServeEntity = session.get(ServeEntity, entity.id)
|
||||||
assert db_entity.prompt_type == "common"
|
assert db_entity.id == entity.id
|
||||||
assert db_entity.prompt_name == "my_prompt_1"
|
assert db_entity.chat_scene == "chat_data"
|
||||||
assert db_entity.content == "Write a qsort function in python."
|
assert db_entity.sub_chat_scene == "excel"
|
||||||
assert db_entity.user_name == "zhangsan"
|
assert db_entity.prompt_type == "common"
|
||||||
assert db_entity.sys_code == "dbgpt"
|
assert db_entity.prompt_name == "my_prompt_1"
|
||||||
assert db_entity.gmt_created is not None
|
assert db_entity.content == "Write a qsort function in python."
|
||||||
assert db_entity.gmt_modified is not None
|
assert db_entity.user_name == "zhangsan"
|
||||||
|
assert db_entity.sys_code == "dbgpt"
|
||||||
|
assert db_entity.gmt_created is not None
|
||||||
|
assert db_entity.gmt_modified is not None
|
||||||
|
|
||||||
|
|
||||||
def test_entity_update(default_entity_dict):
|
def test_entity_update(default_entity_dict):
|
||||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
with db.session() as session:
|
||||||
entity.update(prompt_name="my_prompt_2")
|
entity = ServeEntity(**default_entity_dict)
|
||||||
db_entity: ServeEntity = ServeEntity.get(entity.id)
|
session.add(entity)
|
||||||
assert db_entity.id == entity.id
|
session.commit()
|
||||||
assert db_entity.chat_scene == "chat_data"
|
entity.prompt_name = "my_prompt_2"
|
||||||
assert db_entity.sub_chat_scene == "excel"
|
session.merge(entity)
|
||||||
assert db_entity.prompt_type == "common"
|
db_entity: ServeEntity = session.get(ServeEntity, entity.id)
|
||||||
assert db_entity.prompt_name == "my_prompt_2"
|
assert db_entity.id == entity.id
|
||||||
assert db_entity.content == "Write a qsort function in python."
|
assert db_entity.chat_scene == "chat_data"
|
||||||
assert db_entity.user_name == "zhangsan"
|
assert db_entity.sub_chat_scene == "excel"
|
||||||
assert db_entity.sys_code == "dbgpt"
|
assert db_entity.prompt_type == "common"
|
||||||
assert db_entity.gmt_created is not None
|
assert db_entity.prompt_name == "my_prompt_2"
|
||||||
assert db_entity.gmt_modified is not None
|
assert db_entity.content == "Write a qsort function in python."
|
||||||
|
assert db_entity.user_name == "zhangsan"
|
||||||
|
assert db_entity.sys_code == "dbgpt"
|
||||||
|
assert db_entity.gmt_created is not None
|
||||||
|
assert db_entity.gmt_modified is not None
|
||||||
|
|
||||||
|
|
||||||
def test_entity_delete(default_entity_dict):
|
def test_entity_delete(default_entity_dict):
|
||||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
with db.session() as session:
|
||||||
entity.delete()
|
entity = ServeEntity(**default_entity_dict)
|
||||||
db_entity: ServeEntity = ServeEntity.get(entity.id)
|
session.add(entity)
|
||||||
assert db_entity is None
|
session.commit()
|
||||||
|
session.delete(entity)
|
||||||
|
session.commit()
|
||||||
|
db_entity: ServeEntity = session.get(ServeEntity, entity.id)
|
||||||
|
assert db_entity is None
|
||||||
|
|
||||||
|
|
||||||
def test_entity_all():
|
def test_entity_all():
|
||||||
for i in range(10):
|
with db.session() as session:
|
||||||
ServeEntity.create(
|
for i in range(10):
|
||||||
chat_scene="chat_data",
|
entity = ServeEntity(
|
||||||
sub_chat_scene="excel",
|
chat_scene="chat_data",
|
||||||
prompt_type="common",
|
sub_chat_scene="excel",
|
||||||
prompt_name=f"my_prompt_{i}",
|
prompt_type="common",
|
||||||
content="Write a qsort function in python.",
|
prompt_name=f"my_prompt_{i}",
|
||||||
user_name="zhangsan",
|
content="Write a qsort function in python.",
|
||||||
sys_code="dbgpt",
|
user_name="zhangsan",
|
||||||
)
|
sys_code="dbgpt",
|
||||||
entities = ServeEntity.all()
|
)
|
||||||
assert len(entities) == 10
|
session.add(entity)
|
||||||
for entity in entities:
|
with db.session() as session:
|
||||||
assert entity.chat_scene == "chat_data"
|
entities = session.query(ServeEntity).all()
|
||||||
assert entity.sub_chat_scene == "excel"
|
assert len(entities) == 10
|
||||||
assert entity.prompt_type == "common"
|
for entity in entities:
|
||||||
assert entity.content == "Write a qsort function in python."
|
assert entity.chat_scene == "chat_data"
|
||||||
assert entity.user_name == "zhangsan"
|
assert entity.sub_chat_scene == "excel"
|
||||||
assert entity.sys_code == "dbgpt"
|
assert entity.prompt_type == "common"
|
||||||
assert entity.gmt_created is not None
|
assert entity.content == "Write a qsort function in python."
|
||||||
assert entity.gmt_modified is not None
|
assert entity.user_name == "zhangsan"
|
||||||
|
assert entity.sys_code == "dbgpt"
|
||||||
|
assert entity.gmt_created is not None
|
||||||
|
assert entity.gmt_modified is not None
|
||||||
|
|
||||||
|
|
||||||
def test_dao_create(dao, default_entity_dict):
|
def test_dao_create(dao, default_entity_dict):
|
||||||
|
@@ -75,7 +75,7 @@ def test_config_default_user(service: Service):
|
|||||||
def test_service_create(service: Service, default_entity_dict):
|
def test_service_create(service: Service, default_entity_dict):
|
||||||
entity: ServerResponse = service.create(ServeRequest(**default_entity_dict))
|
entity: ServerResponse = service.create(ServeRequest(**default_entity_dict))
|
||||||
with db.session() as session:
|
with db.session() as session:
|
||||||
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
|
db_entity: ServeEntity = session.get(ServeEntity, entity.id)
|
||||||
assert db_entity.id == entity.id
|
assert db_entity.id == entity.id
|
||||||
assert db_entity.chat_scene == "chat_data"
|
assert db_entity.chat_scene == "chat_data"
|
||||||
assert db_entity.sub_chat_scene == "excel"
|
assert db_entity.sub_chat_scene == "excel"
|
||||||
@@ -92,7 +92,7 @@ def test_service_update(service: Service, default_entity_dict):
|
|||||||
service.create(ServeRequest(**default_entity_dict))
|
service.create(ServeRequest(**default_entity_dict))
|
||||||
entity: ServerResponse = service.update(ServeRequest(**default_entity_dict))
|
entity: ServerResponse = service.update(ServeRequest(**default_entity_dict))
|
||||||
with db.session() as session:
|
with db.session() as session:
|
||||||
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
|
db_entity: ServeEntity = session.get(ServeEntity, entity.id)
|
||||||
assert db_entity.id == entity.id
|
assert db_entity.id == entity.id
|
||||||
assert db_entity.chat_scene == "chat_data"
|
assert db_entity.chat_scene == "chat_data"
|
||||||
assert db_entity.sub_chat_scene == "excel"
|
assert db_entity.sub_chat_scene == "excel"
|
||||||
@@ -109,7 +109,7 @@ def test_service_get(service: Service, default_entity_dict):
|
|||||||
service.create(ServeRequest(**default_entity_dict))
|
service.create(ServeRequest(**default_entity_dict))
|
||||||
entity: ServerResponse = service.get(ServeRequest(**default_entity_dict))
|
entity: ServerResponse = service.get(ServeRequest(**default_entity_dict))
|
||||||
with db.session() as session:
|
with db.session() as session:
|
||||||
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
|
db_entity: ServeEntity = session.get(ServeEntity, entity.id)
|
||||||
assert db_entity.id == entity.id
|
assert db_entity.id == entity.id
|
||||||
assert db_entity.chat_scene == "chat_data"
|
assert db_entity.chat_scene == "chat_data"
|
||||||
assert db_entity.sub_chat_scene == "excel"
|
assert db_entity.sub_chat_scene == "excel"
|
||||||
|
@@ -1,5 +1,3 @@
|
|||||||
from typing import List
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from dbgpt.storage.metadata import db
|
from dbgpt.storage.metadata import db
|
||||||
@@ -39,11 +37,9 @@ def test_table_exist():
|
|||||||
|
|
||||||
|
|
||||||
def test_entity_create(default_entity_dict):
|
def test_entity_create(default_entity_dict):
|
||||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
|
||||||
# TODO: implement your test case
|
|
||||||
with db.session() as session:
|
with db.session() as session:
|
||||||
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
|
entity = ServeEntity(**default_entity_dict)
|
||||||
assert db_entity.id == entity.id
|
session.add(entity)
|
||||||
|
|
||||||
|
|
||||||
def test_entity_unique_key(default_entity_dict):
|
def test_entity_unique_key(default_entity_dict):
|
||||||
@@ -52,10 +48,8 @@ def test_entity_unique_key(default_entity_dict):
|
|||||||
|
|
||||||
|
|
||||||
def test_entity_get(default_entity_dict):
|
def test_entity_get(default_entity_dict):
|
||||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
|
||||||
db_entity: ServeEntity = ServeEntity.get(entity.id)
|
|
||||||
assert db_entity.id == entity.id
|
|
||||||
# TODO: implement your test case
|
# TODO: implement your test case
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def test_entity_update(default_entity_dict):
|
def test_entity_update(default_entity_dict):
|
||||||
@@ -65,10 +59,7 @@ def test_entity_update(default_entity_dict):
|
|||||||
|
|
||||||
def test_entity_delete(default_entity_dict):
|
def test_entity_delete(default_entity_dict):
|
||||||
# TODO: implement your test case
|
# TODO: implement your test case
|
||||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
pass
|
||||||
entity.delete()
|
|
||||||
db_entity: ServeEntity = ServeEntity.get(entity.id)
|
|
||||||
assert db_entity is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_entity_all():
|
def test_entity_all():
|
||||||
|
@@ -105,12 +105,6 @@ class ChatHistoryDao(BaseDao):
|
|||||||
chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid)
|
chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid)
|
||||||
chat_history.delete()
|
chat_history.delete()
|
||||||
|
|
||||||
def get_by_uid(self, conv_uid: str) -> ChatHistoryEntity:
|
def get_by_uid(self, conv_uid: str) -> Optional[ChatHistoryEntity]:
|
||||||
# return ChatHistoryEntity.query.filter_by(conv_uid=conv_uid).first()
|
with self.session(commit=False) as session:
|
||||||
|
return session.query(ChatHistoryEntity).filter_by(conv_uid=conv_uid).first()
|
||||||
session = self.get_raw_session()
|
|
||||||
chat_history = session.query(ChatHistoryEntity)
|
|
||||||
chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid)
|
|
||||||
result = chat_history.first()
|
|
||||||
session.close()
|
|
||||||
return result
|
|
||||||
|
@@ -51,7 +51,9 @@ class BaseDao(Generic[T, REQ, RES]):
|
|||||||
|
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
user = User(name="Edward Snowden")
|
user = User(name="Edward Snowden")
|
||||||
session = self.get_raw_session()
|
session = self.get_raw_session()
|
||||||
session.add(user)
|
session.add(user)
|
||||||
@@ -61,7 +63,7 @@ class BaseDao(Generic[T, REQ, RES]):
|
|||||||
return self._db_manager._session()
|
return self._db_manager._session()
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def session(self) -> Session:
|
def session(self, commit: Optional[bool] = True) -> Session:
|
||||||
"""Provide a transactional scope around a series of operations.
|
"""Provide a transactional scope around a series of operations.
|
||||||
|
|
||||||
If raise an exception, the session will be roll back automatically, otherwise it will be committed.
|
If raise an exception, the session will be roll back automatically, otherwise it will be committed.
|
||||||
@@ -71,13 +73,16 @@ class BaseDao(Generic[T, REQ, RES]):
|
|||||||
with self.session() as session:
|
with self.session() as session:
|
||||||
session.query(User).filter(User.name == 'Edward Snowden').first()
|
session.query(User).filter(User.name == 'Edward Snowden').first()
|
||||||
|
|
||||||
|
Args:
|
||||||
|
commit (Optional[bool], optional): Whether to commit the session. Defaults to True.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Session: A session object.
|
Session: A session object.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Exception: Any exception will be raised.
|
Exception: Any exception will be raised.
|
||||||
"""
|
"""
|
||||||
with self._db_manager.session() as session:
|
with self._db_manager.session(commit=commit) as session:
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
def from_request(self, request: QUERY_SPEC) -> T:
|
def from_request(self, request: QUERY_SPEC) -> T:
|
||||||
|
@@ -1,8 +1,15 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import abc
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import TypeVar, Generic, Union, Dict, Optional, Type, Iterator, List
|
from typing import (
|
||||||
|
TypeVar,
|
||||||
|
Generic,
|
||||||
|
Union,
|
||||||
|
Dict,
|
||||||
|
Optional,
|
||||||
|
Type,
|
||||||
|
ClassVar,
|
||||||
|
)
|
||||||
import logging
|
import logging
|
||||||
from sqlalchemy import create_engine, URL, Engine
|
from sqlalchemy import create_engine, URL, Engine
|
||||||
from sqlalchemy import orm, inspect, MetaData
|
from sqlalchemy import orm, inspect, MetaData
|
||||||
@@ -13,8 +20,6 @@ from sqlalchemy.orm import (
|
|||||||
declarative_base,
|
declarative_base,
|
||||||
DeclarativeMeta,
|
DeclarativeMeta,
|
||||||
)
|
)
|
||||||
from sqlalchemy.orm.session import _PKIdentityArgument
|
|
||||||
from sqlalchemy.orm.exc import UnmappedClassError
|
|
||||||
|
|
||||||
from sqlalchemy.pool import QueuePool
|
from sqlalchemy.pool import QueuePool
|
||||||
from dbgpt.util.string_utils import _to_str
|
from dbgpt.util.string_utils import _to_str
|
||||||
@@ -27,16 +32,10 @@ T = TypeVar("T", bound="BaseModel")
|
|||||||
class _QueryObject:
|
class _QueryObject:
|
||||||
"""The query object."""
|
"""The query object."""
|
||||||
|
|
||||||
def __init__(self, db_manager: "DatabaseManager"):
|
def __get__(self, obj: Union[_Model, None], model_cls: type[_Model]):
|
||||||
self._db_manager = db_manager
|
return model_cls.query_class(
|
||||||
|
model_cls, session=model_cls.__db_manager__._session()
|
||||||
def __get__(self, obj, type):
|
)
|
||||||
try:
|
|
||||||
mapper = orm.class_mapper(type)
|
|
||||||
if mapper:
|
|
||||||
return type.query_class(mapper, session=self._db_manager._session())
|
|
||||||
except UnmappedClassError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class BaseQuery(orm.Query):
|
class BaseQuery(orm.Query):
|
||||||
@@ -46,7 +45,9 @@ class BaseQuery(orm.Query):
|
|||||||
"""Paginate the query.
|
"""Paginate the query.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from dbgpt.storage.metadata import db, Model
|
from dbgpt.storage.metadata import db, Model
|
||||||
class User(Model):
|
class User(Model):
|
||||||
__tablename__ = "user"
|
__tablename__ = "user"
|
||||||
@@ -58,10 +59,6 @@ class BaseQuery(orm.Query):
|
|||||||
pagination = session.query(User).paginate_query(page=1, page_size=10)
|
pagination = session.query(User).paginate_query(page=1, page_size=10)
|
||||||
print(pagination)
|
print(pagination)
|
||||||
|
|
||||||
# Or you can use the query object
|
|
||||||
with db.session() as session:
|
|
||||||
pagination = User.query.paginate_query(page=1, page_size=10)
|
|
||||||
print(pagination)
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
page (Optional[int], optional): The page number. Defaults to 1.
|
page (Optional[int], optional): The page number. Defaults to 1.
|
||||||
@@ -86,26 +83,12 @@ class BaseQuery(orm.Query):
|
|||||||
|
|
||||||
|
|
||||||
class _Model:
|
class _Model:
|
||||||
"""Base class for SQLAlchemy declarative base model.
|
"""Base class for SQLAlchemy declarative base model."""
|
||||||
|
|
||||||
With this class, we can use the query object to query the database.
|
__db_manager__: ClassVar[DatabaseManager]
|
||||||
|
query_class = BaseQuery
|
||||||
|
|
||||||
Examples:
|
# query: Optional[BaseQuery] = _QueryObject()
|
||||||
.. code-block:: python
|
|
||||||
from dbgpt.storage.metadata import db, Model
|
|
||||||
class User(Model):
|
|
||||||
__tablename__ = "user"
|
|
||||||
id = Column(Integer, primary_key=True)
|
|
||||||
name = Column(String(50))
|
|
||||||
fullname = Column(String(50))
|
|
||||||
|
|
||||||
with db.session() as session:
|
|
||||||
# User is an instance of _Model, and we can use the query object to query the database.
|
|
||||||
User.query.filter(User.name == "test").all()
|
|
||||||
"""
|
|
||||||
|
|
||||||
query_class = None
|
|
||||||
query: Optional[BaseQuery] = None
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
identity = inspect(self).identity
|
identity = inspect(self).identity
|
||||||
@@ -120,7 +103,9 @@ class DatabaseManager:
|
|||||||
"""The database manager.
|
"""The database manager.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from urllib.parse import quote_plus as urlquote, quote
|
from urllib.parse import quote_plus as urlquote, quote
|
||||||
from dbgpt.storage.metadata import DatabaseManager, create_model
|
from dbgpt.storage.metadata import DatabaseManager, create_model
|
||||||
db = DatabaseManager()
|
db = DatabaseManager()
|
||||||
@@ -141,21 +126,25 @@ class DatabaseManager:
|
|||||||
session.add(User(name="test", fullname="test"))
|
session.add(User(name="test", fullname="test"))
|
||||||
# db will commit the session automatically default.
|
# db will commit the session automatically default.
|
||||||
# session.commit()
|
# session.commit()
|
||||||
print(User.query.filter(User.name == "test").all())
|
assert session.query(User).filter(User.name == "test").first().name == "test"
|
||||||
|
|
||||||
|
|
||||||
# Use CURDMixin APIs to create, update, delete, query the database.
|
# More usage:
|
||||||
|
|
||||||
with db.session() as session:
|
with db.session() as session:
|
||||||
User.create(**{"name": "test1", "fullname": "test1"})
|
session.add(User(name="test1", fullname="test1"))
|
||||||
User.create(**{"name": "test2", "fullname": "test1"})
|
session.add(User(name="test2", fullname="test1"))
|
||||||
users = User.all()
|
users = session.query(User).all()
|
||||||
print(users)
|
print(users)
|
||||||
user = users[0]
|
user = users[0]
|
||||||
user.update(**{"name": "test1_1111"})
|
user.name = "test1_1111"
|
||||||
|
session.merge(user)
|
||||||
|
|
||||||
user2 = users[1]
|
user2 = users[1]
|
||||||
# Update user2 by save
|
# Update user2 by save
|
||||||
user2.name = "test2_1111"
|
user2.name = "test2_1111"
|
||||||
user2.save()
|
session.merge(user2)
|
||||||
|
session.commit()
|
||||||
# Delete user2
|
# Delete user2
|
||||||
user2.delete()
|
user2.delete()
|
||||||
"""
|
"""
|
||||||
@@ -189,28 +178,65 @@ class DatabaseManager:
|
|||||||
return self._engine is not None and self._session is not None
|
return self._engine is not None and self._session is not None
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def session(self) -> Session:
|
def session(self, commit: Optional[bool] = True) -> Session:
|
||||||
"""Get the session with context manager.
|
"""Get the session with context manager.
|
||||||
|
|
||||||
If raise any exception, the session will roll back automatically, otherwise, the session will commit automatically.
|
This context manager handles the lifecycle of a SQLAlchemy session.
|
||||||
|
It automatically commits or rolls back transactions based on
|
||||||
|
the execution and handles session closure.
|
||||||
|
|
||||||
Example:
|
The `commit` parameter controls whether the session should commit
|
||||||
>>> with db.session() as session:
|
changes at the end of the block. This is useful for separating
|
||||||
>>> session.query(...)
|
read and write operations.
|
||||||
|
|
||||||
Returns:
|
Examples:
|
||||||
Session: The session.
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
# For write operations (insert, update, delete):
|
||||||
|
with db.session() as session:
|
||||||
|
user = User(name="John Doe")
|
||||||
|
session.add(user)
|
||||||
|
# session.commit() is called automatically
|
||||||
|
|
||||||
|
# For read-only operations:
|
||||||
|
with db.session(commit=False) as session:
|
||||||
|
user = session.query(User).filter_by(name="John Doe").first()
|
||||||
|
# session.commit() is NOT called, as it's unnecessary for read operations
|
||||||
|
|
||||||
|
Args:
|
||||||
|
commit (Optional[bool], optional): Whether to commit the session.
|
||||||
|
If True (default), the session will commit changes at the end
|
||||||
|
of the block. Use False for read-only operations or when manual
|
||||||
|
control over commit is needed. Defaults to True.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Session: The SQLAlchemy session object.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: The database manager is not initialized.
|
RuntimeError: Raised if the database manager is not initialized.
|
||||||
Exception: Any exception.
|
Exception: Propagates any exception that occurred within the block.
|
||||||
|
|
||||||
|
Important Notes:
|
||||||
|
- DetachedInstanceError: This error occurs when trying to access or
|
||||||
|
modify an instance that has been detached from its session.
|
||||||
|
DetachedInstanceError can occur in scenarios where the session is
|
||||||
|
closed, and further interaction with the ORM object is attempted,
|
||||||
|
especially when accessing lazy-loaded attributes. To avoid this:
|
||||||
|
a. Ensure required attributes are loaded before session closure.
|
||||||
|
b. Avoid closing the session before all necessary interactions
|
||||||
|
with the ORM object are complete.
|
||||||
|
c. Re-bind the instance to a new session if further interaction
|
||||||
|
is required after the session is closed.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not self.is_initialized:
|
if not self.is_initialized:
|
||||||
raise RuntimeError("The database manager is not initialized.")
|
raise RuntimeError("The database manager is not initialized.")
|
||||||
session = self._session()
|
session = self._session()
|
||||||
try:
|
try:
|
||||||
yield session
|
yield session
|
||||||
session.commit()
|
if commit:
|
||||||
|
session.commit()
|
||||||
except:
|
except:
|
||||||
session.rollback()
|
session.rollback()
|
||||||
raise
|
raise
|
||||||
@@ -223,7 +249,7 @@ class DatabaseManager:
|
|||||||
"""Make the declarative base.
|
"""Make the declarative base.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
base (DeclarativeMeta): The base class.
|
model (DeclarativeMeta): The base class.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
DeclarativeMeta: The declarative base.
|
DeclarativeMeta: The declarative base.
|
||||||
@@ -232,7 +258,8 @@ class DatabaseManager:
|
|||||||
model = declarative_base(cls=model, name="Model")
|
model = declarative_base(cls=model, name="Model")
|
||||||
if not getattr(model, "query_class", None):
|
if not getattr(model, "query_class", None):
|
||||||
model.query_class = self.Query
|
model.query_class = self.Query
|
||||||
model.query = _QueryObject(self)
|
# model.query = _QueryObject()
|
||||||
|
model.__db_manager__ = self
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def init_db(
|
def init_db(
|
||||||
@@ -242,6 +269,7 @@ class DatabaseManager:
|
|||||||
base: Optional[DeclarativeMeta] = None,
|
base: Optional[DeclarativeMeta] = None,
|
||||||
query_class=BaseQuery,
|
query_class=BaseQuery,
|
||||||
override_query_class: Optional[bool] = False,
|
override_query_class: Optional[bool] = False,
|
||||||
|
session_options: Optional[Dict] = None,
|
||||||
):
|
):
|
||||||
"""Initialize the database manager.
|
"""Initialize the database manager.
|
||||||
|
|
||||||
@@ -251,18 +279,26 @@ class DatabaseManager:
|
|||||||
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
|
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
|
||||||
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
|
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
|
||||||
override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False.
|
override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False.
|
||||||
|
session_options (Optional[Dict], optional): The session options. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
if session_options is None:
|
||||||
|
session_options = {}
|
||||||
self._db_url = db_url
|
self._db_url = db_url
|
||||||
if query_class is not None:
|
if query_class is not None:
|
||||||
self.Query = query_class
|
self.Query = query_class
|
||||||
if base is not None:
|
if base is not None:
|
||||||
self._base = base
|
self._base = base
|
||||||
if not hasattr(base, "query") or override_query_class:
|
# if not hasattr(base, "query") or override_query_class:
|
||||||
base.query = _QueryObject(self)
|
# base.query = _QueryObject()
|
||||||
if not getattr(base, "query_class", None) or override_query_class:
|
if not getattr(base, "query_class", None) or override_query_class:
|
||||||
base.query_class = self.Query
|
base.query_class = self.Query
|
||||||
|
if not hasattr(base, "__db_manager__") or override_query_class:
|
||||||
|
base.__db_manager__ = self
|
||||||
self._engine = create_engine(db_url, **(engine_args or {}))
|
self._engine = create_engine(db_url, **(engine_args or {}))
|
||||||
session_factory = sessionmaker(bind=self._engine)
|
|
||||||
|
session_options.setdefault("class_", Session)
|
||||||
|
session_options.setdefault("query_cls", self.Query)
|
||||||
|
session_factory = sessionmaker(bind=self._engine, **session_options)
|
||||||
self._session = scoped_session(session_factory)
|
self._session = scoped_session(session_factory)
|
||||||
self._base.metadata.bind = self._engine
|
self._base.metadata.bind = self._engine
|
||||||
|
|
||||||
@@ -397,35 +433,12 @@ class BaseCRUDMixin(Generic[T]):
|
|||||||
__abstract__ = True
|
__abstract__ = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls: Type[T], **kwargs) -> T:
|
def db(cls) -> DatabaseManager:
|
||||||
instance = cls(**kwargs)
|
"""Get the database manager."""
|
||||||
return instance.save()
|
return cls.__db_manager__
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def all(cls: Type[T]) -> List[T]:
|
|
||||||
return cls.query.all()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]:
|
|
||||||
"""Get a record by its primary key identifier."""
|
|
||||||
|
|
||||||
def update(self: T, commit: Optional[bool] = True, **kwargs) -> T:
|
|
||||||
"""Update specific fields of a record."""
|
|
||||||
for attr, value in kwargs.items():
|
|
||||||
setattr(self, attr, value)
|
|
||||||
return commit and self.save() or self
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def save(self: T, commit: Optional[bool] = True) -> T:
|
|
||||||
"""Save the record."""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def delete(self: T, commit: Optional[bool] = True) -> None:
|
|
||||||
"""Remove the record from the database."""
|
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(BaseCRUDMixin[T], _Model, Generic[T]):
|
class BaseModel(BaseCRUDMixin[T], _Model, Generic[T]):
|
||||||
|
|
||||||
"""The base model class that includes CRUD convenience methods."""
|
"""The base model class that includes CRUD convenience methods."""
|
||||||
|
|
||||||
__abstract__ = True
|
__abstract__ = True
|
||||||
@@ -438,28 +451,14 @@ def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
|
|||||||
_db_manager: DatabaseManager = db_manager
|
_db_manager: DatabaseManager = db_manager
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_db_manager(cls, db_manager: DatabaseManager):
|
def set_db(cls, db_manager: DatabaseManager):
|
||||||
# TODO: It is hard to replace to user DB Connection
|
# TODO: It is hard to replace to user DB Connection
|
||||||
cls._db_manager = db_manager
|
cls._db_manager = db_manager
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]:
|
def db(cls) -> DatabaseManager:
|
||||||
"""Get a record by its primary key identifier."""
|
"""Get the database manager."""
|
||||||
return cls._db_manager._session().get(cls, ident)
|
return cls._db_manager
|
||||||
|
|
||||||
def save(self: T, commit: Optional[bool] = True) -> T:
|
|
||||||
"""Save the record."""
|
|
||||||
session = self._db_manager._session()
|
|
||||||
session.add(self)
|
|
||||||
if commit:
|
|
||||||
session.commit()
|
|
||||||
return self
|
|
||||||
|
|
||||||
def delete(self: T, commit: Optional[bool] = True) -> None:
|
|
||||||
"""Remove the record from the database."""
|
|
||||||
session = self._db_manager._session()
|
|
||||||
session.delete(self)
|
|
||||||
return commit and session.commit()
|
|
||||||
|
|
||||||
class _NewModel(CRUDMixin[T], db_manager.Model, Generic[T]):
|
class _NewModel(CRUDMixin[T], db_manager.Model, Generic[T]):
|
||||||
"""Base model class that includes CRUD convenience methods."""
|
"""Base model class that includes CRUD convenience methods."""
|
||||||
@@ -478,6 +477,7 @@ def initialize_db(
|
|||||||
engine_args: Optional[Dict] = None,
|
engine_args: Optional[Dict] = None,
|
||||||
base: Optional[DeclarativeMeta] = None,
|
base: Optional[DeclarativeMeta] = None,
|
||||||
try_to_create_db: Optional[bool] = False,
|
try_to_create_db: Optional[bool] = False,
|
||||||
|
session_options: Optional[Dict] = None,
|
||||||
) -> DatabaseManager:
|
) -> DatabaseManager:
|
||||||
"""Initialize the database manager.
|
"""Initialize the database manager.
|
||||||
|
|
||||||
@@ -487,10 +487,11 @@ def initialize_db(
|
|||||||
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
|
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
|
||||||
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
|
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
|
||||||
try_to_create_db (Optional[bool], optional): Whether to try to create the database. Defaults to False.
|
try_to_create_db (Optional[bool], optional): Whether to try to create the database. Defaults to False.
|
||||||
|
session_options (Optional[Dict], optional): The session options. Defaults to None.
|
||||||
Returns:
|
Returns:
|
||||||
DatabaseManager: The database manager.
|
DatabaseManager: The database manager.
|
||||||
"""
|
"""
|
||||||
db.init_db(db_url, engine_args, base)
|
db.init_db(db_url, engine_args, base, session_options=session_options)
|
||||||
if try_to_create_db:
|
if try_to_create_db:
|
||||||
try:
|
try:
|
||||||
db.create_all()
|
db.create_all()
|
||||||
|
@@ -100,7 +100,7 @@ def test_update_user(db: DatabaseManager, User: Type[BaseModel], user_dao, user_
|
|||||||
|
|
||||||
# Verify that the user is updated in the database
|
# Verify that the user is updated in the database
|
||||||
with db.session() as session:
|
with db.session() as session:
|
||||||
user = session.query(User).get(created_user_response.id)
|
user = session.get(User, created_user_response.id)
|
||||||
assert user.age == 35
|
assert user.age == 35
|
||||||
|
|
||||||
|
|
||||||
@@ -121,7 +121,7 @@ def test_update_user_partial(
|
|||||||
|
|
||||||
# Verify that the user is updated in the database
|
# Verify that the user is updated in the database
|
||||||
with db.session() as session:
|
with db.session() as session:
|
||||||
user = session.query(User).get(created_user_response.id)
|
user = session.get(User, created_user_response.id)
|
||||||
assert user.age == user_req.age
|
assert user.age == user_req.age
|
||||||
assert user.password == "newpassword"
|
assert user.password == "newpassword"
|
||||||
|
|
||||||
|
@@ -53,11 +53,10 @@ def test_crud_operations(db: DatabaseManager, Model: Type[BaseModel]):
|
|||||||
|
|
||||||
# Create
|
# Create
|
||||||
with db.session() as session:
|
with db.session() as session:
|
||||||
user = User.create(name="John Doe")
|
user = User(name="John Doe")
|
||||||
session.add(user)
|
session.add(user)
|
||||||
session.commit()
|
|
||||||
|
|
||||||
# Read
|
# # Read
|
||||||
with db.session() as session:
|
with db.session() as session:
|
||||||
user = session.query(User).filter_by(name="John Doe").first()
|
user = session.query(User).filter_by(name="John Doe").first()
|
||||||
assert user is not None
|
assert user is not None
|
||||||
@@ -65,12 +64,20 @@ def test_crud_operations(db: DatabaseManager, Model: Type[BaseModel]):
|
|||||||
# Update
|
# Update
|
||||||
with db.session() as session:
|
with db.session() as session:
|
||||||
user = session.query(User).filter_by(name="John Doe").first()
|
user = session.query(User).filter_by(name="John Doe").first()
|
||||||
user.update(name="Jane Doe")
|
user.name = "Mike Doe"
|
||||||
|
session.merge(user)
|
||||||
# Delete
|
|
||||||
with db.session() as session:
|
with db.session() as session:
|
||||||
user = session.query(User).filter_by(name="Jane Doe").first()
|
user = session.query(User).filter_by(name="Mike Doe").first()
|
||||||
user.delete()
|
assert user is not None
|
||||||
|
session.query(User).filter(User.name == "John Doe").first() is None
|
||||||
|
#
|
||||||
|
# # Delete
|
||||||
|
with db.session() as session:
|
||||||
|
user = session.query(User).filter_by(name="Mike Doe").first()
|
||||||
|
session.delete(user)
|
||||||
|
|
||||||
|
with db.session() as session:
|
||||||
|
assert len(session.query(User).all()) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_crud_mixins(db: DatabaseManager, Model: Type[BaseModel]):
|
def test_crud_mixins(db: DatabaseManager, Model: Type[BaseModel]):
|
||||||
@@ -80,20 +87,7 @@ def test_crud_mixins(db: DatabaseManager, Model: Type[BaseModel]):
|
|||||||
name = Column(String(50))
|
name = Column(String(50))
|
||||||
|
|
||||||
db.create_all()
|
db.create_all()
|
||||||
|
User.db() == db
|
||||||
# Create
|
|
||||||
user = User.create(name="John Doe")
|
|
||||||
assert User.get(user.id) is not None
|
|
||||||
users = User.all()
|
|
||||||
assert len(users) == 1
|
|
||||||
|
|
||||||
# Update
|
|
||||||
user.update(name="Bob Doe")
|
|
||||||
assert User.get(user.id).name == "Bob Doe"
|
|
||||||
|
|
||||||
user = User.get(user.id)
|
|
||||||
user.delete()
|
|
||||||
assert User.get(user.id) is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_pagination_query(db: DatabaseManager, Model: Type[BaseModel]):
|
def test_pagination_query(db: DatabaseManager, Model: Type[BaseModel]):
|
||||||
@@ -108,11 +102,10 @@ def test_pagination_query(db: DatabaseManager, Model: Type[BaseModel]):
|
|||||||
for i in range(30):
|
for i in range(30):
|
||||||
user = User(name=f"User {i}")
|
user = User(name=f"User {i}")
|
||||||
session.add(user)
|
session.add(user)
|
||||||
session.commit()
|
with db.session() as session:
|
||||||
|
users_page_1 = session.query(User).paginate_query(page=1, per_page=10)
|
||||||
users_page_1 = User.query.paginate_query(page=1, per_page=10)
|
assert len(users_page_1.items) == 10
|
||||||
assert len(users_page_1.items) == 10
|
assert users_page_1.total_pages == 3
|
||||||
assert users_page_1.total_pages == 3
|
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_pagination(db: DatabaseManager, Model: Type[BaseModel]):
|
def test_invalid_pagination(db: DatabaseManager, Model: Type[BaseModel]):
|
||||||
@@ -124,9 +117,11 @@ def test_invalid_pagination(db: DatabaseManager, Model: Type[BaseModel]):
|
|||||||
db.create_all()
|
db.create_all()
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
User.query.paginate_query(page=0, per_page=10)
|
with db.session() as session:
|
||||||
|
session.query(User).paginate_query(page=0, per_page=10)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
User.query.paginate_query(page=1, per_page=-1)
|
with db.session() as session:
|
||||||
|
session.query(User).paginate_query(page=1, per_page=-1)
|
||||||
|
|
||||||
|
|
||||||
def test_set_model_db_manager(db: DatabaseManager, Model: Type[BaseModel]):
|
def test_set_model_db_manager(db: DatabaseManager, Model: Type[BaseModel]):
|
||||||
@@ -142,14 +137,19 @@ def test_set_model_db_manager(db: DatabaseManager, Model: Type[BaseModel]):
|
|||||||
new_db = DatabaseManager.build_from(
|
new_db = DatabaseManager.build_from(
|
||||||
f"sqlite:///{filename}", base=Model, override_query_class=True
|
f"sqlite:///{filename}", base=Model, override_query_class=True
|
||||||
)
|
)
|
||||||
Model.set_db_manager(new_db)
|
Model.set_db(new_db)
|
||||||
new_db.create_all()
|
new_db.create_all()
|
||||||
db.create_all()
|
db.create_all()
|
||||||
assert list(new_db.metadata.tables.keys())[0] == "user"
|
assert list(new_db.metadata.tables.keys())[0] == "user"
|
||||||
User.create(**{"name": "John Doe"})
|
with new_db.session() as session:
|
||||||
|
user = User(name="John Doe")
|
||||||
|
session.add(user)
|
||||||
with new_db.session() as session:
|
with new_db.session() as session:
|
||||||
assert session.query(User).filter_by(name="John Doe").first() is not None
|
assert session.query(User).filter_by(name="John Doe").first() is not None
|
||||||
with db.session() as session:
|
with db.session() as session:
|
||||||
assert session.query(User).filter_by(name="John Doe").first() is None
|
assert session.query(User).filter_by(name="John Doe").first() is None
|
||||||
assert len(User.query.all()) == 1
|
with new_db.session() as session:
|
||||||
assert User.query.filter(User.name == "John Doe").first().name == "John Doe"
|
session.query(User).all() == 1
|
||||||
|
session.query(User).filter(
|
||||||
|
User.name == "John Doe"
|
||||||
|
).first().name == "John Doe"
|
||||||
|
@@ -7,7 +7,7 @@
|
|||||||
Call with non-streaming response.
|
Call with non-streaming response.
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
DBGPT_SERVER="http://127.0.0.1:5000"
|
DBGPT_SERVER="http://127.0.0.1:5555"
|
||||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/chat/completions \
|
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/chat/completions \
|
||||||
-H "Content-Type: application/json" -d '{
|
-H "Content-Type: application/json" -d '{
|
||||||
"model": "proxyllm",
|
"model": "proxyllm",
|
||||||
|
Reference in New Issue
Block a user