mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-04 18:40:10 +00:00
refactor: Refactor storage and new serve template (#947)
This commit is contained in:
@@ -19,12 +19,7 @@ from dbgpt.util.system_utils import get_system_info
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_torch_imported = False
|
||||
try:
|
||||
import torch
|
||||
|
||||
_torch_imported = True
|
||||
except ImportError:
|
||||
pass
|
||||
torch = None
|
||||
|
||||
|
||||
class DefaultModelWorker(ModelWorker):
|
||||
@@ -95,6 +90,8 @@ class DefaultModelWorker(ModelWorker):
|
||||
def start(
|
||||
self, model_params: ModelParameters = None, command_args: List[str] = None
|
||||
) -> None:
|
||||
# Lazy load torch
|
||||
_try_import_torch()
|
||||
if not model_params:
|
||||
model_params = self.parse_parameters(command_args)
|
||||
self._model_params = model_params
|
||||
@@ -436,3 +433,14 @@ def _new_metrics_from_model_output(
|
||||
].available_memory_gb
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def _try_import_torch():
|
||||
global torch
|
||||
global _torch_imported
|
||||
try:
|
||||
import torch
|
||||
|
||||
_torch_imported = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
from typing import Any, Callable
|
||||
|
||||
import httpx
|
||||
from dbgpt.model.base import ModelInstance, WorkerApplyOutput, WorkerSupportedModel
|
||||
from dbgpt.model.cluster.base import *
|
||||
from dbgpt.model.cluster.registry import ModelRegistry
|
||||
@@ -34,6 +33,9 @@ class RemoteWorkerManager(LocalWorkerManager):
|
||||
success_handler: Callable = None,
|
||||
error_handler: Callable = None,
|
||||
) -> Any:
|
||||
# Lazy import to avoid high time cost
|
||||
import httpx
|
||||
|
||||
url = worker_run_data.worker.worker_addr + endpoint
|
||||
headers = {**worker_run_data.worker.headers, **(additional_headers or {})}
|
||||
timeout = worker_run_data.worker.timeout
|
||||
|
Reference in New Issue
Block a user