mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-05 02:51:07 +00:00
feat(model): Proxy model support count token (#996)
This commit is contained in:
@@ -189,7 +189,7 @@ class DefaultModelWorker(ModelWorker):
|
||||
return output
|
||||
|
||||
def count_token(self, prompt: str) -> int:
|
||||
return _try_to_count_token(prompt, self.tokenizer)
|
||||
return _try_to_count_token(prompt, self.tokenizer, self.model)
|
||||
|
||||
async def async_count_token(self, prompt: str) -> int:
|
||||
# TODO if we deploy the model by vllm, it can't work, we should run transformer _try_to_count_token to async
|
||||
@@ -454,12 +454,13 @@ def _new_metrics_from_model_output(
|
||||
return metrics
|
||||
|
||||
|
||||
def _try_to_count_token(prompt: str, tokenizer) -> int:
|
||||
def _try_to_count_token(prompt: str, tokenizer, model) -> int:
|
||||
"""Try to count token of prompt
|
||||
|
||||
Args:
|
||||
prompt (str): prompt
|
||||
tokenizer ([type]): tokenizer
|
||||
model ([type]): model
|
||||
|
||||
Returns:
|
||||
int: token count, if error return -1
|
||||
@@ -467,6 +468,11 @@ def _try_to_count_token(prompt: str, tokenizer) -> int:
|
||||
TODO: More implementation
|
||||
"""
|
||||
try:
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
|
||||
if isinstance(model, ProxyModel):
|
||||
return model.count_token(prompt)
|
||||
# Only support huggingface model now
|
||||
return len(tokenizer(prompt).input_ids[0])
|
||||
except Exception as e:
|
||||
logger.warning(f"Count token error, detail: {e}, return -1")
|
||||
|
@@ -197,7 +197,7 @@ class LocalWorkerManager(WorkerManager):
|
||||
return True
|
||||
else:
|
||||
# TODO Update worker
|
||||
logger.warn(f"Instance {worker_key} exist")
|
||||
logger.warning(f"Instance {worker_key} exist")
|
||||
return False
|
||||
|
||||
def _remove_worker(self, worker_params: ModelWorkerParameters) -> None:
|
||||
@@ -229,7 +229,7 @@ class LocalWorkerManager(WorkerManager):
|
||||
)
|
||||
if not success:
|
||||
msg = f"Add worker {model_name}@{worker_type}, worker instances is exist"
|
||||
logger.warn(f"{msg}, worker_params: {worker_params}")
|
||||
logger.warning(f"{msg}, worker_params: {worker_params}")
|
||||
self._remove_worker(worker_params)
|
||||
raise Exception(msg)
|
||||
supported_types = WorkerType.values()
|
||||
|
@@ -92,11 +92,11 @@ def _initialize_openai_v1(params: ProxyModelParameters):
|
||||
|
||||
|
||||
def __convert_2_gpt_messages(messages: List[ModelMessage]):
|
||||
chat_round = 0
|
||||
gpt_messages = []
|
||||
last_usr_message = ""
|
||||
system_messages = []
|
||||
|
||||
# TODO: We can't change message order in low level
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.HUMAN or message.role == "user":
|
||||
last_usr_message = message.content
|
||||
|
@@ -1,9 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union, List, Optional, TYPE_CHECKING
|
||||
import logging
|
||||
from dbgpt.model.parameter import ProxyModelParameters
|
||||
from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dbgpt.core.interface.message import ModelMessage, BaseMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProxyModel:
|
||||
def __init__(self, model_params: ProxyModelParameters) -> None:
|
||||
self._model_params = model_params
|
||||
self._tokenizer = ProxyTokenizerWrapper()
|
||||
|
||||
def get_params(self) -> ProxyModelParameters:
|
||||
return self._model_params
|
||||
|
||||
def count_token(
|
||||
self,
|
||||
messages: Union[str, BaseMessage, ModelMessage, List[ModelMessage]],
|
||||
model_name: Optional[int] = None,
|
||||
) -> int:
|
||||
"""Count token of given messages
|
||||
|
||||
Args:
|
||||
messages (Union[str, BaseMessage, ModelMessage, List[ModelMessage]]): messages to count token
|
||||
model_name (Optional[int], optional): model name. Defaults to None.
|
||||
|
||||
Returns:
|
||||
int: token count, -1 if failed
|
||||
"""
|
||||
return self._tokenizer.count_token(messages, model_name)
|
||||
|
@@ -25,6 +25,7 @@ from dbgpt.core.interface.llm import ModelOutput, ModelRequest
|
||||
from dbgpt.model.cluster.client import DefaultLLMClient
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt._private.pydantic import model_to_json
|
||||
from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import httpx
|
||||
@@ -152,6 +153,7 @@ class OpenAILLMClient(LLMClient):
|
||||
self._context_length = context_length
|
||||
self._client = openai_client
|
||||
self._openai_kwargs = openai_kwargs or {}
|
||||
self._tokenizer = ProxyTokenizerWrapper()
|
||||
|
||||
@property
|
||||
def client(self) -> ClientType:
|
||||
@@ -238,10 +240,11 @@ class OpenAILLMClient(LLMClient):
|
||||
async def count_token(self, model: str, prompt: str) -> int:
|
||||
"""Count the number of tokens in a given prompt.
|
||||
|
||||
TODO: Get the real number of tokens from the openai api or tiktoken package
|
||||
Args:
|
||||
model (str): The model name.
|
||||
prompt (str): The prompt.
|
||||
"""
|
||||
|
||||
raise NotImplementedError()
|
||||
return self._tokenizer.count_token(prompt, model)
|
||||
|
||||
|
||||
class OpenAIStreamingOperator(TransformStreamAbsOperator[ModelOutput, str]):
|
||||
|
80
dbgpt/model/utils/token_utils.py
Normal file
80
dbgpt/model/utils/token_utils.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union, List, Optional, TYPE_CHECKING
|
||||
import logging
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dbgpt.core.interface.message import ModelMessage, BaseMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProxyTokenizerWrapper:
|
||||
def __init__(self) -> None:
|
||||
self._support_encoding = True
|
||||
self._encoding_model = None
|
||||
|
||||
def count_token(
|
||||
self,
|
||||
messages: Union[str, BaseMessage, ModelMessage, List[ModelMessage]],
|
||||
model_name: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Count token of given messages
|
||||
|
||||
Args:
|
||||
messages (Union[str, BaseMessage, ModelMessage, List[ModelMessage]]): messages to count token
|
||||
model_name (Optional[str], optional): model name. Defaults to None.
|
||||
|
||||
Returns:
|
||||
int: token count, -1 if failed
|
||||
"""
|
||||
if not self._support_encoding:
|
||||
logger.warning(
|
||||
"model does not support encoding model, can't count token, returning -1"
|
||||
)
|
||||
return -1
|
||||
encoding = self._get_or_create_encoding_model(model_name)
|
||||
cnt = 0
|
||||
if isinstance(messages, str):
|
||||
cnt = len(encoding.encode(messages, disallowed_special=()))
|
||||
elif isinstance(messages, BaseMessage):
|
||||
cnt = len(encoding.encode(messages.content, disallowed_special=()))
|
||||
elif isinstance(messages, ModelMessage):
|
||||
cnt = len(encoding.encode(messages.content, disallowed_special=()))
|
||||
elif isinstance(messages, list):
|
||||
for message in messages:
|
||||
cnt += len(encoding.encode(message.content, disallowed_special=()))
|
||||
else:
|
||||
logger.warning(
|
||||
"unsupported type of messages, can't count token, returning -1"
|
||||
)
|
||||
return -1
|
||||
return cnt
|
||||
|
||||
def _get_or_create_encoding_model(self, model_name: Optional[str] = None):
|
||||
"""Get or create encoding model for given model name
|
||||
More detail see: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
"""
|
||||
if self._encoding_model:
|
||||
return self._encoding_model
|
||||
try:
|
||||
import tiktoken
|
||||
|
||||
logger.info(
|
||||
"tiktoken installed, using it to count tokens, tiktoken will download tokenizer from network, "
|
||||
"also you can download it and put it in the directory of environment variable TIKTOKEN_CACHE_DIR"
|
||||
)
|
||||
except ImportError:
|
||||
self._support_encoding = False
|
||||
logger.warn("tiktoken not installed, cannot count tokens, returning -1")
|
||||
return -1
|
||||
try:
|
||||
if not model_name:
|
||||
model_name = "gpt-3.5-turbo"
|
||||
self._encoding_model = tiktoken.model.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
logger.warning(
|
||||
f"{model_name}'s tokenizer not found, using cl100k_base encoding."
|
||||
)
|
||||
self._encoding_model = tiktoken.get_encoding("cl100k_base")
|
||||
return self._encoding_model
|
@@ -1,5 +1,3 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.storage.metadata import db
|
||||
@@ -39,11 +37,9 @@ def test_table_exist():
|
||||
|
||||
|
||||
def test_entity_create(default_entity_dict):
|
||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
||||
# TODO: implement your test case
|
||||
with db.session() as session:
|
||||
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
|
||||
assert db_entity.id == entity.id
|
||||
entity = ServeEntity(**default_entity_dict)
|
||||
session.add(entity)
|
||||
|
||||
|
||||
def test_entity_unique_key(default_entity_dict):
|
||||
@@ -52,10 +48,8 @@ def test_entity_unique_key(default_entity_dict):
|
||||
|
||||
|
||||
def test_entity_get(default_entity_dict):
|
||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
||||
db_entity: ServeEntity = ServeEntity.get(entity.id)
|
||||
assert db_entity.id == entity.id
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_entity_update(default_entity_dict):
|
||||
@@ -65,10 +59,7 @@ def test_entity_update(default_entity_dict):
|
||||
|
||||
def test_entity_delete(default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
||||
entity.delete()
|
||||
db_entity: ServeEntity = ServeEntity.get(entity.id)
|
||||
assert db_entity is None
|
||||
pass
|
||||
|
||||
|
||||
def test_entity_all():
|
||||
|
@@ -47,9 +47,11 @@ def test_table_exist():
|
||||
|
||||
|
||||
def test_entity_create(default_entity_dict):
|
||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
||||
with db.session() as session:
|
||||
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
|
||||
entity: ServeEntity = ServeEntity(**default_entity_dict)
|
||||
session.add(entity)
|
||||
session.commit()
|
||||
db_entity: ServeEntity = session.get(ServeEntity, entity.id)
|
||||
assert db_entity.id == entity.id
|
||||
assert db_entity.chat_scene == "chat_data"
|
||||
assert db_entity.sub_chat_scene == "excel"
|
||||
@@ -63,78 +65,96 @@ def test_entity_create(default_entity_dict):
|
||||
|
||||
|
||||
def test_entity_unique_key(default_entity_dict):
|
||||
ServeEntity.create(**default_entity_dict)
|
||||
with db.session() as session:
|
||||
entity = ServeEntity(**default_entity_dict)
|
||||
session.add(entity)
|
||||
with pytest.raises(Exception):
|
||||
ServeEntity.create(
|
||||
**{
|
||||
"prompt_name": "my_prompt_1",
|
||||
"sys_code": "dbgpt",
|
||||
"prompt_language": "zh",
|
||||
"model": "vicuna-13b-v1.5",
|
||||
}
|
||||
)
|
||||
with db.session() as session:
|
||||
entity = ServeEntity(
|
||||
**{
|
||||
"prompt_name": "my_prompt_1",
|
||||
"sys_code": "dbgpt",
|
||||
"prompt_language": "zh",
|
||||
"model": "vicuna-13b-v1.5",
|
||||
}
|
||||
)
|
||||
session.add(entity)
|
||||
|
||||
|
||||
def test_entity_get(default_entity_dict):
|
||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
||||
db_entity: ServeEntity = ServeEntity.get(entity.id)
|
||||
assert db_entity.id == entity.id
|
||||
assert db_entity.chat_scene == "chat_data"
|
||||
assert db_entity.sub_chat_scene == "excel"
|
||||
assert db_entity.prompt_type == "common"
|
||||
assert db_entity.prompt_name == "my_prompt_1"
|
||||
assert db_entity.content == "Write a qsort function in python."
|
||||
assert db_entity.user_name == "zhangsan"
|
||||
assert db_entity.sys_code == "dbgpt"
|
||||
assert db_entity.gmt_created is not None
|
||||
assert db_entity.gmt_modified is not None
|
||||
with db.session() as session:
|
||||
entity = ServeEntity(**default_entity_dict)
|
||||
session.add(entity)
|
||||
session.commit()
|
||||
db_entity: ServeEntity = session.get(ServeEntity, entity.id)
|
||||
assert db_entity.id == entity.id
|
||||
assert db_entity.chat_scene == "chat_data"
|
||||
assert db_entity.sub_chat_scene == "excel"
|
||||
assert db_entity.prompt_type == "common"
|
||||
assert db_entity.prompt_name == "my_prompt_1"
|
||||
assert db_entity.content == "Write a qsort function in python."
|
||||
assert db_entity.user_name == "zhangsan"
|
||||
assert db_entity.sys_code == "dbgpt"
|
||||
assert db_entity.gmt_created is not None
|
||||
assert db_entity.gmt_modified is not None
|
||||
|
||||
|
||||
def test_entity_update(default_entity_dict):
|
||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
||||
entity.update(prompt_name="my_prompt_2")
|
||||
db_entity: ServeEntity = ServeEntity.get(entity.id)
|
||||
assert db_entity.id == entity.id
|
||||
assert db_entity.chat_scene == "chat_data"
|
||||
assert db_entity.sub_chat_scene == "excel"
|
||||
assert db_entity.prompt_type == "common"
|
||||
assert db_entity.prompt_name == "my_prompt_2"
|
||||
assert db_entity.content == "Write a qsort function in python."
|
||||
assert db_entity.user_name == "zhangsan"
|
||||
assert db_entity.sys_code == "dbgpt"
|
||||
assert db_entity.gmt_created is not None
|
||||
assert db_entity.gmt_modified is not None
|
||||
with db.session() as session:
|
||||
entity = ServeEntity(**default_entity_dict)
|
||||
session.add(entity)
|
||||
session.commit()
|
||||
entity.prompt_name = "my_prompt_2"
|
||||
session.merge(entity)
|
||||
db_entity: ServeEntity = session.get(ServeEntity, entity.id)
|
||||
assert db_entity.id == entity.id
|
||||
assert db_entity.chat_scene == "chat_data"
|
||||
assert db_entity.sub_chat_scene == "excel"
|
||||
assert db_entity.prompt_type == "common"
|
||||
assert db_entity.prompt_name == "my_prompt_2"
|
||||
assert db_entity.content == "Write a qsort function in python."
|
||||
assert db_entity.user_name == "zhangsan"
|
||||
assert db_entity.sys_code == "dbgpt"
|
||||
assert db_entity.gmt_created is not None
|
||||
assert db_entity.gmt_modified is not None
|
||||
|
||||
|
||||
def test_entity_delete(default_entity_dict):
|
||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
||||
entity.delete()
|
||||
db_entity: ServeEntity = ServeEntity.get(entity.id)
|
||||
assert db_entity is None
|
||||
with db.session() as session:
|
||||
entity = ServeEntity(**default_entity_dict)
|
||||
session.add(entity)
|
||||
session.commit()
|
||||
session.delete(entity)
|
||||
session.commit()
|
||||
db_entity: ServeEntity = session.get(ServeEntity, entity.id)
|
||||
assert db_entity is None
|
||||
|
||||
|
||||
def test_entity_all():
|
||||
for i in range(10):
|
||||
ServeEntity.create(
|
||||
chat_scene="chat_data",
|
||||
sub_chat_scene="excel",
|
||||
prompt_type="common",
|
||||
prompt_name=f"my_prompt_{i}",
|
||||
content="Write a qsort function in python.",
|
||||
user_name="zhangsan",
|
||||
sys_code="dbgpt",
|
||||
)
|
||||
entities = ServeEntity.all()
|
||||
assert len(entities) == 10
|
||||
for entity in entities:
|
||||
assert entity.chat_scene == "chat_data"
|
||||
assert entity.sub_chat_scene == "excel"
|
||||
assert entity.prompt_type == "common"
|
||||
assert entity.content == "Write a qsort function in python."
|
||||
assert entity.user_name == "zhangsan"
|
||||
assert entity.sys_code == "dbgpt"
|
||||
assert entity.gmt_created is not None
|
||||
assert entity.gmt_modified is not None
|
||||
with db.session() as session:
|
||||
for i in range(10):
|
||||
entity = ServeEntity(
|
||||
chat_scene="chat_data",
|
||||
sub_chat_scene="excel",
|
||||
prompt_type="common",
|
||||
prompt_name=f"my_prompt_{i}",
|
||||
content="Write a qsort function in python.",
|
||||
user_name="zhangsan",
|
||||
sys_code="dbgpt",
|
||||
)
|
||||
session.add(entity)
|
||||
with db.session() as session:
|
||||
entities = session.query(ServeEntity).all()
|
||||
assert len(entities) == 10
|
||||
for entity in entities:
|
||||
assert entity.chat_scene == "chat_data"
|
||||
assert entity.sub_chat_scene == "excel"
|
||||
assert entity.prompt_type == "common"
|
||||
assert entity.content == "Write a qsort function in python."
|
||||
assert entity.user_name == "zhangsan"
|
||||
assert entity.sys_code == "dbgpt"
|
||||
assert entity.gmt_created is not None
|
||||
assert entity.gmt_modified is not None
|
||||
|
||||
|
||||
def test_dao_create(dao, default_entity_dict):
|
||||
|
@@ -75,7 +75,7 @@ def test_config_default_user(service: Service):
|
||||
def test_service_create(service: Service, default_entity_dict):
|
||||
entity: ServerResponse = service.create(ServeRequest(**default_entity_dict))
|
||||
with db.session() as session:
|
||||
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
|
||||
db_entity: ServeEntity = session.get(ServeEntity, entity.id)
|
||||
assert db_entity.id == entity.id
|
||||
assert db_entity.chat_scene == "chat_data"
|
||||
assert db_entity.sub_chat_scene == "excel"
|
||||
@@ -92,7 +92,7 @@ def test_service_update(service: Service, default_entity_dict):
|
||||
service.create(ServeRequest(**default_entity_dict))
|
||||
entity: ServerResponse = service.update(ServeRequest(**default_entity_dict))
|
||||
with db.session() as session:
|
||||
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
|
||||
db_entity: ServeEntity = session.get(ServeEntity, entity.id)
|
||||
assert db_entity.id == entity.id
|
||||
assert db_entity.chat_scene == "chat_data"
|
||||
assert db_entity.sub_chat_scene == "excel"
|
||||
@@ -109,7 +109,7 @@ def test_service_get(service: Service, default_entity_dict):
|
||||
service.create(ServeRequest(**default_entity_dict))
|
||||
entity: ServerResponse = service.get(ServeRequest(**default_entity_dict))
|
||||
with db.session() as session:
|
||||
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
|
||||
db_entity: ServeEntity = session.get(ServeEntity, entity.id)
|
||||
assert db_entity.id == entity.id
|
||||
assert db_entity.chat_scene == "chat_data"
|
||||
assert db_entity.sub_chat_scene == "excel"
|
||||
|
@@ -1,5 +1,3 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.storage.metadata import db
|
||||
@@ -39,11 +37,9 @@ def test_table_exist():
|
||||
|
||||
|
||||
def test_entity_create(default_entity_dict):
|
||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
||||
# TODO: implement your test case
|
||||
with db.session() as session:
|
||||
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
|
||||
assert db_entity.id == entity.id
|
||||
entity = ServeEntity(**default_entity_dict)
|
||||
session.add(entity)
|
||||
|
||||
|
||||
def test_entity_unique_key(default_entity_dict):
|
||||
@@ -52,10 +48,8 @@ def test_entity_unique_key(default_entity_dict):
|
||||
|
||||
|
||||
def test_entity_get(default_entity_dict):
|
||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
||||
db_entity: ServeEntity = ServeEntity.get(entity.id)
|
||||
assert db_entity.id == entity.id
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_entity_update(default_entity_dict):
|
||||
@@ -65,10 +59,7 @@ def test_entity_update(default_entity_dict):
|
||||
|
||||
def test_entity_delete(default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
||||
entity.delete()
|
||||
db_entity: ServeEntity = ServeEntity.get(entity.id)
|
||||
assert db_entity is None
|
||||
pass
|
||||
|
||||
|
||||
def test_entity_all():
|
||||
|
@@ -105,12 +105,6 @@ class ChatHistoryDao(BaseDao):
|
||||
chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid)
|
||||
chat_history.delete()
|
||||
|
||||
def get_by_uid(self, conv_uid: str) -> ChatHistoryEntity:
|
||||
# return ChatHistoryEntity.query.filter_by(conv_uid=conv_uid).first()
|
||||
|
||||
session = self.get_raw_session()
|
||||
chat_history = session.query(ChatHistoryEntity)
|
||||
chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid)
|
||||
result = chat_history.first()
|
||||
session.close()
|
||||
return result
|
||||
def get_by_uid(self, conv_uid: str) -> Optional[ChatHistoryEntity]:
|
||||
with self.session(commit=False) as session:
|
||||
return session.query(ChatHistoryEntity).filter_by(conv_uid=conv_uid).first()
|
||||
|
@@ -51,7 +51,9 @@ class BaseDao(Generic[T, REQ, RES]):
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
user = User(name="Edward Snowden")
|
||||
session = self.get_raw_session()
|
||||
session.add(user)
|
||||
@@ -61,7 +63,7 @@ class BaseDao(Generic[T, REQ, RES]):
|
||||
return self._db_manager._session()
|
||||
|
||||
@contextmanager
|
||||
def session(self) -> Session:
|
||||
def session(self, commit: Optional[bool] = True) -> Session:
|
||||
"""Provide a transactional scope around a series of operations.
|
||||
|
||||
If raise an exception, the session will be roll back automatically, otherwise it will be committed.
|
||||
@@ -71,13 +73,16 @@ class BaseDao(Generic[T, REQ, RES]):
|
||||
with self.session() as session:
|
||||
session.query(User).filter(User.name == 'Edward Snowden').first()
|
||||
|
||||
Args:
|
||||
commit (Optional[bool], optional): Whether to commit the session. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Session: A session object.
|
||||
|
||||
Raises:
|
||||
Exception: Any exception will be raised.
|
||||
"""
|
||||
with self._db_manager.session() as session:
|
||||
with self._db_manager.session(commit=commit) as session:
|
||||
yield session
|
||||
|
||||
def from_request(self, request: QUERY_SPEC) -> T:
|
||||
|
@@ -1,8 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from contextlib import contextmanager
|
||||
from typing import TypeVar, Generic, Union, Dict, Optional, Type, Iterator, List
|
||||
from typing import (
|
||||
TypeVar,
|
||||
Generic,
|
||||
Union,
|
||||
Dict,
|
||||
Optional,
|
||||
Type,
|
||||
ClassVar,
|
||||
)
|
||||
import logging
|
||||
from sqlalchemy import create_engine, URL, Engine
|
||||
from sqlalchemy import orm, inspect, MetaData
|
||||
@@ -13,8 +20,6 @@ from sqlalchemy.orm import (
|
||||
declarative_base,
|
||||
DeclarativeMeta,
|
||||
)
|
||||
from sqlalchemy.orm.session import _PKIdentityArgument
|
||||
from sqlalchemy.orm.exc import UnmappedClassError
|
||||
|
||||
from sqlalchemy.pool import QueuePool
|
||||
from dbgpt.util.string_utils import _to_str
|
||||
@@ -27,16 +32,10 @@ T = TypeVar("T", bound="BaseModel")
|
||||
class _QueryObject:
|
||||
"""The query object."""
|
||||
|
||||
def __init__(self, db_manager: "DatabaseManager"):
|
||||
self._db_manager = db_manager
|
||||
|
||||
def __get__(self, obj, type):
|
||||
try:
|
||||
mapper = orm.class_mapper(type)
|
||||
if mapper:
|
||||
return type.query_class(mapper, session=self._db_manager._session())
|
||||
except UnmappedClassError:
|
||||
return None
|
||||
def __get__(self, obj: Union[_Model, None], model_cls: type[_Model]):
|
||||
return model_cls.query_class(
|
||||
model_cls, session=model_cls.__db_manager__._session()
|
||||
)
|
||||
|
||||
|
||||
class BaseQuery(orm.Query):
|
||||
@@ -46,7 +45,9 @@ class BaseQuery(orm.Query):
|
||||
"""Paginate the query.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.storage.metadata import db, Model
|
||||
class User(Model):
|
||||
__tablename__ = "user"
|
||||
@@ -58,10 +59,6 @@ class BaseQuery(orm.Query):
|
||||
pagination = session.query(User).paginate_query(page=1, page_size=10)
|
||||
print(pagination)
|
||||
|
||||
# Or you can use the query object
|
||||
with db.session() as session:
|
||||
pagination = User.query.paginate_query(page=1, page_size=10)
|
||||
print(pagination)
|
||||
|
||||
Args:
|
||||
page (Optional[int], optional): The page number. Defaults to 1.
|
||||
@@ -86,26 +83,12 @@ class BaseQuery(orm.Query):
|
||||
|
||||
|
||||
class _Model:
|
||||
"""Base class for SQLAlchemy declarative base model.
|
||||
"""Base class for SQLAlchemy declarative base model."""
|
||||
|
||||
With this class, we can use the query object to query the database.
|
||||
__db_manager__: ClassVar[DatabaseManager]
|
||||
query_class = BaseQuery
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
from dbgpt.storage.metadata import db, Model
|
||||
class User(Model):
|
||||
__tablename__ = "user"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
fullname = Column(String(50))
|
||||
|
||||
with db.session() as session:
|
||||
# User is an instance of _Model, and we can use the query object to query the database.
|
||||
User.query.filter(User.name == "test").all()
|
||||
"""
|
||||
|
||||
query_class = None
|
||||
query: Optional[BaseQuery] = None
|
||||
# query: Optional[BaseQuery] = _QueryObject()
|
||||
|
||||
def __repr__(self):
|
||||
identity = inspect(self).identity
|
||||
@@ -120,7 +103,9 @@ class DatabaseManager:
|
||||
"""The database manager.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from urllib.parse import quote_plus as urlquote, quote
|
||||
from dbgpt.storage.metadata import DatabaseManager, create_model
|
||||
db = DatabaseManager()
|
||||
@@ -141,21 +126,25 @@ class DatabaseManager:
|
||||
session.add(User(name="test", fullname="test"))
|
||||
# db will commit the session automatically default.
|
||||
# session.commit()
|
||||
print(User.query.filter(User.name == "test").all())
|
||||
assert session.query(User).filter(User.name == "test").first().name == "test"
|
||||
|
||||
|
||||
# Use CURDMixin APIs to create, update, delete, query the database.
|
||||
# More usage:
|
||||
|
||||
with db.session() as session:
|
||||
User.create(**{"name": "test1", "fullname": "test1"})
|
||||
User.create(**{"name": "test2", "fullname": "test1"})
|
||||
users = User.all()
|
||||
session.add(User(name="test1", fullname="test1"))
|
||||
session.add(User(name="test2", fullname="test1"))
|
||||
users = session.query(User).all()
|
||||
print(users)
|
||||
user = users[0]
|
||||
user.update(**{"name": "test1_1111"})
|
||||
user.name = "test1_1111"
|
||||
session.merge(user)
|
||||
|
||||
user2 = users[1]
|
||||
# Update user2 by save
|
||||
user2.name = "test2_1111"
|
||||
user2.save()
|
||||
session.merge(user2)
|
||||
session.commit()
|
||||
# Delete user2
|
||||
user2.delete()
|
||||
"""
|
||||
@@ -189,28 +178,65 @@ class DatabaseManager:
|
||||
return self._engine is not None and self._session is not None
|
||||
|
||||
@contextmanager
|
||||
def session(self) -> Session:
|
||||
def session(self, commit: Optional[bool] = True) -> Session:
|
||||
"""Get the session with context manager.
|
||||
|
||||
If raise any exception, the session will roll back automatically, otherwise, the session will commit automatically.
|
||||
This context manager handles the lifecycle of a SQLAlchemy session.
|
||||
It automatically commits or rolls back transactions based on
|
||||
the execution and handles session closure.
|
||||
|
||||
Example:
|
||||
>>> with db.session() as session:
|
||||
>>> session.query(...)
|
||||
The `commit` parameter controls whether the session should commit
|
||||
changes at the end of the block. This is useful for separating
|
||||
read and write operations.
|
||||
|
||||
Returns:
|
||||
Session: The session.
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# For write operations (insert, update, delete):
|
||||
with db.session() as session:
|
||||
user = User(name="John Doe")
|
||||
session.add(user)
|
||||
# session.commit() is called automatically
|
||||
|
||||
# For read-only operations:
|
||||
with db.session(commit=False) as session:
|
||||
user = session.query(User).filter_by(name="John Doe").first()
|
||||
# session.commit() is NOT called, as it's unnecessary for read operations
|
||||
|
||||
Args:
|
||||
commit (Optional[bool], optional): Whether to commit the session.
|
||||
If True (default), the session will commit changes at the end
|
||||
of the block. Use False for read-only operations or when manual
|
||||
control over commit is needed. Defaults to True.
|
||||
|
||||
Yields:
|
||||
Session: The SQLAlchemy session object.
|
||||
|
||||
Raises:
|
||||
RuntimeError: The database manager is not initialized.
|
||||
Exception: Any exception.
|
||||
RuntimeError: Raised if the database manager is not initialized.
|
||||
Exception: Propagates any exception that occurred within the block.
|
||||
|
||||
Important Notes:
|
||||
- DetachedInstanceError: This error occurs when trying to access or
|
||||
modify an instance that has been detached from its session.
|
||||
DetachedInstanceError can occur in scenarios where the session is
|
||||
closed, and further interaction with the ORM object is attempted,
|
||||
especially when accessing lazy-loaded attributes. To avoid this:
|
||||
a. Ensure required attributes are loaded before session closure.
|
||||
b. Avoid closing the session before all necessary interactions
|
||||
with the ORM object are complete.
|
||||
c. Re-bind the instance to a new session if further interaction
|
||||
is required after the session is closed.
|
||||
|
||||
"""
|
||||
if not self.is_initialized:
|
||||
raise RuntimeError("The database manager is not initialized.")
|
||||
session = self._session()
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
if commit:
|
||||
session.commit()
|
||||
except:
|
||||
session.rollback()
|
||||
raise
|
||||
@@ -223,7 +249,7 @@ class DatabaseManager:
|
||||
"""Make the declarative base.
|
||||
|
||||
Args:
|
||||
base (DeclarativeMeta): The base class.
|
||||
model (DeclarativeMeta): The base class.
|
||||
|
||||
Returns:
|
||||
DeclarativeMeta: The declarative base.
|
||||
@@ -232,7 +258,8 @@ class DatabaseManager:
|
||||
model = declarative_base(cls=model, name="Model")
|
||||
if not getattr(model, "query_class", None):
|
||||
model.query_class = self.Query
|
||||
model.query = _QueryObject(self)
|
||||
# model.query = _QueryObject()
|
||||
model.__db_manager__ = self
|
||||
return model
|
||||
|
||||
def init_db(
|
||||
@@ -242,6 +269,7 @@ class DatabaseManager:
|
||||
base: Optional[DeclarativeMeta] = None,
|
||||
query_class=BaseQuery,
|
||||
override_query_class: Optional[bool] = False,
|
||||
session_options: Optional[Dict] = None,
|
||||
):
|
||||
"""Initialize the database manager.
|
||||
|
||||
@@ -251,18 +279,26 @@ class DatabaseManager:
|
||||
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
|
||||
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
|
||||
override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False.
|
||||
session_options (Optional[Dict], optional): The session options. Defaults to None.
|
||||
"""
|
||||
if session_options is None:
|
||||
session_options = {}
|
||||
self._db_url = db_url
|
||||
if query_class is not None:
|
||||
self.Query = query_class
|
||||
if base is not None:
|
||||
self._base = base
|
||||
if not hasattr(base, "query") or override_query_class:
|
||||
base.query = _QueryObject(self)
|
||||
# if not hasattr(base, "query") or override_query_class:
|
||||
# base.query = _QueryObject()
|
||||
if not getattr(base, "query_class", None) or override_query_class:
|
||||
base.query_class = self.Query
|
||||
if not hasattr(base, "__db_manager__") or override_query_class:
|
||||
base.__db_manager__ = self
|
||||
self._engine = create_engine(db_url, **(engine_args or {}))
|
||||
session_factory = sessionmaker(bind=self._engine)
|
||||
|
||||
session_options.setdefault("class_", Session)
|
||||
session_options.setdefault("query_cls", self.Query)
|
||||
session_factory = sessionmaker(bind=self._engine, **session_options)
|
||||
self._session = scoped_session(session_factory)
|
||||
self._base.metadata.bind = self._engine
|
||||
|
||||
@@ -397,35 +433,12 @@ class BaseCRUDMixin(Generic[T]):
|
||||
__abstract__ = True
|
||||
|
||||
@classmethod
|
||||
def create(cls: Type[T], **kwargs) -> T:
|
||||
instance = cls(**kwargs)
|
||||
return instance.save()
|
||||
|
||||
@classmethod
|
||||
def all(cls: Type[T]) -> List[T]:
|
||||
return cls.query.all()
|
||||
|
||||
@classmethod
|
||||
def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]:
|
||||
"""Get a record by its primary key identifier."""
|
||||
|
||||
def update(self: T, commit: Optional[bool] = True, **kwargs) -> T:
|
||||
"""Update specific fields of a record."""
|
||||
for attr, value in kwargs.items():
|
||||
setattr(self, attr, value)
|
||||
return commit and self.save() or self
|
||||
|
||||
@abc.abstractmethod
|
||||
def save(self: T, commit: Optional[bool] = True) -> T:
|
||||
"""Save the record."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete(self: T, commit: Optional[bool] = True) -> None:
|
||||
"""Remove the record from the database."""
|
||||
def db(cls) -> DatabaseManager:
|
||||
"""Get the database manager."""
|
||||
return cls.__db_manager__
|
||||
|
||||
|
||||
class BaseModel(BaseCRUDMixin[T], _Model, Generic[T]):
|
||||
|
||||
"""The base model class that includes CRUD convenience methods."""
|
||||
|
||||
__abstract__ = True
|
||||
@@ -438,28 +451,14 @@ def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
|
||||
_db_manager: DatabaseManager = db_manager
|
||||
|
||||
@classmethod
|
||||
def set_db_manager(cls, db_manager: DatabaseManager):
|
||||
def set_db(cls, db_manager: DatabaseManager):
|
||||
# TODO: It is hard to replace to user DB Connection
|
||||
cls._db_manager = db_manager
|
||||
|
||||
@classmethod
|
||||
def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]:
|
||||
"""Get a record by its primary key identifier."""
|
||||
return cls._db_manager._session().get(cls, ident)
|
||||
|
||||
def save(self: T, commit: Optional[bool] = True) -> T:
|
||||
"""Save the record."""
|
||||
session = self._db_manager._session()
|
||||
session.add(self)
|
||||
if commit:
|
||||
session.commit()
|
||||
return self
|
||||
|
||||
def delete(self: T, commit: Optional[bool] = True) -> None:
|
||||
"""Remove the record from the database."""
|
||||
session = self._db_manager._session()
|
||||
session.delete(self)
|
||||
return commit and session.commit()
|
||||
def db(cls) -> DatabaseManager:
|
||||
"""Get the database manager."""
|
||||
return cls._db_manager
|
||||
|
||||
class _NewModel(CRUDMixin[T], db_manager.Model, Generic[T]):
|
||||
"""Base model class that includes CRUD convenience methods."""
|
||||
@@ -478,6 +477,7 @@ def initialize_db(
|
||||
engine_args: Optional[Dict] = None,
|
||||
base: Optional[DeclarativeMeta] = None,
|
||||
try_to_create_db: Optional[bool] = False,
|
||||
session_options: Optional[Dict] = None,
|
||||
) -> DatabaseManager:
|
||||
"""Initialize the database manager.
|
||||
|
||||
@@ -487,10 +487,11 @@ def initialize_db(
|
||||
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
|
||||
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
|
||||
try_to_create_db (Optional[bool], optional): Whether to try to create the database. Defaults to False.
|
||||
session_options (Optional[Dict], optional): The session options. Defaults to None.
|
||||
Returns:
|
||||
DatabaseManager: The database manager.
|
||||
"""
|
||||
db.init_db(db_url, engine_args, base)
|
||||
db.init_db(db_url, engine_args, base, session_options=session_options)
|
||||
if try_to_create_db:
|
||||
try:
|
||||
db.create_all()
|
||||
|
@@ -100,7 +100,7 @@ def test_update_user(db: DatabaseManager, User: Type[BaseModel], user_dao, user_
|
||||
|
||||
# Verify that the user is updated in the database
|
||||
with db.session() as session:
|
||||
user = session.query(User).get(created_user_response.id)
|
||||
user = session.get(User, created_user_response.id)
|
||||
assert user.age == 35
|
||||
|
||||
|
||||
@@ -121,7 +121,7 @@ def test_update_user_partial(
|
||||
|
||||
# Verify that the user is updated in the database
|
||||
with db.session() as session:
|
||||
user = session.query(User).get(created_user_response.id)
|
||||
user = session.get(User, created_user_response.id)
|
||||
assert user.age == user_req.age
|
||||
assert user.password == "newpassword"
|
||||
|
||||
|
@@ -53,11 +53,10 @@ def test_crud_operations(db: DatabaseManager, Model: Type[BaseModel]):
|
||||
|
||||
# Create
|
||||
with db.session() as session:
|
||||
user = User.create(name="John Doe")
|
||||
user = User(name="John Doe")
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
# Read
|
||||
# # Read
|
||||
with db.session() as session:
|
||||
user = session.query(User).filter_by(name="John Doe").first()
|
||||
assert user is not None
|
||||
@@ -65,12 +64,20 @@ def test_crud_operations(db: DatabaseManager, Model: Type[BaseModel]):
|
||||
# Update
|
||||
with db.session() as session:
|
||||
user = session.query(User).filter_by(name="John Doe").first()
|
||||
user.update(name="Jane Doe")
|
||||
|
||||
# Delete
|
||||
user.name = "Mike Doe"
|
||||
session.merge(user)
|
||||
with db.session() as session:
|
||||
user = session.query(User).filter_by(name="Jane Doe").first()
|
||||
user.delete()
|
||||
user = session.query(User).filter_by(name="Mike Doe").first()
|
||||
assert user is not None
|
||||
session.query(User).filter(User.name == "John Doe").first() is None
|
||||
#
|
||||
# # Delete
|
||||
with db.session() as session:
|
||||
user = session.query(User).filter_by(name="Mike Doe").first()
|
||||
session.delete(user)
|
||||
|
||||
with db.session() as session:
|
||||
assert len(session.query(User).all()) == 0
|
||||
|
||||
|
||||
def test_crud_mixins(db: DatabaseManager, Model: Type[BaseModel]):
|
||||
@@ -80,20 +87,7 @@ def test_crud_mixins(db: DatabaseManager, Model: Type[BaseModel]):
|
||||
name = Column(String(50))
|
||||
|
||||
db.create_all()
|
||||
|
||||
# Create
|
||||
user = User.create(name="John Doe")
|
||||
assert User.get(user.id) is not None
|
||||
users = User.all()
|
||||
assert len(users) == 1
|
||||
|
||||
# Update
|
||||
user.update(name="Bob Doe")
|
||||
assert User.get(user.id).name == "Bob Doe"
|
||||
|
||||
user = User.get(user.id)
|
||||
user.delete()
|
||||
assert User.get(user.id) is None
|
||||
User.db() == db
|
||||
|
||||
|
||||
def test_pagination_query(db: DatabaseManager, Model: Type[BaseModel]):
|
||||
@@ -108,11 +102,10 @@ def test_pagination_query(db: DatabaseManager, Model: Type[BaseModel]):
|
||||
for i in range(30):
|
||||
user = User(name=f"User {i}")
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
users_page_1 = User.query.paginate_query(page=1, per_page=10)
|
||||
assert len(users_page_1.items) == 10
|
||||
assert users_page_1.total_pages == 3
|
||||
with db.session() as session:
|
||||
users_page_1 = session.query(User).paginate_query(page=1, per_page=10)
|
||||
assert len(users_page_1.items) == 10
|
||||
assert users_page_1.total_pages == 3
|
||||
|
||||
|
||||
def test_invalid_pagination(db: DatabaseManager, Model: Type[BaseModel]):
|
||||
@@ -124,9 +117,11 @@ def test_invalid_pagination(db: DatabaseManager, Model: Type[BaseModel]):
|
||||
db.create_all()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
User.query.paginate_query(page=0, per_page=10)
|
||||
with db.session() as session:
|
||||
session.query(User).paginate_query(page=0, per_page=10)
|
||||
with pytest.raises(ValueError):
|
||||
User.query.paginate_query(page=1, per_page=-1)
|
||||
with db.session() as session:
|
||||
session.query(User).paginate_query(page=1, per_page=-1)
|
||||
|
||||
|
||||
def test_set_model_db_manager(db: DatabaseManager, Model: Type[BaseModel]):
|
||||
@@ -142,14 +137,19 @@ def test_set_model_db_manager(db: DatabaseManager, Model: Type[BaseModel]):
|
||||
new_db = DatabaseManager.build_from(
|
||||
f"sqlite:///{filename}", base=Model, override_query_class=True
|
||||
)
|
||||
Model.set_db_manager(new_db)
|
||||
Model.set_db(new_db)
|
||||
new_db.create_all()
|
||||
db.create_all()
|
||||
assert list(new_db.metadata.tables.keys())[0] == "user"
|
||||
User.create(**{"name": "John Doe"})
|
||||
with new_db.session() as session:
|
||||
user = User(name="John Doe")
|
||||
session.add(user)
|
||||
with new_db.session() as session:
|
||||
assert session.query(User).filter_by(name="John Doe").first() is not None
|
||||
with db.session() as session:
|
||||
assert session.query(User).filter_by(name="John Doe").first() is None
|
||||
assert len(User.query.all()) == 1
|
||||
assert User.query.filter(User.name == "John Doe").first().name == "John Doe"
|
||||
with new_db.session() as session:
|
||||
session.query(User).all() == 1
|
||||
session.query(User).filter(
|
||||
User.name == "John Doe"
|
||||
).first().name == "John Doe"
|
||||
|
@@ -7,7 +7,7 @@
|
||||
Call with non-streaming response.
|
||||
.. code-block:: shell
|
||||
|
||||
DBGPT_SERVER="http://127.0.0.1:5000"
|
||||
DBGPT_SERVER="http://127.0.0.1:5555"
|
||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/chat/completions \
|
||||
-H "Content-Type: application/json" -d '{
|
||||
"model": "proxyllm",
|
||||
|
Reference in New Issue
Block a user