Use fuzzy matching when searching dbgpts (#2110)

Co-authored-by: jiaoqiyuan <jiaoqiyuan@ipplus.com>
This commit is contained in:
Qiyuan Jiao 2024-11-05 13:39:59 +08:00 committed by GitHub
parent 52062fd960
commit b4ce217ded
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 44 additions and 4 deletions

View File

@ -2,11 +2,12 @@
You can define your own models and DAOs here
"""
from datetime import datetime
from typing import Any, Dict, Union
from typing import Any, Dict, Optional, Union
from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint
from sqlalchemy import Column, DateTime, Index, Integer, String, UniqueConstraint, desc
from dbgpt.storage.metadata import BaseDao, Model, db
from dbgpt.util.pagination_utils import PaginationResult
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVER_APP_TABLE_NAME, ServeConfig
@ -109,3 +110,42 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
gmt_created=gmt_created_str,
gmt_modified=gmt_modified_str,
)
def dbgpts_list(
self,
query_request: ServeRequest,
page: int,
page_size: int,
desc_order_column: Optional[str] = None,
) -> PaginationResult[ServerResponse]:
"""Get a page of dbgpts.
Args:
query_request (ServeRequest): The request schema object or dict for query.
page (int): The page number.
page_size (int): The page size.
desc_order_column(Optional[str]): The column for descending order.
Returns:
PaginationResult: The pagination result.
"""
session = self.get_raw_session()
try:
query = session.query(ServeEntity)
if query_request.name:
query = query.filter(ServeEntity.name.like(f"%{query_request.name}%"))
if desc_order_column:
query = query.order_by(desc(getattr(ServeEntity, desc_order_column)))
total_count = query.count()
items = query.offset((page - 1) * page_size).limit(page_size)
res_items = [self.to_response(item) for item in items]
total_pages = (total_count + page_size - 1) // page_size
finally:
session.close()
return PaginationResult(
items=res_items,
total_count=total_count,
total_pages=total_pages,
page=page,
page_size=page_size,
)

View File

@ -41,7 +41,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
self._system_app = system_app
@property
def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]:
def dao(self) -> ServeDao:
"""Returns the internal DAO."""
return self._dao
@ -130,7 +130,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
installed=request.installed,
)
return self.dao.get_list_page(query_request, page, page_size)
return self.dao.dbgpts_list(query_request, page, page_size)
def refresh_hub_from_git(
self,