refactor: Refactor storage and new serve template (#947)

This commit is contained in:
Fangyin Cheng
2023-12-18 19:30:40 +08:00
committed by GitHub
parent 22d95b444b
commit 511a43b849
63 changed files with 1891 additions and 229 deletions

View File

@@ -19,12 +19,7 @@ from dbgpt.util.system_utils import get_system_info
logger = logging.getLogger(__name__)
_torch_imported = False
try:
import torch
_torch_imported = True
except ImportError:
pass
torch = None
class DefaultModelWorker(ModelWorker):
@@ -95,6 +90,8 @@ class DefaultModelWorker(ModelWorker):
def start(
self, model_params: ModelParameters = None, command_args: List[str] = None
) -> None:
# Lazy load torch
_try_import_torch()
if not model_params:
model_params = self.parse_parameters(command_args)
self._model_params = model_params
@@ -436,3 +433,14 @@ def _new_metrics_from_model_output(
].available_memory_gb
return metrics
def _try_import_torch():
global torch
global _torch_imported
try:
import torch
_torch_imported = True
except ImportError:
pass

View File

@@ -1,7 +1,6 @@
import asyncio
from typing import Any, Callable
import httpx
from dbgpt.model.base import ModelInstance, WorkerApplyOutput, WorkerSupportedModel
from dbgpt.model.cluster.base import *
from dbgpt.model.cluster.registry import ModelRegistry
@@ -34,6 +33,9 @@ class RemoteWorkerManager(LocalWorkerManager):
success_handler: Callable = None,
error_handler: Callable = None,
) -> Any:
# Lazy import to avoid high time cost
import httpx
url = worker_run_data.worker.worker_addr + endpoint
headers = {**worker_run_data.worker.headers, **(additional_headers or {})}
timeout = worker_run_data.worker.timeout