From b4ce217ded95f409919a078147dbcb0e19e9add6 Mon Sep 17 00:00:00 2001 From: Qiyuan Jiao <753525025@qq.com> Date: Tue, 5 Nov 2024 13:39:59 +0800 Subject: [PATCH] Use fuzzy matching when searching dbgpts (#2110) Co-authored-by: jiaoqiyuan --- dbgpt/serve/dbgpts/hub/models/models.py | 44 +++++++++++++++++++++-- dbgpt/serve/dbgpts/hub/service/service.py | 4 +-- 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/dbgpt/serve/dbgpts/hub/models/models.py b/dbgpt/serve/dbgpts/hub/models/models.py index 1bda0c8b2..b0352ea79 100644 --- a/dbgpt/serve/dbgpts/hub/models/models.py +++ b/dbgpt/serve/dbgpts/hub/models/models.py @@ -2,11 +2,12 @@ You can define your own models and DAOs here """ from datetime import datetime -from typing import Any, Dict, Union +from typing import Any, Dict, Optional, Union -from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint +from sqlalchemy import Column, DateTime, Index, Integer, String, UniqueConstraint, desc from dbgpt.storage.metadata import BaseDao, Model, db +from dbgpt.util.pagination_utils import PaginationResult from ..api.schemas import ServeRequest, ServerResponse from ..config import SERVER_APP_TABLE_NAME, ServeConfig @@ -109,3 +110,42 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): gmt_created=gmt_created_str, gmt_modified=gmt_modified_str, ) + + def dbgpts_list( + self, + query_request: ServeRequest, + page: int, + page_size: int, + desc_order_column: Optional[str] = None, + ) -> PaginationResult[ServerResponse]: + """Get a page of dbgpts. + + Args: + query_request (ServeRequest): The request schema object or dict for query. + page (int): The page number. + page_size (int): The page size. + desc_order_column(Optional[str]): The column for descending order. + Returns: + PaginationResult: The pagination result. + """ + session = self.get_raw_session() + try: + query = session.query(ServeEntity) + if query_request.name: + query = query.filter(ServeEntity.name.like(f"%{query_request.name}%")) + if desc_order_column: + query = query.order_by(desc(getattr(ServeEntity, desc_order_column))) + total_count = query.count() + items = query.offset((page - 1) * page_size).limit(page_size) + res_items = [self.to_response(item) for item in items] + total_pages = (total_count + page_size - 1) // page_size + finally: + session.close() + + return PaginationResult( + items=res_items, + total_count=total_count, + total_pages=total_pages, + page=page, + page_size=page_size, + ) diff --git a/dbgpt/serve/dbgpts/hub/service/service.py b/dbgpt/serve/dbgpts/hub/service/service.py index a7cd66169..91214061e 100644 --- a/dbgpt/serve/dbgpts/hub/service/service.py +++ b/dbgpt/serve/dbgpts/hub/service/service.py @@ -41,7 +41,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]): self._system_app = system_app @property - def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]: + def dao(self) -> ServeDao: """Returns the internal DAO.""" return self._dao @@ -130,7 +130,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]): installed=request.installed, ) - return self.dao.get_list_page(query_request, page, page_size) + return self.dao.dbgpts_list(query_request, page, page_size) def refresh_hub_from_git( self,