refactor: Refactor for core SDK (#1092)

This commit is contained in:
Fangyin Cheng
2024-01-21 09:57:57 +08:00
committed by GitHub
parent ba7248adbb
commit 2d905191f8
45 changed files with 236 additions and 133 deletions

View File

@@ -1,9 +1,12 @@
from dbgpt.model.cluster.client import DefaultLLMClient
try:
from dbgpt.model.cluster.client import DefaultLLMClient
except ImportError as exc:
# logging.warning("Can't import dbgpt.model.DefaultLLMClient")
DefaultLLMClient = None
# from dbgpt.model.utils.chatgpt_utils import OpenAILLMClient
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
__ALL__ = [
"DefaultLLMClient",
"OpenAILLMClient",
]
_exports = []
if DefaultLLMClient:
_exports.append("DefaultLLMClient")
__ALL__ = _exports

View File

@@ -137,7 +137,7 @@ class ModelLoader:
def huggingface_loader(llm_adapter: LLMModelAdapter, model_params: ModelParameters):
import torch
from dbgpt.model.compression import compress_module
from dbgpt.model.llm.compression import compress_module
device = model_params.device
max_memory = None

View File

@@ -19,7 +19,7 @@ from dbgpt.configs.model_config import get_device
from dbgpt.model.adapter.base import LLMModelAdapter
from dbgpt.model.adapter.template import ConversationAdapter, PromptType
from dbgpt.model.base import ModelType
from dbgpt.model.conversation import Conversation
from dbgpt.model.llm.conversation import Conversation
from dbgpt.model.parameter import (
LlamaCppModelParameters,
ModelParameters,

View File

@@ -1,13 +1,11 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import time
from dataclasses import asdict, dataclass
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, TypedDict
from typing import Dict, List, Optional
from dbgpt.util.model_utils import GPUInfo
from dbgpt.util.parameter_utils import ParameterDescription

View File

@@ -12,9 +12,9 @@ from dbgpt.core import (
ModelOutput,
)
from dbgpt.model.adapter.base import LLMModelAdapter
from dbgpt.model.adapter.loader import ModelLoader, _get_model_real_path
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
from dbgpt.model.cluster.worker_base import ModelWorker
from dbgpt.model.loader import ModelLoader, _get_model_real_path
from dbgpt.model.parameter import ModelParameters
from dbgpt.util.model_utils import _clear_model_cache, _get_current_cuda_memory
from dbgpt.util.parameter_utils import EnvArgumentParser, _get_dict_from_obj

View File

@@ -3,9 +3,9 @@ from typing import Dict, List, Optional, Type
from dbgpt.configs.model_config import get_device
from dbgpt.core import ModelMetadata
from dbgpt.model.adapter.loader import _get_model_real_path
from dbgpt.model.cluster.embedding.loader import EmbeddingLoader
from dbgpt.model.cluster.worker_base import ModelWorker
from dbgpt.model.loader import _get_model_real_path
from dbgpt.model.parameter import (
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
BaseEmbeddingModelParameters,

View File

@@ -8,7 +8,7 @@ import sys
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict
from typing import Awaitable, Callable, Dict, Iterator, List
from typing import Awaitable, Callable, Iterator
from fastapi import APIRouter, FastAPI
from fastapi.responses import StreamingResponse
@@ -16,12 +16,7 @@ from fastapi.responses import StreamingResponse
from dbgpt.component import SystemApp
from dbgpt.configs.model_config import LOGDIR
from dbgpt.core import ModelMetadata, ModelOutput
from dbgpt.model.base import (
ModelInstance,
WorkerApplyOutput,
WorkerApplyType,
WorkerSupportedModel,
)
from dbgpt.model.base import ModelInstance, WorkerApplyOutput, WorkerSupportedModel
from dbgpt.model.cluster.base import *
from dbgpt.model.cluster.manager_base import (
WorkerManager,
@@ -30,8 +25,8 @@ from dbgpt.model.cluster.manager_base import (
)
from dbgpt.model.cluster.registry import ModelRegistry
from dbgpt.model.cluster.worker_base import ModelWorker
from dbgpt.model.llm_utils import list_supported_models
from dbgpt.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
from dbgpt.model.parameter import ModelWorkerParameters, WorkerType
from dbgpt.model.utils.llm_utils import list_supported_models
from dbgpt.util.parameter_utils import (
EnvArgumentParser,
ParameterDescription,

View File

@@ -18,7 +18,7 @@ from transformers.generation.logits_process import (
TopPLogitsWarper,
)
from dbgpt.model.llm_utils import is_partial_stop, is_sentence_complete
from dbgpt.model.utils.llm_utils import is_partial_stop, is_sentence_complete
def prepare_logits_processor(

View File

@@ -1,9 +1,9 @@
from dbgpt.model.operator.llm_operator import (
from dbgpt.model.operator.llm_operator import ( # noqa: F401
LLMOperator,
MixinLLMOperator,
StreamingLLMOperator,
)
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator # noqa: F401
__ALL__ = [
"MixinLLMOperator",

View File

@@ -6,7 +6,6 @@ from dbgpt.component import ComponentType
from dbgpt.core import LLMClient
from dbgpt.core.awel import BaseOperator
from dbgpt.core.operator import BaseLLM, BaseLLMOperator, BaseStreamingLLMOperator
from dbgpt.model.cluster import WorkerManagerFactory
logger = logging.getLogger(__name__)
@@ -19,31 +18,30 @@ class MixinLLMOperator(BaseLLM, BaseOperator, ABC):
def __init__(self, default_client: Optional[LLMClient] = None, **kwargs):
super().__init__(default_client)
self._default_llm_client = default_client
@property
def llm_client(self) -> LLMClient:
if not self._llm_client:
worker_manager_factory: WorkerManagerFactory = (
self.system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY,
WorkerManagerFactory,
default_component=None,
)
)
if worker_manager_factory:
try:
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.model.cluster.client import DefaultLLMClient
self._llm_client = DefaultLLMClient(worker_manager_factory.create())
else:
if self._default_llm_client is None:
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
self._default_llm_client = OpenAILLMClient()
logger.info(
f"Can't find worker manager factory, use default llm client {self._default_llm_client}."
worker_manager_factory: WorkerManagerFactory = (
self.system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY,
WorkerManagerFactory,
default_component=None,
)
)
self._llm_client = self._default_llm_client
if worker_manager_factory:
self._llm_client = DefaultLLMClient(worker_manager_factory.create())
except Exception as e:
logger.warning(f"Load worker manager failed: {e}.")
if not self._llm_client:
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
logger.info("Can't find worker manager factory, use OpenAILLMClient.")
self._llm_client = OpenAILLMClient()
return self._llm_client

View File

@@ -6,11 +6,8 @@ from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, Optional, Tuple, Union
from dbgpt.model.conversation import conv_templates
from dbgpt.util.parameter_utils import BaseParameters
suported_prompt_templates = ",".join(conv_templates.keys())
class WorkerType(str, Enum):
LLM = "llm"
@@ -299,7 +296,8 @@ class ModelParameters(BaseModelParameters):
prompt_template: Optional[str] = field(
default=None,
metadata={
"help": f"Prompt template. If None, the prompt template is automatically determined from model path, supported template: {suported_prompt_templates}"
"help": f"Prompt template. If None, the prompt template is automatically "
f"determined from model path"
},
)
max_context_size: Optional[int] = field(
@@ -450,7 +448,8 @@ class ProxyModelParameters(BaseModelParameters):
proxyllm_backend: Optional[str] = field(
default=None,
metadata={
"help": "The model name actually pass to current proxy server url, such as gpt-3.5-turbo, gpt-4, chatglm_pro, chatglm_std and so on"
"help": "The model name actually pass to current proxy server url, such "
"as gpt-3.5-turbo, gpt-4, chatglm_pro, chatglm_std and so on"
},
)
model_type: Optional[str] = field(
@@ -463,13 +462,15 @@ class ProxyModelParameters(BaseModelParameters):
device: Optional[str] = field(
default=None,
metadata={
"help": "Device to run model. If None, the device is automatically determined"
"help": "Device to run model. If None, the device is automatically "
"determined"
},
)
prompt_template: Optional[str] = field(
default=None,
metadata={
"help": f"Prompt template. If None, the prompt template is automatically determined from model path, supported template: {suported_prompt_templates}"
"help": f"Prompt template. If None, the prompt template is automatically "
f"determined from model path"
},
)
max_context_size: Optional[int] = field(
@@ -478,7 +479,8 @@ class ProxyModelParameters(BaseModelParameters):
llm_client_class: Optional[str] = field(
default=None,
metadata={
"help": "The class name of llm client, such as dbgpt.model.proxy.llms.proxy_model.ProxyModel"
"help": "The class name of llm client, such as "
"dbgpt.model.proxy.llms.proxy_model.ProxyModel"
},
)

View File

@@ -37,8 +37,8 @@ def list_supported_models():
def _list_supported_models(
worker_type: str, model_config: Dict[str, str]
) -> List[SupportedModel]:
from dbgpt.model.adapter.loader import _get_model_real_path
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
from dbgpt.model.loader import _get_model_real_path
ret = []
for model_name, model_path in model_config.items():