feat(core): Support multi round conversation operator (#986)

This commit is contained in:
Fangyin Cheng
2023-12-27 23:26:28 +08:00
committed by GitHub
parent 9aec636b02
commit b13d3f6d92
63 changed files with 2011 additions and 314 deletions

View File

@@ -1,5 +1,6 @@
from dbgpt.serve.core.schemas import Result
from dbgpt.serve.core.config import BaseServeConfig
from dbgpt.serve.core.service import BaseService
from dbgpt.serve.core.serve import BaseServe
__ALL__ = ["Result", "BaseServeConfig", "BaseService"]
__ALL__ = ["Result", "BaseServeConfig", "BaseService", "BaseServe"]

60
dbgpt/serve/core/serve.py Normal file
View File

@@ -0,0 +1,60 @@
from abc import ABC
from typing import Optional, Union, List
import logging
from dbgpt.component import BaseComponent, SystemApp, ComponentType
from sqlalchemy import URL
from dbgpt.storage.metadata import DatabaseManager
logger = logging.getLogger(__name__)
class BaseServe(BaseComponent, ABC):
"""Base serve component for DB-GPT"""
name = "dbgpt_serve_base"
def __init__(
self,
system_app: SystemApp,
api_prefix: str,
api_tags: List[str],
db_url_or_db: Union[str, URL, DatabaseManager] = None,
try_create_tables: Optional[bool] = False,
):
self._system_app = system_app
self._api_prefix = api_prefix
self._api_tags = api_tags
self._db_url_or_db = db_url_or_db
self._try_create_tables = try_create_tables
self._not_create_table = True
self._app_has_initiated = False
def create_or_get_db_manager(self) -> DatabaseManager:
"""Create or get the database manager.
This method must be called after the application is initialized
Returns:
DatabaseManager: The database manager
"""
from dbgpt.storage.metadata import Model, db, UnifiedDBManagerFactory
# If you need to use the database, you can get the database manager here
db_manager_factory: UnifiedDBManagerFactory = self._system_app.get_component(
ComponentType.UNIFIED_METADATA_DB_MANAGER_FACTORY,
UnifiedDBManagerFactory,
default_component=None,
)
if db_manager_factory is not None and db_manager_factory.create():
init_db = db_manager_factory.create()
else:
init_db = self._db_url_or_db or db
init_db = DatabaseManager.build_from(init_db, base=Model)
if self._try_create_tables and self._not_create_table:
try:
init_db.create_all()
except Exception as e:
logger.warning(f"Failed to create tables: {e}")
finally:
self._not_create_table = False
return init_db

View File

@@ -8,15 +8,6 @@ from httpx import AsyncClient
from dbgpt.component import SystemApp
from dbgpt.util import AppConfig
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["*"],
)
def create_system_app(param: Dict) -> SystemApp:
app_config = param.get("app_config", {})
@@ -24,7 +15,17 @@ def create_system_app(param: Dict) -> SystemApp:
app_config = AppConfig(configs=app_config)
elif not isinstance(app_config, AppConfig):
raise RuntimeError("app_config must be AppConfig or dict")
return SystemApp(app, app_config)
test_app = FastAPI()
test_app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["*"],
)
return SystemApp(test_app, app_config)
@pytest_asyncio.fixture
@@ -51,9 +52,12 @@ async def client(request, asystem_app: SystemApp):
del param["api_keys"]
if client_api_key:
headers["Authorization"] = "Bearer " + client_api_key
async with AsyncClient(app=app, base_url=base_url, headers=headers) as client:
test_app = asystem_app.app
async with AsyncClient(app=test_app, base_url=base_url, headers=headers) as client:
for router in routers:
app.include_router(router)
test_app.include_router(router)
if app_caller:
app_caller(app, asystem_app)
app_caller(test_app, asystem_app)
yield client