mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-05 19:11:52 +00:00
feat(core): APP use new SDK component (#1050)
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
from dbgpt.model.base import ModelType
|
||||
|
||||
from dbgpt.model.adapter.base import LLMModelAdapter
|
||||
from dbgpt.model.adapter.template import ConversationAdapter, ConversationAdapterFactory
|
||||
from dbgpt.model.base import ModelType
|
||||
from dbgpt.model.parameter import BaseModelParameters
|
||||
from dbgpt.util.parameter_utils import (
|
||||
_extract_parameter_details,
|
||||
_build_parameter_class,
|
||||
_extract_parameter_details,
|
||||
_get_dataclass_print_str,
|
||||
)
|
||||
|
||||
@@ -27,6 +28,7 @@ class VLLMModelAdapterWrapper(LLMModelAdapter):
|
||||
|
||||
def model_param_class(self, model_type: str = None) -> BaseModelParameters:
|
||||
import argparse
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -56,9 +58,9 @@ class VLLMModelAdapterWrapper(LLMModelAdapter):
|
||||
return _build_parameter_class(descs)
|
||||
|
||||
def load_from_params(self, params):
|
||||
import torch
|
||||
from vllm import AsyncLLMEngine
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
import torch
|
||||
|
||||
num_gpus = torch.cuda.device_count()
|
||||
if num_gpus > 1 and hasattr(params, "tensor_parallel_size"):
|
||||
|
Reference in New Issue
Block a user