mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-05 19:11:52 +00:00
refactor: Refactor for core SDK (#1092)
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
from dbgpt.model.cluster.client import DefaultLLMClient
|
||||
try:
|
||||
from dbgpt.model.cluster.client import DefaultLLMClient
|
||||
except ImportError as exc:
|
||||
# logging.warning("Can't import dbgpt.model.DefaultLLMClient")
|
||||
DefaultLLMClient = None
|
||||
|
||||
# from dbgpt.model.utils.chatgpt_utils import OpenAILLMClient
|
||||
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
|
||||
|
||||
__ALL__ = [
|
||||
"DefaultLLMClient",
|
||||
"OpenAILLMClient",
|
||||
]
|
||||
_exports = []
|
||||
if DefaultLLMClient:
|
||||
_exports.append("DefaultLLMClient")
|
||||
|
||||
__ALL__ = _exports
|
||||
|
@@ -137,7 +137,7 @@ class ModelLoader:
|
||||
def huggingface_loader(llm_adapter: LLMModelAdapter, model_params: ModelParameters):
|
||||
import torch
|
||||
|
||||
from dbgpt.model.compression import compress_module
|
||||
from dbgpt.model.llm.compression import compress_module
|
||||
|
||||
device = model_params.device
|
||||
max_memory = None
|
@@ -19,7 +19,7 @@ from dbgpt.configs.model_config import get_device
|
||||
from dbgpt.model.adapter.base import LLMModelAdapter
|
||||
from dbgpt.model.adapter.template import ConversationAdapter, PromptType
|
||||
from dbgpt.model.base import ModelType
|
||||
from dbgpt.model.conversation import Conversation
|
||||
from dbgpt.model.llm.conversation import Conversation
|
||||
from dbgpt.model.parameter import (
|
||||
LlamaCppModelParameters,
|
||||
ModelParameters,
|
||||
|
@@ -1,13 +1,11 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import time
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, TypedDict
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from dbgpt.util.model_utils import GPUInfo
|
||||
from dbgpt.util.parameter_utils import ParameterDescription
|
||||
|
||||
|
||||
|
@@ -12,9 +12,9 @@ from dbgpt.core import (
|
||||
ModelOutput,
|
||||
)
|
||||
from dbgpt.model.adapter.base import LLMModelAdapter
|
||||
from dbgpt.model.adapter.loader import ModelLoader, _get_model_real_path
|
||||
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
|
||||
from dbgpt.model.cluster.worker_base import ModelWorker
|
||||
from dbgpt.model.loader import ModelLoader, _get_model_real_path
|
||||
from dbgpt.model.parameter import ModelParameters
|
||||
from dbgpt.util.model_utils import _clear_model_cache, _get_current_cuda_memory
|
||||
from dbgpt.util.parameter_utils import EnvArgumentParser, _get_dict_from_obj
|
||||
|
@@ -3,9 +3,9 @@ from typing import Dict, List, Optional, Type
|
||||
|
||||
from dbgpt.configs.model_config import get_device
|
||||
from dbgpt.core import ModelMetadata
|
||||
from dbgpt.model.adapter.loader import _get_model_real_path
|
||||
from dbgpt.model.cluster.embedding.loader import EmbeddingLoader
|
||||
from dbgpt.model.cluster.worker_base import ModelWorker
|
||||
from dbgpt.model.loader import _get_model_real_path
|
||||
from dbgpt.model.parameter import (
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
||||
BaseEmbeddingModelParameters,
|
||||
|
@@ -8,7 +8,7 @@ import sys
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import asdict
|
||||
from typing import Awaitable, Callable, Dict, Iterator, List
|
||||
from typing import Awaitable, Callable, Iterator
|
||||
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi.responses import StreamingResponse
|
||||
@@ -16,12 +16,7 @@ from fastapi.responses import StreamingResponse
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.configs.model_config import LOGDIR
|
||||
from dbgpt.core import ModelMetadata, ModelOutput
|
||||
from dbgpt.model.base import (
|
||||
ModelInstance,
|
||||
WorkerApplyOutput,
|
||||
WorkerApplyType,
|
||||
WorkerSupportedModel,
|
||||
)
|
||||
from dbgpt.model.base import ModelInstance, WorkerApplyOutput, WorkerSupportedModel
|
||||
from dbgpt.model.cluster.base import *
|
||||
from dbgpt.model.cluster.manager_base import (
|
||||
WorkerManager,
|
||||
@@ -30,8 +25,8 @@ from dbgpt.model.cluster.manager_base import (
|
||||
)
|
||||
from dbgpt.model.cluster.registry import ModelRegistry
|
||||
from dbgpt.model.cluster.worker_base import ModelWorker
|
||||
from dbgpt.model.llm_utils import list_supported_models
|
||||
from dbgpt.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
|
||||
from dbgpt.model.parameter import ModelWorkerParameters, WorkerType
|
||||
from dbgpt.model.utils.llm_utils import list_supported_models
|
||||
from dbgpt.util.parameter_utils import (
|
||||
EnvArgumentParser,
|
||||
ParameterDescription,
|
||||
|
@@ -18,7 +18,7 @@ from transformers.generation.logits_process import (
|
||||
TopPLogitsWarper,
|
||||
)
|
||||
|
||||
from dbgpt.model.llm_utils import is_partial_stop, is_sentence_complete
|
||||
from dbgpt.model.utils.llm_utils import is_partial_stop, is_sentence_complete
|
||||
|
||||
|
||||
def prepare_logits_processor(
|
@@ -1,9 +1,9 @@
|
||||
from dbgpt.model.operator.llm_operator import (
|
||||
from dbgpt.model.operator.llm_operator import ( # noqa: F401
|
||||
LLMOperator,
|
||||
MixinLLMOperator,
|
||||
StreamingLLMOperator,
|
||||
)
|
||||
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
|
||||
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator # noqa: F401
|
||||
|
||||
__ALL__ = [
|
||||
"MixinLLMOperator",
|
||||
|
@@ -6,7 +6,6 @@ from dbgpt.component import ComponentType
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.core.awel import BaseOperator
|
||||
from dbgpt.core.operator import BaseLLM, BaseLLMOperator, BaseStreamingLLMOperator
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -19,31 +18,30 @@ class MixinLLMOperator(BaseLLM, BaseOperator, ABC):
|
||||
|
||||
def __init__(self, default_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(default_client)
|
||||
self._default_llm_client = default_client
|
||||
|
||||
@property
|
||||
def llm_client(self) -> LLMClient:
|
||||
if not self._llm_client:
|
||||
worker_manager_factory: WorkerManagerFactory = (
|
||||
self.system_app.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY,
|
||||
WorkerManagerFactory,
|
||||
default_component=None,
|
||||
)
|
||||
)
|
||||
if worker_manager_factory:
|
||||
try:
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.model.cluster.client import DefaultLLMClient
|
||||
|
||||
self._llm_client = DefaultLLMClient(worker_manager_factory.create())
|
||||
else:
|
||||
if self._default_llm_client is None:
|
||||
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
|
||||
|
||||
self._default_llm_client = OpenAILLMClient()
|
||||
logger.info(
|
||||
f"Can't find worker manager factory, use default llm client {self._default_llm_client}."
|
||||
worker_manager_factory: WorkerManagerFactory = (
|
||||
self.system_app.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY,
|
||||
WorkerManagerFactory,
|
||||
default_component=None,
|
||||
)
|
||||
)
|
||||
self._llm_client = self._default_llm_client
|
||||
if worker_manager_factory:
|
||||
self._llm_client = DefaultLLMClient(worker_manager_factory.create())
|
||||
except Exception as e:
|
||||
logger.warning(f"Load worker manager failed: {e}.")
|
||||
if not self._llm_client:
|
||||
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
|
||||
|
||||
logger.info("Can't find worker manager factory, use OpenAILLMClient.")
|
||||
self._llm_client = OpenAILLMClient()
|
||||
return self._llm_client
|
||||
|
||||
|
||||
|
@@ -6,11 +6,8 @@ from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
from dbgpt.model.conversation import conv_templates
|
||||
from dbgpt.util.parameter_utils import BaseParameters
|
||||
|
||||
suported_prompt_templates = ",".join(conv_templates.keys())
|
||||
|
||||
|
||||
class WorkerType(str, Enum):
|
||||
LLM = "llm"
|
||||
@@ -299,7 +296,8 @@ class ModelParameters(BaseModelParameters):
|
||||
prompt_template: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": f"Prompt template. If None, the prompt template is automatically determined from model path, supported template: {suported_prompt_templates}"
|
||||
"help": f"Prompt template. If None, the prompt template is automatically "
|
||||
f"determined from model path"
|
||||
},
|
||||
)
|
||||
max_context_size: Optional[int] = field(
|
||||
@@ -450,7 +448,8 @@ class ProxyModelParameters(BaseModelParameters):
|
||||
proxyllm_backend: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The model name actually pass to current proxy server url, such as gpt-3.5-turbo, gpt-4, chatglm_pro, chatglm_std and so on"
|
||||
"help": "The model name actually pass to current proxy server url, such "
|
||||
"as gpt-3.5-turbo, gpt-4, chatglm_pro, chatglm_std and so on"
|
||||
},
|
||||
)
|
||||
model_type: Optional[str] = field(
|
||||
@@ -463,13 +462,15 @@ class ProxyModelParameters(BaseModelParameters):
|
||||
device: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Device to run model. If None, the device is automatically determined"
|
||||
"help": "Device to run model. If None, the device is automatically "
|
||||
"determined"
|
||||
},
|
||||
)
|
||||
prompt_template: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": f"Prompt template. If None, the prompt template is automatically determined from model path, supported template: {suported_prompt_templates}"
|
||||
"help": f"Prompt template. If None, the prompt template is automatically "
|
||||
f"determined from model path"
|
||||
},
|
||||
)
|
||||
max_context_size: Optional[int] = field(
|
||||
@@ -478,7 +479,8 @@ class ProxyModelParameters(BaseModelParameters):
|
||||
llm_client_class: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The class name of llm client, such as dbgpt.model.proxy.llms.proxy_model.ProxyModel"
|
||||
"help": "The class name of llm client, such as "
|
||||
"dbgpt.model.proxy.llms.proxy_model.ProxyModel"
|
||||
},
|
||||
)
|
||||
|
||||
|
@@ -37,8 +37,8 @@ def list_supported_models():
|
||||
def _list_supported_models(
|
||||
worker_type: str, model_config: Dict[str, str]
|
||||
) -> List[SupportedModel]:
|
||||
from dbgpt.model.adapter.loader import _get_model_real_path
|
||||
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
|
||||
from dbgpt.model.loader import _get_model_real_path
|
||||
|
||||
ret = []
|
||||
for model_name, model_path in model_config.items():
|
Reference in New Issue
Block a user