mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-03 18:17:45 +00:00
chore: Add pylint for DB-GPT core lib (#1076)
This commit is contained in:
@@ -8,6 +8,9 @@ import threading
|
||||
from functools import cache
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
|
||||
|
||||
from dbgpt.model.adapter.base import LLMModelAdapter
|
||||
from dbgpt.model.adapter.template import ConversationAdapter, PromptType
|
||||
|
||||
try:
|
||||
from fastchat.conversation import (
|
||||
Conversation,
|
||||
@@ -20,8 +23,6 @@ except ImportError as exc:
|
||||
"Please install fastchat by command `pip install fschat` "
|
||||
) from exc
|
||||
|
||||
from dbgpt.model.adapter.base import LLMModelAdapter
|
||||
from dbgpt.model.adapter.template import ConversationAdapter, PromptType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastchat.model.model_adapter import BaseModelAdapter
|
||||
|
@@ -196,7 +196,14 @@ class DefaultModelWorker(ModelWorker):
|
||||
return _try_to_count_token(prompt, self.tokenizer, self.model)
|
||||
|
||||
async def async_count_token(self, prompt: str) -> int:
|
||||
# TODO if we deploy the model by vllm, it can't work, we should run transformer _try_to_count_token to async
|
||||
# TODO if we deploy the model by vllm, it can't work, we should run
|
||||
# transformer _try_to_count_token to async
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
|
||||
if isinstance(self.model, ProxyModel) and self.model.proxy_llm_client:
|
||||
return await self.model.proxy_llm_client.count_token(
|
||||
self.model.proxy_llm_client.default_model, prompt
|
||||
)
|
||||
raise NotImplementedError
|
||||
|
||||
def get_model_metadata(self, params: Dict) -> ModelMetadata:
|
||||
|
@@ -118,9 +118,6 @@ def replace_llama_attn_with_non_inplace_operations():
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
||||
|
||||
|
||||
import transformers
|
||||
|
||||
|
||||
def replace_llama_attn_with_non_inplace_operations():
|
||||
"""Avoid bugs in mps backend by not using in-place operations."""
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
||||
|
@@ -196,6 +196,15 @@ class ProxyLLMClient(LLMClient):
|
||||
"""
|
||||
return self._models()
|
||||
|
||||
@property
|
||||
def default_model(self) -> str:
|
||||
"""Get default model name
|
||||
|
||||
Returns:
|
||||
str: default model name
|
||||
"""
|
||||
return self.model_names[0]
|
||||
|
||||
@cache
|
||||
def _models(self) -> List[ModelMetadata]:
|
||||
results = []
|
||||
@@ -237,6 +246,7 @@ class ProxyLLMClient(LLMClient):
|
||||
Returns:
|
||||
int: token count, -1 if failed
|
||||
"""
|
||||
return await blocking_func_to_async(
|
||||
counts = await blocking_func_to_async(
|
||||
self.executor, self.proxy_tokenizer.count_token, model, [prompt]
|
||||
)[0]
|
||||
)
|
||||
return counts[0]
|
||||
|
@@ -86,6 +86,11 @@ class OpenAILLMClient(ProxyLLMClient):
|
||||
self._openai_kwargs = openai_kwargs or {}
|
||||
super().__init__(model_names=[model_alias], context_length=context_length)
|
||||
|
||||
if self._openai_less_then_v1:
|
||||
from dbgpt.model.utils.chatgpt_utils import _initialize_openai
|
||||
|
||||
_initialize_openai(self._init_params)
|
||||
|
||||
@classmethod
|
||||
def new_client(
|
||||
cls,
|
||||
|
@@ -114,7 +114,6 @@ class GeminiLLMClient(ProxyLLMClient):
|
||||
self._api_key = api_key if api_key else os.getenv("GEMINI_PROXY_API_KEY")
|
||||
self._api_base = api_base if api_base else os.getenv("GEMINI_PROXY_API_BASE")
|
||||
self._model = model
|
||||
self.default_model = self._model
|
||||
if not self._api_key:
|
||||
raise RuntimeError("api_key can't be empty")
|
||||
|
||||
@@ -148,6 +147,10 @@ class GeminiLLMClient(ProxyLLMClient):
|
||||
executor=default_executor,
|
||||
)
|
||||
|
||||
@property
|
||||
def default_model(self) -> str:
|
||||
return self._model
|
||||
|
||||
def sync_generate_stream(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
|
@@ -8,9 +8,6 @@ from datetime import datetime
|
||||
from time import mktime
|
||||
from typing import Iterator, Optional
|
||||
from urllib.parse import urlencode, urlparse
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
from websockets.sync.client import connect
|
||||
|
||||
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
|
||||
from dbgpt.model.parameter import ProxyModelParameters
|
||||
@@ -56,6 +53,8 @@ def spark_generate_stream(
|
||||
|
||||
|
||||
def get_response(request_url, data):
|
||||
from websockets.sync.client import connect
|
||||
|
||||
with connect(request_url) as ws:
|
||||
ws.send(json.dumps(data, ensure_ascii=False))
|
||||
result = ""
|
||||
@@ -87,6 +86,8 @@ class SparkAPI:
|
||||
self.spark_url = spark_url
|
||||
|
||||
def gen_url(self):
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
# 生成RFC1123格式的时间戳
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
@@ -145,7 +146,6 @@ class SparkLLMClient(ProxyLLMClient):
|
||||
if not api_domain:
|
||||
api_domain = domain
|
||||
self._model = model
|
||||
self.default_model = self._model
|
||||
self._model_version = model_version
|
||||
self._api_base = api_base
|
||||
self._domain = api_domain
|
||||
@@ -183,6 +183,10 @@ class SparkLLMClient(ProxyLLMClient):
|
||||
executor=default_executor,
|
||||
)
|
||||
|
||||
@property
|
||||
def default_model(self) -> str:
|
||||
return self._model
|
||||
|
||||
def sync_generate_stream(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
|
@@ -51,7 +51,6 @@ class TongyiLLMClient(ProxyLLMClient):
|
||||
if api_region:
|
||||
dashscope.api_region = api_region
|
||||
self._model = model
|
||||
self.default_model = self._model
|
||||
|
||||
super().__init__(
|
||||
model_names=[model, model_alias],
|
||||
@@ -73,6 +72,10 @@ class TongyiLLMClient(ProxyLLMClient):
|
||||
executor=default_executor,
|
||||
)
|
||||
|
||||
@property
|
||||
def default_model(self) -> str:
|
||||
return self._model
|
||||
|
||||
def sync_generate_stream(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
|
@@ -121,7 +121,6 @@ class WenxinLLMClient(ProxyLLMClient):
|
||||
self._api_key = api_key
|
||||
self._api_secret = api_secret
|
||||
self._model_version = model_version
|
||||
self.default_model = self._model
|
||||
|
||||
super().__init__(
|
||||
model_names=[model, model_alias],
|
||||
@@ -145,6 +144,10 @@ class WenxinLLMClient(ProxyLLMClient):
|
||||
executor=default_executor,
|
||||
)
|
||||
|
||||
@property
|
||||
def default_model(self) -> str:
|
||||
return self._model
|
||||
|
||||
def sync_generate_stream(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
|
@@ -54,7 +54,6 @@ class ZhipuLLMClient(ProxyLLMClient):
|
||||
if api_key:
|
||||
zhipuai.api_key = api_key
|
||||
self._model = model
|
||||
self.default_model = self._model
|
||||
|
||||
super().__init__(
|
||||
model_names=[model, model_alias],
|
||||
@@ -76,6 +75,10 @@ class ZhipuLLMClient(ProxyLLMClient):
|
||||
executor=default_executor,
|
||||
)
|
||||
|
||||
@property
|
||||
def default_model(self) -> str:
|
||||
return self._model
|
||||
|
||||
def sync_generate_stream(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
|
@@ -88,6 +88,42 @@ def _initialize_openai_v1(init_params: OpenAIParameters):
|
||||
return openai_params, api_type, api_version
|
||||
|
||||
|
||||
def _initialize_openai(params: OpenAIParameters):
|
||||
try:
|
||||
import openai
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
"Could not import python package: openai "
|
||||
"Please install openai by command `pip install openai` "
|
||||
) from exc
|
||||
|
||||
api_type = params.api_type or os.getenv("OPENAI_API_TYPE", "open_ai")
|
||||
|
||||
api_base = params.api_base or os.getenv(
|
||||
"OPENAI_API_TYPE",
|
||||
os.getenv("AZURE_OPENAI_ENDPOINT") if api_type == "azure" else None,
|
||||
)
|
||||
api_key = params.api_key or os.getenv(
|
||||
"OPENAI_API_KEY",
|
||||
os.getenv("AZURE_OPENAI_KEY") if api_type == "azure" else None,
|
||||
)
|
||||
api_version = params.api_version or os.getenv("OPENAI_API_VERSION")
|
||||
|
||||
if not api_base and params.full_url:
|
||||
# Adapt previous proxy_server_url configuration
|
||||
api_base = params.full_url.split("/chat/completions")[0]
|
||||
if api_type:
|
||||
openai.api_type = api_type
|
||||
if api_base:
|
||||
openai.api_base = api_base
|
||||
if api_key:
|
||||
openai.api_key = api_key
|
||||
if api_version:
|
||||
openai.api_version = api_version
|
||||
if params.proxies:
|
||||
openai.proxy = params.proxies
|
||||
|
||||
|
||||
def _build_openai_client(init_params: OpenAIParameters) -> Tuple[str, ClientType]:
|
||||
import httpx
|
||||
|
||||
@@ -112,9 +148,7 @@ def _build_openai_client(init_params: OpenAIParameters) -> Tuple[str, ClientType
|
||||
class OpenAIStreamingOutputOperator(TransformStreamAbsOperator[ModelOutput, str]):
|
||||
"""Transform ModelOutput to openai stream format."""
|
||||
|
||||
async def transform_stream(
|
||||
self, input_value: AsyncIterator[ModelOutput]
|
||||
) -> AsyncIterator[str]:
|
||||
async def transform_stream(self, input_value: AsyncIterator[ModelOutput]):
|
||||
async def model_caller() -> str:
|
||||
"""Read model name from share data.
|
||||
In streaming mode, this transform_stream function will be executed
|
||||
|
Reference in New Issue
Block a user