mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-16 22:51:24 +00:00
feat(core): APP use new SDK component (#1050)
This commit is contained in:
@@ -1,19 +1,20 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Any, Tuple, Type, Callable
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
|
||||
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
from dbgpt.model.adapter.template import (
|
||||
ConversationAdapter,
|
||||
ConversationAdapterFactory,
|
||||
get_conv_template,
|
||||
)
|
||||
from dbgpt.model.base import ModelType
|
||||
from dbgpt.model.parameter import (
|
||||
BaseModelParameters,
|
||||
ModelParameters,
|
||||
LlamaCppModelParameters,
|
||||
ModelParameters,
|
||||
ProxyModelParameters,
|
||||
)
|
||||
from dbgpt.model.adapter.template import (
|
||||
get_conv_template,
|
||||
ConversationAdapter,
|
||||
ConversationAdapterFactory,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@@ -2,17 +2,17 @@
|
||||
|
||||
You can import fastchat only in this file, so that the user does not need to install fastchat if he does not use it.
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import logging
|
||||
from functools import cache
|
||||
from typing import TYPE_CHECKING, Callable, Tuple, List, Optional
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
|
||||
|
||||
try:
|
||||
from fastchat.conversation import (
|
||||
Conversation,
|
||||
register_conv_template,
|
||||
SeparatorStyle,
|
||||
register_conv_template,
|
||||
)
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
@@ -20,8 +20,8 @@ except ImportError as exc:
|
||||
"Please install fastchat by command `pip install fschat` "
|
||||
) from exc
|
||||
|
||||
from dbgpt.model.adapter.template import ConversationAdapter, PromptType
|
||||
from dbgpt.model.adapter.base import LLMModelAdapter
|
||||
from dbgpt.model.adapter.template import ConversationAdapter, PromptType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastchat.model.model_adapter import BaseModelAdapter
|
||||
|
@@ -1,10 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional, List, Any
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from dbgpt.core import ModelMessage
|
||||
from dbgpt.model.base import ModelType
|
||||
from dbgpt.model.adapter.base import LLMModelAdapter, register_model_adapter
|
||||
from dbgpt.model.base import ModelType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -42,7 +42,7 @@ class NewHFChatModelAdapter(LLMModelAdapter, ABC):
|
||||
def load(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
try:
|
||||
import transformers
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
|
||||
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
"Could not import depend python package "
|
||||
|
@@ -1,21 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
List,
|
||||
Type,
|
||||
Optional,
|
||||
)
|
||||
import logging
|
||||
import threading
|
||||
import os
|
||||
import threading
|
||||
from functools import cache
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from dbgpt.model.adapter.base import LLMModelAdapter, get_model_adapter
|
||||
from dbgpt.model.adapter.template import ConversationAdapter, ConversationAdapterFactory
|
||||
from dbgpt.model.base import ModelType
|
||||
from dbgpt.model.parameter import BaseModelParameters
|
||||
from dbgpt.model.adapter.base import LLMModelAdapter, get_model_adapter
|
||||
from dbgpt.model.adapter.template import (
|
||||
ConversationAdapter,
|
||||
ConversationAdapterFactory,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -64,9 +58,9 @@ def get_llm_model_adapter(
|
||||
if use_fastchat and not must_use_old:
|
||||
logger.info("Use fastcat adapter")
|
||||
from dbgpt.model.adapter.fschat_adapter import (
|
||||
_get_fastchat_model_adapter,
|
||||
_fastchat_get_adapter_monkey_patch,
|
||||
FastChatLLMModelAdapterWrapper,
|
||||
_fastchat_get_adapter_monkey_patch,
|
||||
_get_fastchat_model_adapter,
|
||||
)
|
||||
|
||||
adapter = _get_fastchat_model_adapter(
|
||||
@@ -79,11 +73,11 @@ def get_llm_model_adapter(
|
||||
result_adapter = FastChatLLMModelAdapterWrapper(adapter)
|
||||
|
||||
else:
|
||||
from dbgpt.app.chat_adapter import get_llm_chat_adapter
|
||||
from dbgpt.model.adapter.old_adapter import OldLLMModelAdapterWrapper
|
||||
from dbgpt.model.adapter.old_adapter import (
|
||||
get_llm_model_adapter as _old_get_llm_model_adapter,
|
||||
OldLLMModelAdapterWrapper,
|
||||
)
|
||||
from dbgpt.app.chat_adapter import get_llm_chat_adapter
|
||||
|
||||
logger.info("Use DB-GPT old adapter")
|
||||
result_adapter = OldLLMModelAdapterWrapper(
|
||||
@@ -139,12 +133,12 @@ def _dynamic_model_parser() -> Optional[List[Type[BaseModelParameters]]]:
|
||||
Returns:
|
||||
Optional[List[Type[BaseModelParameters]]]: The model parameters class list.
|
||||
"""
|
||||
from dbgpt.util.parameter_utils import _SimpleArgParser
|
||||
from dbgpt.model.parameter import (
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
||||
EmbeddingModelParameters,
|
||||
WorkerType,
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
||||
)
|
||||
from dbgpt.util.parameter_utils import _SimpleArgParser
|
||||
|
||||
pre_args = _SimpleArgParser("model_name", "model_path", "worker_type", "model_type")
|
||||
pre_args.parse()
|
||||
|
@@ -5,31 +5,26 @@ We have integrated fastchat. For details, see: dbgpt/model/model_adapter.py
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, TYPE_CHECKING, Optional
|
||||
from functools import cache
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
LlamaTokenizer,
|
||||
)
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.configs.model_config import get_device
|
||||
from dbgpt.model.adapter.base import LLMModelAdapter
|
||||
from dbgpt.model.adapter.template import ConversationAdapter, PromptType
|
||||
from dbgpt.model.base import ModelType
|
||||
|
||||
from dbgpt.model.conversation import Conversation
|
||||
from dbgpt.model.parameter import (
|
||||
ModelParameters,
|
||||
LlamaCppModelParameters,
|
||||
ModelParameters,
|
||||
ProxyModelParameters,
|
||||
)
|
||||
from dbgpt.model.conversation import Conversation
|
||||
from dbgpt.configs.model_config import get_device
|
||||
from dbgpt._private.config import Config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dbgpt.app.chat_adapter import BaseChatAdpter
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, Union, List
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastchat.conversation import Conversation
|
||||
@@ -124,6 +124,7 @@ def get_conv_template(name: str) -> ConversationAdapter:
|
||||
Conversation: The conversation template.
|
||||
"""
|
||||
from fastchat.conversation import get_conv_template
|
||||
|
||||
from dbgpt.model.adapter.fschat_adapter import FschatConversationAdapter
|
||||
|
||||
conv_template = get_conv_template(name)
|
||||
|
@@ -1,12 +1,13 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
from dbgpt.model.base import ModelType
|
||||
|
||||
from dbgpt.model.adapter.base import LLMModelAdapter
|
||||
from dbgpt.model.adapter.template import ConversationAdapter, ConversationAdapterFactory
|
||||
from dbgpt.model.base import ModelType
|
||||
from dbgpt.model.parameter import BaseModelParameters
|
||||
from dbgpt.util.parameter_utils import (
|
||||
_extract_parameter_details,
|
||||
_build_parameter_class,
|
||||
_extract_parameter_details,
|
||||
_get_dataclass_print_str,
|
||||
)
|
||||
|
||||
@@ -27,6 +28,7 @@ class VLLMModelAdapterWrapper(LLMModelAdapter):
|
||||
|
||||
def model_param_class(self, model_type: str = None) -> BaseModelParameters:
|
||||
import argparse
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -56,9 +58,9 @@ class VLLMModelAdapterWrapper(LLMModelAdapter):
|
||||
return _build_parameter_class(descs)
|
||||
|
||||
def load_from_params(self, params):
|
||||
import torch
|
||||
from vllm import AsyncLLMEngine
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
import torch
|
||||
|
||||
num_gpus = torch.cuda.device_count()
|
||||
if num_gpus > 1 and hasattr(params, "tensor_parallel_size"):
|
||||
|
Reference in New Issue
Block a user