Use fuzzy matching when searching dbgpts (#2110)

Co-authored-by: jiaoqiyuan <jiaoqiyuan@ipplus.com>
This commit is contained in:
Qiyuan Jiao 2024-11-05 13:39:59 +08:00 committed by GitHub
parent 52062fd960
commit b4ce217ded
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 44 additions and 4 deletions

View File

@ -2,11 +2,12 @@
You can define your own models and DAOs here You can define your own models and DAOs here
""" """
from datetime import datetime from datetime import datetime
from typing import Any, Dict, Union from typing import Any, Dict, Optional, Union
from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint from sqlalchemy import Column, DateTime, Index, Integer, String, UniqueConstraint, desc
from dbgpt.storage.metadata import BaseDao, Model, db from dbgpt.storage.metadata import BaseDao, Model, db
from dbgpt.util.pagination_utils import PaginationResult
from ..api.schemas import ServeRequest, ServerResponse from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVER_APP_TABLE_NAME, ServeConfig from ..config import SERVER_APP_TABLE_NAME, ServeConfig
@ -109,3 +110,42 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
gmt_created=gmt_created_str, gmt_created=gmt_created_str,
gmt_modified=gmt_modified_str, gmt_modified=gmt_modified_str,
) )
def dbgpts_list(
self,
query_request: ServeRequest,
page: int,
page_size: int,
desc_order_column: Optional[str] = None,
) -> PaginationResult[ServerResponse]:
"""Get a page of dbgpts.
Args:
query_request (ServeRequest): The request schema object or dict for query.
page (int): The page number.
page_size (int): The page size.
desc_order_column(Optional[str]): The column for descending order.
Returns:
PaginationResult: The pagination result.
"""
session = self.get_raw_session()
try:
query = session.query(ServeEntity)
if query_request.name:
query = query.filter(ServeEntity.name.like(f"%{query_request.name}%"))
if desc_order_column:
query = query.order_by(desc(getattr(ServeEntity, desc_order_column)))
total_count = query.count()
items = query.offset((page - 1) * page_size).limit(page_size)
res_items = [self.to_response(item) for item in items]
total_pages = (total_count + page_size - 1) // page_size
finally:
session.close()
return PaginationResult(
items=res_items,
total_count=total_count,
total_pages=total_pages,
page=page,
page_size=page_size,
)

View File

@ -41,7 +41,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
self._system_app = system_app self._system_app = system_app
@property @property
def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]: def dao(self) -> ServeDao:
"""Returns the internal DAO.""" """Returns the internal DAO."""
return self._dao return self._dao
@ -130,7 +130,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
installed=request.installed, installed=request.installed,
) )
return self.dao.get_list_page(query_request, page, page_size) return self.dao.dbgpts_list(query_request, page, page_size)
def refresh_hub_from_git( def refresh_hub_from_git(
self, self,