mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-27 20:38:30 +00:00
feat(core): Support multi round conversation operator (#986)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from dbgpt.serve.core.schemas import Result
|
||||
from dbgpt.serve.core.config import BaseServeConfig
|
||||
from dbgpt.serve.core.service import BaseService
|
||||
from dbgpt.serve.core.serve import BaseServe
|
||||
|
||||
__ALL__ = ["Result", "BaseServeConfig", "BaseService"]
|
||||
__ALL__ = ["Result", "BaseServeConfig", "BaseService", "BaseServe"]
|
||||
|
60
dbgpt/serve/core/serve.py
Normal file
60
dbgpt/serve/core/serve.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from abc import ABC
|
||||
from typing import Optional, Union, List
|
||||
import logging
|
||||
from dbgpt.component import BaseComponent, SystemApp, ComponentType
|
||||
from sqlalchemy import URL
|
||||
from dbgpt.storage.metadata import DatabaseManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseServe(BaseComponent, ABC):
|
||||
"""Base serve component for DB-GPT"""
|
||||
|
||||
name = "dbgpt_serve_base"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_app: SystemApp,
|
||||
api_prefix: str,
|
||||
api_tags: List[str],
|
||||
db_url_or_db: Union[str, URL, DatabaseManager] = None,
|
||||
try_create_tables: Optional[bool] = False,
|
||||
):
|
||||
self._system_app = system_app
|
||||
self._api_prefix = api_prefix
|
||||
self._api_tags = api_tags
|
||||
self._db_url_or_db = db_url_or_db
|
||||
self._try_create_tables = try_create_tables
|
||||
self._not_create_table = True
|
||||
self._app_has_initiated = False
|
||||
|
||||
def create_or_get_db_manager(self) -> DatabaseManager:
|
||||
"""Create or get the database manager.
|
||||
This method must be called after the application is initialized
|
||||
|
||||
Returns:
|
||||
DatabaseManager: The database manager
|
||||
"""
|
||||
from dbgpt.storage.metadata import Model, db, UnifiedDBManagerFactory
|
||||
|
||||
# If you need to use the database, you can get the database manager here
|
||||
db_manager_factory: UnifiedDBManagerFactory = self._system_app.get_component(
|
||||
ComponentType.UNIFIED_METADATA_DB_MANAGER_FACTORY,
|
||||
UnifiedDBManagerFactory,
|
||||
default_component=None,
|
||||
)
|
||||
if db_manager_factory is not None and db_manager_factory.create():
|
||||
init_db = db_manager_factory.create()
|
||||
else:
|
||||
init_db = self._db_url_or_db or db
|
||||
init_db = DatabaseManager.build_from(init_db, base=Model)
|
||||
|
||||
if self._try_create_tables and self._not_create_table:
|
||||
try:
|
||||
init_db.create_all()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create tables: {e}")
|
||||
finally:
|
||||
self._not_create_table = False
|
||||
return init_db
|
@@ -8,15 +8,6 @@ from httpx import AsyncClient
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.util import AppConfig
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
def create_system_app(param: Dict) -> SystemApp:
|
||||
app_config = param.get("app_config", {})
|
||||
@@ -24,7 +15,17 @@ def create_system_app(param: Dict) -> SystemApp:
|
||||
app_config = AppConfig(configs=app_config)
|
||||
elif not isinstance(app_config, AppConfig):
|
||||
raise RuntimeError("app_config must be AppConfig or dict")
|
||||
return SystemApp(app, app_config)
|
||||
|
||||
test_app = FastAPI()
|
||||
test_app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
return SystemApp(test_app, app_config)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@@ -51,9 +52,12 @@ async def client(request, asystem_app: SystemApp):
|
||||
del param["api_keys"]
|
||||
if client_api_key:
|
||||
headers["Authorization"] = "Bearer " + client_api_key
|
||||
async with AsyncClient(app=app, base_url=base_url, headers=headers) as client:
|
||||
|
||||
test_app = asystem_app.app
|
||||
|
||||
async with AsyncClient(app=test_app, base_url=base_url, headers=headers) as client:
|
||||
for router in routers:
|
||||
app.include_router(router)
|
||||
test_app.include_router(router)
|
||||
if app_caller:
|
||||
app_caller(app, asystem_app)
|
||||
app_caller(test_app, asystem_app)
|
||||
yield client
|
||||
|
Reference in New Issue
Block a user