mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-24 04:36:23 +00:00
feat: Support variables query API
This commit is contained in:
parent
439b5b32e2
commit
494eb587dd
@ -31,6 +31,7 @@ BUILTIN_VARIABLES_CORE_VARIABLES = "dbgpt.core.variables"
|
||||
BUILTIN_VARIABLES_CORE_SECRETS = "dbgpt.core.secrets"
|
||||
BUILTIN_VARIABLES_CORE_LLMS = "dbgpt.core.model.llms"
|
||||
BUILTIN_VARIABLES_CORE_EMBEDDINGS = "dbgpt.core.model.embeddings"
|
||||
# Not implemented yet
|
||||
BUILTIN_VARIABLES_CORE_RERANKERS = "dbgpt.core.model.rerankers"
|
||||
BUILTIN_VARIABLES_CORE_DATASOURCES = "dbgpt.core.datasources"
|
||||
BUILTIN_VARIABLES_CORE_AGENTS = "dbgpt.core.agent.agents"
|
||||
|
@ -27,7 +27,7 @@ from dbgpt.util.i18n_utils import _
|
||||
name="auto_convert_message",
|
||||
type=bool,
|
||||
optional=True,
|
||||
default=False,
|
||||
default=True,
|
||||
description=_(
|
||||
"Whether to auto convert the messages that are not supported "
|
||||
"by the LLM to a compatible format"
|
||||
@ -128,7 +128,7 @@ class DefaultLLMClient(LLMClient):
|
||||
name="auto_convert_message",
|
||||
type=bool,
|
||||
optional=True,
|
||||
default=False,
|
||||
default=True,
|
||||
description=_(
|
||||
"Whether to auto convert the messages that are not supported "
|
||||
"by the LLM to a compatible format"
|
||||
@ -158,7 +158,7 @@ class RemoteLLMClient(DefaultLLMClient):
|
||||
def __init__(
|
||||
self,
|
||||
controller_address: str = "http://127.0.0.1:8000",
|
||||
auto_convert_message: bool = False,
|
||||
auto_convert_message: bool = True,
|
||||
):
|
||||
"""Initialize the RemoteLLMClient."""
|
||||
from dbgpt.model.cluster import ModelRegistryClient, RemoteWorkerManager
|
||||
|
@ -21,6 +21,7 @@ from .schemas import (
|
||||
RefreshNodeRequest,
|
||||
ServeRequest,
|
||||
ServerResponse,
|
||||
VariablesKeyResponse,
|
||||
VariablesRequest,
|
||||
VariablesResponse,
|
||||
)
|
||||
@ -359,6 +360,62 @@ async def update_variables(
|
||||
return Result.succ(res)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/variables",
|
||||
response_model=Result[PaginationResult[VariablesResponse]],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def get_variables_by_keys(
|
||||
key: str = Query(..., description="variable key"),
|
||||
scope: Optional[str] = Query(default=None, description="scope"),
|
||||
scope_key: Optional[str] = Query(default=None, description="scope key"),
|
||||
user_name: Optional[str] = Query(default=None, description="user name"),
|
||||
sys_code: Optional[str] = Query(default=None, description="system code"),
|
||||
page: int = Query(default=1, description="current page"),
|
||||
page_size: int = Query(default=20, description="page size"),
|
||||
) -> Result[PaginationResult[VariablesResponse]]:
|
||||
"""Get the variables by keys
|
||||
|
||||
Returns:
|
||||
VariablesResponse: The response
|
||||
"""
|
||||
res = await get_variable_service().get_list_by_page(
|
||||
key,
|
||||
scope,
|
||||
scope_key,
|
||||
user_name,
|
||||
sys_code,
|
||||
page,
|
||||
page_size,
|
||||
)
|
||||
return Result.succ(res)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/variables/keys",
|
||||
response_model=Result[List[VariablesKeyResponse]],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def get_variables_keys(
|
||||
user_name: Optional[str] = Query(default=None, description="user name"),
|
||||
sys_code: Optional[str] = Query(default=None, description="system code"),
|
||||
category: Optional[str] = Query(default=None, description="category"),
|
||||
) -> Result[List[VariablesKeyResponse]]:
|
||||
"""Get the variable keys
|
||||
|
||||
Returns:
|
||||
VariablesKeyResponse: The response
|
||||
"""
|
||||
res = await blocking_func_to_async(
|
||||
global_system_app,
|
||||
get_variable_service().list_keys,
|
||||
user_name,
|
||||
sys_code,
|
||||
category,
|
||||
)
|
||||
return Result.succ(res)
|
||||
|
||||
|
||||
@router.post("/flow/debug", dependencies=[Depends(check_api_key)])
|
||||
async def debug_flow(
|
||||
flow_debug_request: FlowDebugRequest, service: Service = Depends(get_service)
|
||||
@ -477,10 +534,13 @@ async def import_flow(
|
||||
def init_endpoints(system_app: SystemApp) -> None:
|
||||
"""Initialize the endpoints"""
|
||||
from .variables_provider import (
|
||||
BuiltinAgentsVariablesProvider,
|
||||
BuiltinAllSecretVariablesProvider,
|
||||
BuiltinAllVariablesProvider,
|
||||
BuiltinDatasourceVariablesProvider,
|
||||
BuiltinEmbeddingsVariablesProvider,
|
||||
BuiltinFlowVariablesProvider,
|
||||
BuiltinKnowledgeSpacesVariablesProvider,
|
||||
BuiltinLLMVariablesProvider,
|
||||
BuiltinNodeVariablesProvider,
|
||||
)
|
||||
@ -494,4 +554,7 @@ def init_endpoints(system_app: SystemApp) -> None:
|
||||
system_app.register(BuiltinAllSecretVariablesProvider)
|
||||
system_app.register(BuiltinLLMVariablesProvider)
|
||||
system_app.register(BuiltinEmbeddingsVariablesProvider)
|
||||
system_app.register(BuiltinDatasourceVariablesProvider)
|
||||
system_app.register(BuiltinAgentsVariablesProvider)
|
||||
system_app.register(BuiltinKnowledgeSpacesVariablesProvider)
|
||||
global_system_app = system_app
|
||||
|
@ -2,7 +2,11 @@ from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
|
||||
from dbgpt.core.awel import CommonLLMHttpRequestBody
|
||||
from dbgpt.core.awel.flow.flow_factory import FlowPanel, VariablesRequest
|
||||
from dbgpt.core.awel.flow.flow_factory import (
|
||||
FlowPanel,
|
||||
VariablesRequest,
|
||||
_VariablesRequestBase,
|
||||
)
|
||||
from dbgpt.core.awel.util.parameter_util import RefreshOptionRequest
|
||||
|
||||
from ..config import SERVE_APP_NAME_HUMP
|
||||
@ -28,6 +32,13 @@ class VariablesResponse(VariablesRequest):
|
||||
)
|
||||
|
||||
|
||||
class VariablesKeyResponse(_VariablesRequestBase):
|
||||
"""Variables Key response model.
|
||||
|
||||
Just include the key, for select options in the frontend.
|
||||
"""
|
||||
|
||||
|
||||
class RefreshNodeRequest(BaseModel):
|
||||
"""Flow response model"""
|
||||
|
||||
|
@ -1,9 +1,12 @@
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from dbgpt.core.interface.variables import (
|
||||
BUILTIN_VARIABLES_CORE_AGENTS,
|
||||
BUILTIN_VARIABLES_CORE_DATASOURCES,
|
||||
BUILTIN_VARIABLES_CORE_EMBEDDINGS,
|
||||
BUILTIN_VARIABLES_CORE_FLOW_NODES,
|
||||
BUILTIN_VARIABLES_CORE_FLOWS,
|
||||
BUILTIN_VARIABLES_CORE_KNOWLEDGE_SPACES,
|
||||
BUILTIN_VARIABLES_CORE_LLMS,
|
||||
BUILTIN_VARIABLES_CORE_SECRETS,
|
||||
BUILTIN_VARIABLES_CORE_VARIABLES,
|
||||
@ -54,6 +57,7 @@ class BuiltinFlowVariablesProvider(BuiltinVariablesProvider):
|
||||
scope_key=scope_key,
|
||||
sys_code=sys_code,
|
||||
user_name=user_name,
|
||||
description=flow.description,
|
||||
)
|
||||
)
|
||||
return variables
|
||||
@ -91,6 +95,7 @@ class BuiltinNodeVariablesProvider(BuiltinVariablesProvider):
|
||||
scope_key=scope_key,
|
||||
sys_code=sys_code,
|
||||
user_name=user_name,
|
||||
description=metadata.get("description"),
|
||||
)
|
||||
)
|
||||
return variables
|
||||
@ -122,10 +127,14 @@ class BuiltinAllVariablesProvider(BuiltinVariablesProvider):
|
||||
name=var.name,
|
||||
label=var.label,
|
||||
value=var.value,
|
||||
category=var.category,
|
||||
value_type=var.value_type,
|
||||
scope=scope,
|
||||
scope_key=scope_key,
|
||||
sys_code=sys_code,
|
||||
user_name=user_name,
|
||||
enabled=1 if var.enabled else 0,
|
||||
description=var.description,
|
||||
)
|
||||
)
|
||||
return variables
|
||||
@ -258,3 +267,128 @@ class BuiltinEmbeddingsVariablesProvider(BuiltinLLMVariablesProvider):
|
||||
return await self._get_models(
|
||||
key, scope, scope_key, sys_code, user_name, "text2vec"
|
||||
)
|
||||
|
||||
|
||||
class BuiltinDatasourceVariablesProvider(BuiltinVariablesProvider):
|
||||
"""Builtin datasource variables provider.
|
||||
|
||||
Provide all datasource variables by variables "${dbgpt.core.datasource}"
|
||||
"""
|
||||
|
||||
name = BUILTIN_VARIABLES_CORE_DATASOURCES
|
||||
|
||||
def get_variables(
|
||||
self,
|
||||
key: str,
|
||||
scope: str = "global",
|
||||
scope_key: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> List[StorageVariables]:
|
||||
"""Get the builtin variables."""
|
||||
from dbgpt.serve.datasource.service.service import (
|
||||
DatasourceServeResponse,
|
||||
Service,
|
||||
)
|
||||
|
||||
all_datasource: List[DatasourceServeResponse] = Service.get_instance(
|
||||
self.system_app
|
||||
).list()
|
||||
|
||||
variables = []
|
||||
for datasource in all_datasource:
|
||||
label = f"[{datasource.db_type}]{datasource.db_name}"
|
||||
variables.append(
|
||||
StorageVariables(
|
||||
key=key,
|
||||
name=datasource.db_name,
|
||||
label=label,
|
||||
value=datasource.db_name,
|
||||
scope=scope,
|
||||
scope_key=scope_key,
|
||||
sys_code=sys_code,
|
||||
user_name=user_name,
|
||||
description=datasource.comment,
|
||||
)
|
||||
)
|
||||
return variables
|
||||
|
||||
|
||||
class BuiltinAgentsVariablesProvider(BuiltinVariablesProvider):
|
||||
"""Builtin agents variables provider.
|
||||
|
||||
Provide all agents variables by variables "${dbgpt.core.agent.agents}"
|
||||
"""
|
||||
|
||||
name = BUILTIN_VARIABLES_CORE_AGENTS
|
||||
|
||||
def get_variables(
|
||||
self,
|
||||
key: str,
|
||||
scope: str = "global",
|
||||
scope_key: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> List[StorageVariables]:
|
||||
"""Get the builtin variables."""
|
||||
from dbgpt.agent.core.agent_manage import get_agent_manager
|
||||
|
||||
agent_manager = get_agent_manager(self.system_app)
|
||||
agents = agent_manager.list_agents()
|
||||
variables = []
|
||||
for agent in agents:
|
||||
variables.append(
|
||||
StorageVariables(
|
||||
key=key,
|
||||
name=agent["name"],
|
||||
label=agent["desc"],
|
||||
value=agent["name"],
|
||||
scope=scope,
|
||||
scope_key=scope_key,
|
||||
sys_code=sys_code,
|
||||
user_name=user_name,
|
||||
description=agent["desc"],
|
||||
)
|
||||
)
|
||||
return variables
|
||||
|
||||
|
||||
class BuiltinKnowledgeSpacesVariablesProvider(BuiltinVariablesProvider):
|
||||
"""Builtin knowledge variables provider.
|
||||
|
||||
Provide all knowledge variables by variables "${dbgpt.core.knowledge_spaces}"
|
||||
"""
|
||||
|
||||
name = BUILTIN_VARIABLES_CORE_KNOWLEDGE_SPACES
|
||||
|
||||
def get_variables(
|
||||
self,
|
||||
key: str,
|
||||
scope: str = "global",
|
||||
scope_key: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> List[StorageVariables]:
|
||||
"""Get the builtin variables."""
|
||||
from dbgpt.serve.rag.service.service import Service, SpaceServeRequest
|
||||
|
||||
# TODO: Query with user_name and sys_code
|
||||
knowledge_list = Service.get_instance(self.system_app).get_list(
|
||||
SpaceServeRequest()
|
||||
)
|
||||
variables = []
|
||||
for k in knowledge_list:
|
||||
variables.append(
|
||||
StorageVariables(
|
||||
key=key,
|
||||
name=k.name,
|
||||
label=k.name,
|
||||
value=k.name,
|
||||
scope=scope,
|
||||
scope_key=scope_key,
|
||||
sys_code=sys_code,
|
||||
user_name=user_name,
|
||||
description=k.desc,
|
||||
)
|
||||
)
|
||||
return variables
|
||||
|
@ -230,7 +230,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
continue
|
||||
# Set state to DEPLOYED
|
||||
flow.state = State.DEPLOYED
|
||||
exist_inst = self.get({"name": flow.name})
|
||||
exist_inst = self.dao.get_one({"name": flow.name})
|
||||
if not exist_inst:
|
||||
self.create_and_save_dag(flow, save_failed_flow=True)
|
||||
elif is_first_load or exist_inst.state != State.RUNNING:
|
||||
|
@ -1,10 +1,25 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt import SystemApp
|
||||
from dbgpt.core.interface.variables import StorageVariables, VariablesProvider
|
||||
from dbgpt.serve.core import BaseService
|
||||
from dbgpt.core.interface.variables import (
|
||||
BUILTIN_VARIABLES_CORE_AGENTS,
|
||||
BUILTIN_VARIABLES_CORE_DATASOURCES,
|
||||
BUILTIN_VARIABLES_CORE_EMBEDDINGS,
|
||||
BUILTIN_VARIABLES_CORE_FLOW_NODES,
|
||||
BUILTIN_VARIABLES_CORE_FLOWS,
|
||||
BUILTIN_VARIABLES_CORE_KNOWLEDGE_SPACES,
|
||||
BUILTIN_VARIABLES_CORE_LLMS,
|
||||
BUILTIN_VARIABLES_CORE_RERANKERS,
|
||||
BUILTIN_VARIABLES_CORE_SECRETS,
|
||||
BUILTIN_VARIABLES_CORE_VARIABLES,
|
||||
StorageVariables,
|
||||
VariablesProvider,
|
||||
)
|
||||
from dbgpt.serve.core import BaseService, blocking_func_to_async
|
||||
from dbgpt.util import PaginationResult
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
from ..api.schemas import VariablesRequest, VariablesResponse
|
||||
from ..api.schemas import VariablesKeyResponse, VariablesRequest, VariablesResponse
|
||||
from ..config import (
|
||||
SERVE_CONFIG_KEY_PREFIX,
|
||||
SERVE_VARIABLES_SERVICE_COMPONENT_NAME,
|
||||
@ -12,6 +27,93 @@ from ..config import (
|
||||
)
|
||||
from ..models.models import VariablesDao, VariablesEntity
|
||||
|
||||
BUILTIN_VARIABLES = [
|
||||
VariablesKeyResponse(
|
||||
key=BUILTIN_VARIABLES_CORE_FLOWS,
|
||||
label=_("All AWEL Flows"),
|
||||
description=_("Fetch all AWEL flows in the system"),
|
||||
value_type="str",
|
||||
category="common",
|
||||
scope="global",
|
||||
),
|
||||
VariablesKeyResponse(
|
||||
key=BUILTIN_VARIABLES_CORE_FLOW_NODES,
|
||||
label=_("All AWEL Flow Nodes"),
|
||||
description=_("Fetch all AWEL flow nodes in the system"),
|
||||
value_type="str",
|
||||
category="common",
|
||||
scope="global",
|
||||
),
|
||||
VariablesKeyResponse(
|
||||
key=BUILTIN_VARIABLES_CORE_VARIABLES,
|
||||
label=_("All Variables"),
|
||||
description=_("Fetch all variables in the system"),
|
||||
value_type="str",
|
||||
category="common",
|
||||
scope="global",
|
||||
),
|
||||
VariablesKeyResponse(
|
||||
key=BUILTIN_VARIABLES_CORE_SECRETS,
|
||||
label=_("All Secrets"),
|
||||
description=_("Fetch all secrets in the system"),
|
||||
value_type="str",
|
||||
category="common",
|
||||
scope="global",
|
||||
),
|
||||
VariablesKeyResponse(
|
||||
key=BUILTIN_VARIABLES_CORE_LLMS,
|
||||
label=_("All LLMs"),
|
||||
description=_("Fetch all LLMs in the system"),
|
||||
value_type="str",
|
||||
category="common",
|
||||
scope="global",
|
||||
),
|
||||
VariablesKeyResponse(
|
||||
key=BUILTIN_VARIABLES_CORE_EMBEDDINGS,
|
||||
label=_("All Embeddings"),
|
||||
description=_("Fetch all embeddings models in the system"),
|
||||
value_type="str",
|
||||
category="common",
|
||||
scope="global",
|
||||
),
|
||||
VariablesKeyResponse(
|
||||
key=BUILTIN_VARIABLES_CORE_RERANKERS,
|
||||
label=_("All Rerankers"),
|
||||
description=_("Fetch all rerankers in the system"),
|
||||
value_type="str",
|
||||
category="common",
|
||||
scope="global",
|
||||
),
|
||||
VariablesKeyResponse(
|
||||
key=BUILTIN_VARIABLES_CORE_DATASOURCES,
|
||||
label=_("All Data Sources"),
|
||||
description=_("Fetch all data sources in the system"),
|
||||
value_type="str",
|
||||
category="common",
|
||||
scope="global",
|
||||
),
|
||||
VariablesKeyResponse(
|
||||
key=BUILTIN_VARIABLES_CORE_AGENTS,
|
||||
label=_("All Agents"),
|
||||
description=_("Fetch all agents in the system"),
|
||||
value_type="str",
|
||||
category="common",
|
||||
scope="global",
|
||||
),
|
||||
VariablesKeyResponse(
|
||||
key=BUILTIN_VARIABLES_CORE_KNOWLEDGE_SPACES,
|
||||
label=_("All Knowledge Spaces"),
|
||||
description=_("Fetch all knowledge spaces in the system"),
|
||||
value_type="str",
|
||||
category="common",
|
||||
scope="global",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _is_builtin_variable(key: str) -> bool:
|
||||
return key in [v.key for v in BUILTIN_VARIABLES]
|
||||
|
||||
|
||||
class VariablesService(
|
||||
BaseService[VariablesEntity, VariablesRequest, VariablesResponse]
|
||||
@ -148,5 +250,119 @@ class VariablesService(
|
||||
return self.dao.get_one(query)
|
||||
|
||||
def list_all_variables(self, category: str = "common") -> List[VariablesResponse]:
|
||||
"""List all variables."""
|
||||
"""List all variables.
|
||||
|
||||
Please note that this method will return all variables in the system, it may
|
||||
be a large list.
|
||||
"""
|
||||
return self.dao.get_list({"enabled": True, "category": category})
|
||||
|
||||
def list_keys(
|
||||
self,
|
||||
user_name: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
category: Optional[str] = None,
|
||||
) -> List[VariablesKeyResponse]:
|
||||
"""List all keys."""
|
||||
results = []
|
||||
|
||||
# TODO: More high performance way to get the keys
|
||||
all_db_variables = self.dao.get_list(
|
||||
{
|
||||
"enabled": True,
|
||||
"category": category,
|
||||
"user_name": user_name,
|
||||
"sys_code": sys_code,
|
||||
}
|
||||
)
|
||||
if not user_name:
|
||||
# Only return the keys that are not user specific
|
||||
all_db_variables = [v for v in all_db_variables if not v.user_name]
|
||||
if not sys_code:
|
||||
# Only return the keys that are not system specific
|
||||
all_db_variables = [v for v in all_db_variables if not v.sys_code]
|
||||
key_to_db_variable = {}
|
||||
for db_variable in all_db_variables:
|
||||
key = db_variable.key
|
||||
if key not in key_to_db_variable:
|
||||
key_to_db_variable[key] = db_variable
|
||||
|
||||
# Append all builtin variables to the results
|
||||
results.extend(BUILTIN_VARIABLES)
|
||||
|
||||
# Append all db variables to the results
|
||||
for key, db_variable in key_to_db_variable.items():
|
||||
results.append(
|
||||
VariablesKeyResponse(
|
||||
key=key,
|
||||
label=db_variable.label,
|
||||
description=db_variable.description,
|
||||
value_type=db_variable.value_type,
|
||||
category=db_variable.category,
|
||||
scope=db_variable.scope,
|
||||
scope_key=db_variable.scope_key,
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
async def get_list_by_page(
|
||||
self,
|
||||
key: str,
|
||||
scope: Optional[str] = None,
|
||||
scope_key: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> PaginationResult[VariablesResponse]:
|
||||
"""Get a list of variables by page."""
|
||||
if not _is_builtin_variable(key):
|
||||
query = {
|
||||
"key": key,
|
||||
"scope": scope,
|
||||
"scope_key": scope_key,
|
||||
"user_name": user_name,
|
||||
"sys_code": sys_code,
|
||||
}
|
||||
return await blocking_func_to_async(
|
||||
self._system_app,
|
||||
self.dao.get_list_page,
|
||||
query,
|
||||
page,
|
||||
page_size,
|
||||
desc_order_column="gmt_modified",
|
||||
)
|
||||
else:
|
||||
variables: List[
|
||||
StorageVariables
|
||||
] = await self.variables_provider.async_get_variables(
|
||||
key=key,
|
||||
scope=scope,
|
||||
scope_key=scope_key,
|
||||
sys_code=sys_code,
|
||||
user_name=user_name,
|
||||
)
|
||||
result_variables = []
|
||||
for entity in variables:
|
||||
result_variables.append(
|
||||
VariablesResponse(
|
||||
id=-1,
|
||||
key=entity.key,
|
||||
name=entity.name,
|
||||
label=entity.label,
|
||||
value=entity.value,
|
||||
value_type=entity.value_type,
|
||||
category=entity.category,
|
||||
scope=entity.scope,
|
||||
scope_key=entity.scope_key,
|
||||
enabled=True if entity.enabled == 1 else False,
|
||||
user_name=entity.user_name,
|
||||
sys_code=entity.sys_code,
|
||||
description=entity.description,
|
||||
)
|
||||
)
|
||||
return PaginationResult.build_from_all(
|
||||
result_variables,
|
||||
page,
|
||||
page_size,
|
||||
)
|
||||
|
@ -15,3 +15,29 @@ class PaginationResult(BaseModel, Generic[T]):
|
||||
total_pages: int = Field(..., description="total number of pages")
|
||||
page: int = Field(..., description="Current page number")
|
||||
page_size: int = Field(..., description="Number of items per page")
|
||||
|
||||
@classmethod
|
||||
def build_from_all(
|
||||
cls, all_items: List[T], page: int, page_size: int
|
||||
) -> "PaginationResult[T]":
|
||||
"""Build a pagination result from all items"""
|
||||
if page < 1:
|
||||
page = 1
|
||||
if page_size < 1:
|
||||
page_size = 1
|
||||
total_count = len(all_items)
|
||||
total_pages = (
|
||||
(total_count + page_size - 1) // page_size if total_count > 0 else 0
|
||||
)
|
||||
page = max(1, min(page, total_pages)) if total_pages > 0 else 0
|
||||
start_index = (page - 1) * page_size if page > 0 else 0
|
||||
end_index = min(start_index + page_size, total_count)
|
||||
items = all_items[start_index:end_index]
|
||||
|
||||
return cls(
|
||||
items=items,
|
||||
total_count=total_count,
|
||||
total_pages=total_pages,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
84
dbgpt/util/tests/test_pagination_utils.py
Normal file
84
dbgpt/util/tests/test_pagination_utils.py
Normal file
@ -0,0 +1,84 @@
|
||||
from dbgpt.util.pagination_utils import PaginationResult
|
||||
|
||||
|
||||
def test_build_from_all_normal_case():
|
||||
items = list(range(100))
|
||||
result = PaginationResult.build_from_all(items, page=2, page_size=20)
|
||||
|
||||
assert len(result.items) == 20
|
||||
assert result.items == list(range(20, 40))
|
||||
assert result.total_count == 100
|
||||
assert result.total_pages == 5
|
||||
assert result.page == 2
|
||||
assert result.page_size == 20
|
||||
|
||||
|
||||
def test_build_from_all_empty_list():
|
||||
items = []
|
||||
result = PaginationResult.build_from_all(items, page=1, page_size=5)
|
||||
|
||||
assert result.items == []
|
||||
assert result.total_count == 0
|
||||
assert result.total_pages == 0
|
||||
assert result.page == 0
|
||||
assert result.page_size == 5
|
||||
|
||||
|
||||
def test_build_from_all_last_page():
|
||||
items = list(range(95))
|
||||
result = PaginationResult.build_from_all(items, page=5, page_size=20)
|
||||
|
||||
assert len(result.items) == 15
|
||||
assert result.items == list(range(80, 95))
|
||||
assert result.total_count == 95
|
||||
assert result.total_pages == 5
|
||||
assert result.page == 5
|
||||
assert result.page_size == 20
|
||||
|
||||
|
||||
def test_build_from_all_page_out_of_range():
|
||||
items = list(range(50))
|
||||
result = PaginationResult.build_from_all(items, page=10, page_size=10)
|
||||
|
||||
assert len(result.items) == 10
|
||||
assert result.items == list(range(40, 50))
|
||||
assert result.total_count == 50
|
||||
assert result.total_pages == 5
|
||||
assert result.page == 5
|
||||
assert result.page_size == 10
|
||||
|
||||
|
||||
def test_build_from_all_page_zero():
|
||||
items = list(range(50))
|
||||
result = PaginationResult.build_from_all(items, page=0, page_size=10)
|
||||
|
||||
assert len(result.items) == 10
|
||||
assert result.items == list(range(0, 10))
|
||||
assert result.total_count == 50
|
||||
assert result.total_pages == 5
|
||||
assert result.page == 1
|
||||
assert result.page_size == 10
|
||||
|
||||
|
||||
def test_build_from_all_negative_page():
|
||||
items = list(range(50))
|
||||
result = PaginationResult.build_from_all(items, page=-1, page_size=10)
|
||||
|
||||
assert len(result.items) == 10
|
||||
assert result.items == list(range(0, 10))
|
||||
assert result.total_count == 50
|
||||
assert result.total_pages == 5
|
||||
assert result.page == 1
|
||||
assert result.page_size == 10
|
||||
|
||||
|
||||
def test_build_from_all_page_size_larger_than_total():
|
||||
items = list(range(50))
|
||||
result = PaginationResult.build_from_all(items, page=1, page_size=100)
|
||||
|
||||
assert len(result.items) == 50
|
||||
assert result.items == list(range(50))
|
||||
assert result.total_count == 50
|
||||
assert result.total_pages == 1
|
||||
assert result.page == 1
|
||||
assert result.page_size == 100
|
Loading…
Reference in New Issue
Block a user