mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-02 17:45:31 +00:00
feat(model): Add new LLMClient and new build tools (#967)
This commit is contained in:
61
Makefile
Normal file
61
Makefile
Normal file
@@ -0,0 +1,61 @@
|
||||
.DEFAULT_GOAL := help
|
||||
|
||||
SHELL=/bin/bash
|
||||
VENV = venv
|
||||
|
||||
# Detect the operating system and set the virtualenv bin directory
|
||||
ifeq ($(OS),Windows_NT)
|
||||
VENV_BIN=$(VENV)/Scripts
|
||||
else
|
||||
VENV_BIN=$(VENV)/bin
|
||||
endif
|
||||
|
||||
setup: ## Set up the Python development environment
|
||||
python3 -m venv $(VENV)
|
||||
$(VENV_BIN)/pip install --upgrade pip
|
||||
$(VENV_BIN)/pip install -r requirements/dev-requirements.txt
|
||||
$(VENV_BIN)/pip install -r requirements/lint-requirements.txt
|
||||
|
||||
testenv: setup ## Set up the Python test environment
|
||||
$(VENV_BIN)/pip install -e ".[simple_framework]"
|
||||
|
||||
.PHONY: fmt
|
||||
fmt: setup ## Format Python code
|
||||
# TODO: Use isort to sort Python imports.
|
||||
# https://github.com/PyCQA/isort
|
||||
# $(VENV_BIN)/isort .
|
||||
# https://github.com/psf/black
|
||||
$(VENV_BIN)/black .
|
||||
# TODO: Use blackdoc to format Python doctests.
|
||||
# https://blackdoc.readthedocs.io/en/latest/
|
||||
# $(VENV_BIN)/blackdoc .
|
||||
# TODO: Type checking of Python code.
|
||||
# https://github.com/python/mypy
|
||||
# $(VENV_BIN)/mypy dbgpt
|
||||
# TODO: uUse flake8 to enforce Python style guide.
|
||||
# https://flake8.pycqa.org/en/latest/
|
||||
# $(VENV_BIN)/flake8 dbgpt
|
||||
|
||||
.PHONY: pre-commit
|
||||
pre-commit: fmt test ## Run formatting and unit tests before committing
|
||||
|
||||
.PHONY: test
|
||||
test: testenv ## Run unit tests
|
||||
$(VENV_BIN)/pytest dbgpt
|
||||
|
||||
.PHONY: coverage
|
||||
coverage: setup ## Run tests and report coverage
|
||||
$(VENV_BIN)/pytest dbgpt --cov=dbgpt
|
||||
|
||||
.PHONY: clean
|
||||
clean: ## Clean up the environment
|
||||
rm -rf $(VENV)
|
||||
find . -type f -name '*.pyc' -delete
|
||||
find . -type d -name '__pycache__' -delete
|
||||
find . -type d -name '.pytest_cache' -delete
|
||||
find . -type d -name '.coverage' -delete
|
||||
|
||||
.PHONY: help
|
||||
help: ## Display this help screen
|
||||
@echo "Available commands:"
|
||||
@grep -E '^[a-z.A-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-18s\033[0m %s\n", $$1, $$2}' | sort
|
@@ -1,9 +1,12 @@
|
||||
from dbgpt.core.interface.llm import (
|
||||
ModelInferenceMetrics,
|
||||
ModelRequest,
|
||||
ModelOutput,
|
||||
OpenAILLM,
|
||||
BaseLLMOperator,
|
||||
LLMClient,
|
||||
LLMOperator,
|
||||
StreamingLLMOperator,
|
||||
RequestBuildOperator,
|
||||
ModelMetadata,
|
||||
)
|
||||
from dbgpt.core.interface.message import (
|
||||
ModelMessage,
|
||||
@@ -37,11 +40,15 @@ from dbgpt.core.interface.storage import (
|
||||
|
||||
__ALL__ = [
|
||||
"ModelInferenceMetrics",
|
||||
"ModelRequest",
|
||||
"ModelOutput",
|
||||
"OpenAILLM",
|
||||
"BaseLLMOperator",
|
||||
"Operator",
|
||||
"RequestBuildOperator",
|
||||
"ModelMetadata",
|
||||
"ModelMessage",
|
||||
"LLMClient",
|
||||
"LLMOperator",
|
||||
"StreamingLLMOperator",
|
||||
"ModelMessageRoleType",
|
||||
"OnceConversation",
|
||||
"StorageConversation",
|
||||
|
@@ -211,7 +211,9 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
Returns:
|
||||
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
|
||||
"""
|
||||
out_ctx = await self._runner.execute_workflow(self, call_data)
|
||||
out_ctx = await self._runner.execute_workflow(
|
||||
self, call_data, streaming_call=True
|
||||
)
|
||||
return out_ctx.current_task_context.task_output.output_stream
|
||||
|
||||
def _blocking_call_stream(
|
||||
|
@@ -130,8 +130,9 @@ async def _trigger_dag(
|
||||
"Connection": "keep-alive",
|
||||
"Transfer-Encoding": "chunked",
|
||||
}
|
||||
generator = await end_node.call_stream(call_data={"data": body})
|
||||
return StreamingResponse(
|
||||
end_node.call_stream(call_data={"data": body}),
|
||||
generator,
|
||||
headers=headers,
|
||||
media_type=media_type,
|
||||
)
|
||||
|
@@ -1,10 +1,10 @@
|
||||
from abc import ABC
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Dict, List, Any, Union, AsyncIterator
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, asdict
|
||||
from dataclasses import dataclass, asdict, field
|
||||
import copy
|
||||
|
||||
from dbgpt.util import BaseParameters
|
||||
from dbgpt.util.annotations import PublicAPI
|
||||
from dbgpt.util.model_utils import GPUInfo
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
@@ -12,6 +12,7 @@ from dbgpt.core.awel import MapOperator, StreamifyAbsOperator
|
||||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="beta")
|
||||
class ModelInferenceMetrics:
|
||||
"""A class to represent metrics for assessing the inference performance of a LLM."""
|
||||
|
||||
@@ -97,6 +98,7 @@ class ModelInferenceMetrics:
|
||||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="beta")
|
||||
class ModelOutput:
|
||||
"""A class to represent the output of a LLM.""" ""
|
||||
|
||||
@@ -118,6 +120,7 @@ _ModelMessageType = Union[ModelMessage, Dict[str, Any]]
|
||||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="beta")
|
||||
class ModelRequest:
|
||||
model: str
|
||||
"""The name of the model."""
|
||||
@@ -142,7 +145,7 @@ class ModelRequest:
|
||||
span_id: Optional[str] = None
|
||||
"""The span id of the model inference."""
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
new_reqeust = copy.deepcopy(self)
|
||||
new_reqeust.messages = list(
|
||||
map(lambda m: m if isinstance(m, dict) else m.dict(), new_reqeust.messages)
|
||||
@@ -166,6 +169,110 @@ class ModelRequest:
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def to_openai_messages(self) -> List[Dict[str, Any]]:
|
||||
"""Convert the messages to the format of OpenAI API.
|
||||
|
||||
This function will move last user message to the end of the list.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: The messages in the format of OpenAI API.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
messages = [
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hi"),
|
||||
ModelMessage(role=ModelMessageRoleType.AI, content="Hi, I'm a robot.")
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Who are your"),
|
||||
]
|
||||
openai_messages = ModelRequest.to_openai_messages(messages)
|
||||
assert openai_messages == [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Hi, I'm a robot."},
|
||||
{"role": "user", "content": "Who are your"},
|
||||
]
|
||||
"""
|
||||
messages = [
|
||||
m if isinstance(m, ModelMessage) else ModelMessage(**m)
|
||||
for m in self.messages
|
||||
]
|
||||
return ModelMessage.to_openai_messages(messages)
|
||||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="beta")
|
||||
class ModelMetadata(BaseParameters):
|
||||
"""A class to represent a LLM model."""
|
||||
|
||||
model: str = field(
|
||||
metadata={"help": "Model name"},
|
||||
)
|
||||
context_length: Optional[int] = field(
|
||||
default=4096,
|
||||
metadata={"help": "Context length of model"},
|
||||
)
|
||||
chat_model: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether the model is a chat model"},
|
||||
)
|
||||
is_function_calling_model: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether the model is a function calling model"},
|
||||
)
|
||||
metadata: Optional[Dict[str, Any]] = field(
|
||||
default_factory=dict,
|
||||
metadata={"help": "Model metadata"},
|
||||
)
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
class LLMClient(ABC):
|
||||
"""An abstract class for LLM client."""
|
||||
|
||||
@abstractmethod
|
||||
async def generate(self, request: ModelRequest) -> ModelOutput:
|
||||
"""Generate a response for a given model request.
|
||||
|
||||
Args:
|
||||
request(ModelRequest): The model request.
|
||||
|
||||
Returns:
|
||||
ModelOutput: The model output.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def generate_stream(
|
||||
self, request: ModelRequest
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
"""Generate a stream of responses for a given model request.
|
||||
|
||||
Args:
|
||||
request(ModelRequest): The model request.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[ModelOutput]: The model output stream.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def models(self) -> List[ModelMetadata]:
|
||||
"""Get all the models.
|
||||
|
||||
Returns:
|
||||
List[ModelMetadata]: A list of model metadata.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def count_token(self, model: str, prompt: str) -> int:
|
||||
"""Count the number of tokens in a given prompt.
|
||||
|
||||
Args:
|
||||
model(str): The model name.
|
||||
prompt(str): The prompt.
|
||||
|
||||
Returns:
|
||||
int: The number of tokens.
|
||||
"""
|
||||
|
||||
|
||||
class RequestBuildOperator(MapOperator[str, ModelRequest], ABC):
|
||||
def __init__(self, model: str, **kwargs):
|
||||
@@ -176,85 +283,52 @@ class RequestBuildOperator(MapOperator[str, ModelRequest], ABC):
|
||||
return ModelRequest._build(self._model, input_value)
|
||||
|
||||
|
||||
class BaseLLMOperator(
|
||||
MapOperator[ModelRequest, ModelOutput],
|
||||
StreamifyAbsOperator[ModelRequest, ModelOutput],
|
||||
ABC,
|
||||
):
|
||||
class BaseLLM:
|
||||
"""The abstract operator for a LLM."""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None):
|
||||
self._llm_client = llm_client
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
class OpenAILLM(BaseLLMOperator):
|
||||
"""The operator for OpenAI LLM.
|
||||
@property
|
||||
def llm_client(self) -> LLMClient:
|
||||
"""Return the LLM client."""
|
||||
if not self._llm_client:
|
||||
raise ValueError("llm_client is not set")
|
||||
return self._llm_client
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
llm = OpenAILLM()
|
||||
model_request = ModelRequest(model="gpt-3.5-turbo", messages=[ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello")])
|
||||
model_output = await llm.map(model_request)
|
||||
class LLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
|
||||
"""The operator for a LLM.
|
||||
|
||||
Args:
|
||||
llm_client (LLMClient, optional): The LLM client. Defaults to None.
|
||||
|
||||
This operator will generate a no streaming response.
|
||||
"""
|
||||
|
||||
def __int__(self):
|
||||
try:
|
||||
import openai
|
||||
except ImportError as e:
|
||||
raise ImportError("Please install openai package to use OpenAILLM") from e
|
||||
import importlib.metadata as metadata
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client=llm_client)
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
|
||||
if not metadata.version("openai") >= "1.0.0":
|
||||
raise ImportError("Please upgrade openai package to version 1.0.0 or above")
|
||||
async def map(self, request: ModelRequest) -> ModelOutput:
|
||||
return await self.llm_client.generate(request)
|
||||
|
||||
async def _send_request(
|
||||
self, model_request: ModelRequest, stream: Optional[bool] = False
|
||||
):
|
||||
import os
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
client = AsyncOpenAI(
|
||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
base_url=os.environ.get("OPENAI_API_BASE"),
|
||||
)
|
||||
messages = ModelMessage.to_openai_messages(model_request._get_messages())
|
||||
payloads = {
|
||||
"model": model_request.model,
|
||||
"stream": stream,
|
||||
}
|
||||
if model_request.temperature is not None:
|
||||
payloads["temperature"] = model_request.temperature
|
||||
if model_request.max_new_tokens:
|
||||
payloads["max_tokens"] = model_request.max_new_tokens
|
||||
class StreamingLLMOperator(
|
||||
BaseLLM, StreamifyAbsOperator[ModelRequest, ModelOutput], ABC
|
||||
):
|
||||
"""The streaming operator for a LLM.
|
||||
|
||||
return await client.chat.completions.create(messages=messages, **payloads)
|
||||
Args:
|
||||
llm_client (LLMClient, optional): The LLM client. Defaults to None.
|
||||
|
||||
async def map(self, model_request: ModelRequest) -> ModelOutput:
|
||||
try:
|
||||
chat_completion = await self._send_request(model_request, stream=False)
|
||||
text = chat_completion.choices[0].message.content
|
||||
usage = chat_completion.usage.dict()
|
||||
return ModelOutput(text=text, error_code=0, usage=usage)
|
||||
except Exception as e:
|
||||
return ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=1,
|
||||
)
|
||||
This operator will generate streaming response.
|
||||
"""
|
||||
|
||||
async def streamify(
|
||||
self, model_request: ModelRequest
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
try:
|
||||
chat_completion = await self._send_request(model_request, stream=True)
|
||||
text = ""
|
||||
for r in chat_completion:
|
||||
if len(r.choices) == 0:
|
||||
continue
|
||||
if r.choices[0].delta.content is not None:
|
||||
content = r.choices[0].delta.content
|
||||
text += content
|
||||
yield ModelOutput(text=text, error_code=0)
|
||||
except Exception as e:
|
||||
yield ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=1,
|
||||
)
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client=llm_client)
|
||||
StreamifyAbsOperator.__init__(self, **kwargs)
|
||||
|
||||
async def streamify(self, request: ModelRequest) -> AsyncIterator[ModelOutput]:
|
||||
async for output in self.llm_client.generate_stream(request):
|
||||
yield output
|
||||
|
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import sqlparse
|
||||
import regex as re
|
||||
import pandas as pd
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import quote_plus as urlquote
|
||||
from typing import Any, Iterable, List, Optional, Dict
|
||||
@@ -383,6 +382,10 @@ class RDBMSDatabase(BaseConnect):
|
||||
return self.get_simple_fields(table_name)
|
||||
|
||||
def run_to_df(self, command: str, fetch: str = "all"):
|
||||
import pandas as pd
|
||||
|
||||
# Pandas has too much dependence and the import time is too long
|
||||
# TODO: Remove the dependency on pandas
|
||||
result_lst = self.run(command, fetch)
|
||||
colunms = result_lst[0]
|
||||
values = result_lst[1:]
|
||||
|
@@ -1,13 +1,8 @@
|
||||
import re
|
||||
import sqlparse
|
||||
import clickhouse_connect
|
||||
from typing import List, Optional, Any, Iterable, Dict
|
||||
from sqlalchemy import text
|
||||
from urllib.parse import quote
|
||||
from sqlalchemy.schema import CreateTable
|
||||
from urllib.parse import quote_plus as urlquote
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
from clickhouse_connect.driver import httputil
|
||||
from dbgpt.storage.schema import DBType
|
||||
from sqlalchemy import (
|
||||
MetaData,
|
||||
@@ -56,6 +51,11 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
engine_args: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> RDBMSDatabase:
|
||||
import clickhouse_connect
|
||||
from clickhouse_connect.driver import httputil
|
||||
|
||||
# Lazy import
|
||||
|
||||
big_pool_mgr = httputil.get_pool_manager(maxsize=16, num_pools=12)
|
||||
client = clickhouse_connect.get_client(
|
||||
host=host,
|
||||
|
@@ -37,7 +37,14 @@ class SQLiteConnect(RDBMSDatabase):
|
||||
"""Get table indexes about specified table."""
|
||||
cursor = self.session.execute(text(f"PRAGMA index_list({table_name})"))
|
||||
indexes = cursor.fetchall()
|
||||
return [(index[1], index[3]) for index in indexes]
|
||||
result = []
|
||||
for idx in indexes:
|
||||
index_name = idx[1]
|
||||
cursor = self.session.execute(text(f"PRAGMA index_info({index_name})"))
|
||||
index_infos = cursor.fetchall()
|
||||
column_names = [index_info[2] for index_info in index_infos]
|
||||
result.append({"name": index_name, "column_names": column_names})
|
||||
return result
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
"""Get table show create table about specified table."""
|
||||
|
@@ -47,7 +47,8 @@ def test_run_no_throw(db):
|
||||
def test_get_indexes(db):
|
||||
db.run("CREATE TABLE test (name TEXT);")
|
||||
db.run("CREATE INDEX idx_name ON test(name);")
|
||||
assert db.get_indexes("test") == [("idx_name", "c")]
|
||||
indexes = db.get_indexes("test")
|
||||
assert indexes == [{"name": "idx_name", "column_names": ["name"]}]
|
||||
|
||||
|
||||
def test_get_indexes_empty(db):
|
||||
|
@@ -0,0 +1,4 @@
|
||||
from dbgpt.model.cluster.client import DefaultLLMClient
|
||||
from dbgpt.model.utils.chatgpt_utils import OpenAILLMClient
|
||||
|
||||
__ALL__ = ["DefaultLLMClient", "OpenAILLMClient"]
|
||||
|
@@ -30,6 +30,15 @@ class EmbeddingsRequest(BaseModel):
|
||||
span_id: str = None
|
||||
|
||||
|
||||
class CountTokenRequest(BaseModel):
|
||||
model: str
|
||||
prompt: str
|
||||
|
||||
|
||||
class ModelMetadataRequest(BaseModel):
|
||||
model: str
|
||||
|
||||
|
||||
class WorkerApplyRequest(BaseModel):
|
||||
model: str
|
||||
apply_type: WorkerApplyType
|
||||
|
40
dbgpt/model/cluster/client.py
Normal file
40
dbgpt/model/cluster/client.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from typing import AsyncIterator, List
|
||||
import asyncio
|
||||
from dbgpt.core.interface.llm import LLMClient, ModelRequest, ModelOutput, ModelMetadata
|
||||
from dbgpt.model.parameter import WorkerType
|
||||
from dbgpt.model.cluster.manager_base import WorkerManager
|
||||
|
||||
|
||||
class DefaultLLMClient(LLMClient):
|
||||
def __init__(self, worker_manager: WorkerManager):
|
||||
self._worker_manager = worker_manager
|
||||
|
||||
async def generate(self, request: ModelRequest) -> ModelOutput:
|
||||
return await self._worker_manager.generate(request.to_dict())
|
||||
|
||||
async def generate_stream(
|
||||
self, request: ModelRequest
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
async for output in self._worker_manager.generate_stream(request.to_dict()):
|
||||
yield output
|
||||
|
||||
async def models(self) -> List[ModelMetadata]:
|
||||
instances = await self._worker_manager.get_all_model_instances(
|
||||
WorkerType.LLM.value, healthy_only=True
|
||||
)
|
||||
query_metadata_task = []
|
||||
for instance in instances:
|
||||
worker_name, _ = WorkerType.parse_worker_key(instance.worker_key)
|
||||
query_metadata_task.append(
|
||||
self._worker_manager.get_model_metadata({"model": worker_name})
|
||||
)
|
||||
models: List[ModelMetadata] = await asyncio.gather(*query_metadata_task)
|
||||
model_map = {}
|
||||
for single_model in models:
|
||||
model_map[single_model.model] = single_model
|
||||
return [model_map[model_name] for model_name in sorted(model_map.keys())]
|
||||
|
||||
async def count_token(self, model: str, prompt: str) -> int:
|
||||
return await self._worker_manager.count_token(
|
||||
{"model": model, "prompt": prompt}
|
||||
)
|
@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from concurrent.futures import Future
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
from dbgpt.core import ModelOutput
|
||||
from dbgpt.core import ModelOutput, ModelMetadata
|
||||
from dbgpt.model.base import WorkerSupportedModel, WorkerApplyOutput
|
||||
from dbgpt.model.cluster.worker_base import ModelWorker
|
||||
from dbgpt.model.cluster.base import WorkerStartupRequest, WorkerApplyRequest
|
||||
@@ -38,6 +38,11 @@ class WorkerRunData:
|
||||
port = self.port
|
||||
return f"model {model_name}@{model_type}({host}:{port})"
|
||||
|
||||
@property
|
||||
def stopped(self):
|
||||
"""Check if the worker is stopped""" ""
|
||||
return self.stop_event.is_set()
|
||||
|
||||
|
||||
class WorkerManager(ABC):
|
||||
@abstractmethod
|
||||
@@ -62,6 +67,20 @@ class WorkerManager(ABC):
|
||||
) -> List[WorkerRunData]:
|
||||
"""Asynchronous get model instances by worker type and model name"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_all_model_instances(
|
||||
self, worker_type: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
"""Asynchronous get all model instances
|
||||
|
||||
Args:
|
||||
worker_type (str): worker type
|
||||
healthy_only (bool, optional): only return healthy instances. Defaults to True.
|
||||
|
||||
Returns:
|
||||
List[WorkerRunData]: worker run data list
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def sync_get_model_instances(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
@@ -112,6 +131,25 @@ class WorkerManager(ABC):
|
||||
We must provide a synchronous version.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def count_token(self, params: Dict) -> int:
|
||||
"""Count token of prompt
|
||||
|
||||
Args:
|
||||
params (Dict): parameters, eg. {"prompt": "hello", "model": "vicuna-13b-v1.5"}
|
||||
|
||||
Returns:
|
||||
int: token count
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_model_metadata(self, params: Dict) -> ModelMetadata:
|
||||
"""Get model metadata
|
||||
|
||||
Args:
|
||||
params (Dict): parameters, eg. {"model": "vicuna-13b-v1.5"}
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
||||
"""Worker apply"""
|
||||
|
@@ -3,7 +3,7 @@ import pytest_asyncio
|
||||
from contextlib import contextmanager, asynccontextmanager
|
||||
from typing import List, Iterator, Dict, Tuple
|
||||
from dbgpt.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
|
||||
from dbgpt.core import ModelOutput
|
||||
from dbgpt.core import ModelOutput, ModelMetadata
|
||||
from dbgpt.model.cluster.worker_base import ModelWorker
|
||||
from dbgpt.model.cluster.worker.manager import (
|
||||
WorkerManager,
|
||||
@@ -80,6 +80,14 @@ class MockModelWorker(ModelWorker):
|
||||
output = out
|
||||
return output
|
||||
|
||||
def count_token(self, prompt: str) -> int:
|
||||
return len(prompt)
|
||||
|
||||
def get_model_metadata(self, params: Dict) -> ModelMetadata:
|
||||
return ModelMetadata(
|
||||
model=self.model_parameters.model_name,
|
||||
)
|
||||
|
||||
def embeddings(self, params: Dict) -> List[List[float]]:
|
||||
return self._embeddings
|
||||
|
||||
|
@@ -8,7 +8,7 @@ import traceback
|
||||
from dbgpt.configs.model_config import get_device
|
||||
from dbgpt.model.adapter.base import LLMModelAdapter
|
||||
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
|
||||
from dbgpt.core import ModelOutput, ModelInferenceMetrics
|
||||
from dbgpt.core import ModelOutput, ModelInferenceMetrics, ModelMetadata
|
||||
from dbgpt.model.loader import ModelLoader, _get_model_real_path
|
||||
from dbgpt.model.parameter import ModelParameters
|
||||
from dbgpt.model.cluster.worker_base import ModelWorker
|
||||
@@ -118,6 +118,8 @@ class DefaultModelWorker(ModelWorker):
|
||||
f"Parse model max length {model_max_length} from model {self.model_name}."
|
||||
)
|
||||
self.context_len = model_max_length
|
||||
elif hasattr(model_params, "max_context_size"):
|
||||
self.context_len = model_params.max_context_size
|
||||
|
||||
def stop(self) -> None:
|
||||
if not self.model:
|
||||
@@ -186,6 +188,22 @@ class DefaultModelWorker(ModelWorker):
|
||||
output = out
|
||||
return output
|
||||
|
||||
def count_token(self, prompt: str) -> int:
|
||||
return _try_to_count_token(prompt, self.tokenizer)
|
||||
|
||||
async def async_count_token(self, prompt: str) -> int:
|
||||
# TODO if we deploy the model by vllm, it can't work, we should run transformer _try_to_count_token to async
|
||||
raise NotImplementedError
|
||||
|
||||
def get_model_metadata(self, params: Dict) -> ModelMetadata:
|
||||
return ModelMetadata(
|
||||
model=self.model_name,
|
||||
context_length=self.context_len,
|
||||
)
|
||||
|
||||
async def async_get_model_metadata(self, params: Dict) -> ModelMetadata:
|
||||
return self.get_model_metadata(params)
|
||||
|
||||
def embeddings(self, params: Dict) -> List[List[float]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -436,6 +454,25 @@ def _new_metrics_from_model_output(
|
||||
return metrics
|
||||
|
||||
|
||||
def _try_to_count_token(prompt: str, tokenizer) -> int:
|
||||
"""Try to count token of prompt
|
||||
|
||||
Args:
|
||||
prompt (str): prompt
|
||||
tokenizer ([type]): tokenizer
|
||||
|
||||
Returns:
|
||||
int: token count, if error return -1
|
||||
|
||||
TODO: More implementation
|
||||
"""
|
||||
try:
|
||||
return len(tokenizer(prompt).input_ids[0])
|
||||
except Exception as e:
|
||||
logger.warning(f"Count token error, detail: {e}, return -1")
|
||||
return -1
|
||||
|
||||
|
||||
def _try_import_torch():
|
||||
global torch
|
||||
global _torch_imported
|
||||
|
@@ -2,6 +2,7 @@ import logging
|
||||
from typing import Dict, List, Type, Optional
|
||||
|
||||
from dbgpt.configs.model_config import get_device
|
||||
from dbgpt.core import ModelMetadata
|
||||
from dbgpt.model.loader import _get_model_real_path
|
||||
from dbgpt.model.parameter import (
|
||||
EmbeddingModelParameters,
|
||||
@@ -89,6 +90,14 @@ class EmbeddingsModelWorker(ModelWorker):
|
||||
"""Generate non stream result"""
|
||||
raise NotImplementedError("Not supported generate for embeddings model")
|
||||
|
||||
def count_token(self, prompt: str) -> int:
|
||||
raise NotImplementedError("Not supported count_token for embeddings model")
|
||||
|
||||
def get_model_metadata(self, params: Dict) -> ModelMetadata:
|
||||
raise NotImplementedError(
|
||||
"Not supported get_model_metadata for embeddings model"
|
||||
)
|
||||
|
||||
def embeddings(self, params: Dict) -> List[List[float]]:
|
||||
model = params.get("model")
|
||||
logger.info(f"Receive embeddings request, model: {model}")
|
||||
|
@@ -15,7 +15,7 @@ from fastapi.responses import StreamingResponse
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.configs.model_config import LOGDIR
|
||||
from dbgpt.core import ModelOutput
|
||||
from dbgpt.core import ModelOutput, ModelMetadata
|
||||
from dbgpt.model.base import (
|
||||
ModelInstance,
|
||||
WorkerApplyOutput,
|
||||
@@ -271,6 +271,18 @@ class LocalWorkerManager(WorkerManager):
|
||||
) -> List[WorkerRunData]:
|
||||
return self.sync_get_model_instances(worker_type, model_name, healthy_only)
|
||||
|
||||
async def get_all_model_instances(
|
||||
self, worker_type: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
instances = list(itertools.chain(*self.workers.values()))
|
||||
result = []
|
||||
for instance in instances:
|
||||
name, wt = WorkerType.parse_worker_key(instance.worker_key)
|
||||
if wt != worker_type or (healthy_only and instance.stopped):
|
||||
continue
|
||||
result.append(instance)
|
||||
return result
|
||||
|
||||
def sync_get_model_instances(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
@@ -390,6 +402,43 @@ class LocalWorkerManager(WorkerManager):
|
||||
worker_run_data = self._sync_get_model(params, worker_type="text2vec")
|
||||
return worker_run_data.worker.embeddings(params)
|
||||
|
||||
async def count_token(self, params: Dict) -> int:
|
||||
"""Count token of prompt"""
|
||||
with root_tracer.start_span(
|
||||
"WorkerManager.count_token", params.get("span_id")
|
||||
) as span:
|
||||
params["span_id"] = span.span_id
|
||||
try:
|
||||
worker_run_data = await self._get_model(params)
|
||||
except Exception as e:
|
||||
raise e
|
||||
prompt = params.get("prompt")
|
||||
async with worker_run_data.semaphore:
|
||||
if worker_run_data.worker.support_async():
|
||||
return await worker_run_data.worker.async_count_token(prompt)
|
||||
else:
|
||||
return await self.run_blocking_func(
|
||||
worker_run_data.worker.count_token, prompt
|
||||
)
|
||||
|
||||
async def get_model_metadata(self, params: Dict) -> ModelMetadata:
|
||||
"""Get model metadata"""
|
||||
with root_tracer.start_span(
|
||||
"WorkerManager.get_model_metadata", params.get("span_id")
|
||||
) as span:
|
||||
params["span_id"] = span.span_id
|
||||
try:
|
||||
worker_run_data = await self._get_model(params)
|
||||
except Exception as e:
|
||||
raise e
|
||||
async with worker_run_data.semaphore:
|
||||
if worker_run_data.worker.support_async():
|
||||
return await worker_run_data.worker.async_get_model_metadata(params)
|
||||
else:
|
||||
return await self.run_blocking_func(
|
||||
worker_run_data.worker.get_model_metadata, params
|
||||
)
|
||||
|
||||
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
||||
apply_func: Callable[[WorkerApplyRequest], Awaitable[str]] = None
|
||||
if apply_req.apply_type == WorkerApplyType.START:
|
||||
@@ -601,6 +650,13 @@ class WorkerManagerAdapter(WorkerManager):
|
||||
worker_type, model_name, healthy_only
|
||||
)
|
||||
|
||||
async def get_all_model_instances(
|
||||
self, worker_type: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
return await self.worker_manager.get_all_model_instances(
|
||||
worker_type, healthy_only
|
||||
)
|
||||
|
||||
def sync_get_model_instances(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
@@ -635,6 +691,12 @@ class WorkerManagerAdapter(WorkerManager):
|
||||
def sync_embeddings(self, params: Dict) -> List[List[float]]:
|
||||
return self.worker_manager.sync_embeddings(params)
|
||||
|
||||
async def count_token(self, params: Dict) -> int:
|
||||
return await self.worker_manager.count_token(params)
|
||||
|
||||
async def get_model_metadata(self, params: Dict) -> ModelMetadata:
|
||||
return await self.worker_manager.get_model_metadata(params)
|
||||
|
||||
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
||||
return await self.worker_manager.worker_apply(apply_req)
|
||||
|
||||
@@ -696,6 +758,24 @@ async def api_embeddings(request: EmbeddingsRequest):
|
||||
return await worker_manager.embeddings(params)
|
||||
|
||||
|
||||
@router.post("/worker/count_token")
|
||||
async def api_count_token(request: CountTokenRequest):
|
||||
params = request.dict(exclude_none=True)
|
||||
span_id = root_tracer.get_current_span_id()
|
||||
if "span_id" not in params and span_id:
|
||||
params["span_id"] = span_id
|
||||
return await worker_manager.count_token(params)
|
||||
|
||||
|
||||
@router.post("/worker/model_metadata")
|
||||
async def api_get_model_metadata(request: ModelMetadataRequest):
|
||||
params = request.dict(exclude_none=True)
|
||||
span_id = root_tracer.get_current_span_id()
|
||||
if "span_id" not in params and span_id:
|
||||
params["span_id"] = span_id
|
||||
return await worker_manager.get_model_metadata(params)
|
||||
|
||||
|
||||
@router.post("/worker/apply")
|
||||
async def api_worker_apply(request: WorkerApplyRequest):
|
||||
return await worker_manager.worker_apply(request)
|
||||
|
@@ -133,22 +133,29 @@ class RemoteWorkerManager(LocalWorkerManager):
|
||||
self, model_name: str, instances: List[ModelInstance]
|
||||
) -> List[WorkerRunData]:
|
||||
worker_instances = []
|
||||
for ins in instances:
|
||||
worker = RemoteModelWorker()
|
||||
worker.load_worker(model_name, model_name, host=ins.host, port=ins.port)
|
||||
wr = WorkerRunData(
|
||||
host=ins.host,
|
||||
port=ins.port,
|
||||
worker_key=ins.model_name,
|
||||
worker=worker,
|
||||
worker_params=None,
|
||||
model_params=None,
|
||||
stop_event=asyncio.Event(),
|
||||
semaphore=asyncio.Semaphore(100), # Not limit in client
|
||||
for instance in instances:
|
||||
worker_instances.append(
|
||||
self._build_single_worker_instance(model_name, instance)
|
||||
)
|
||||
worker_instances.append(wr)
|
||||
return worker_instances
|
||||
|
||||
def _build_single_worker_instance(self, model_name: str, instance: ModelInstance):
|
||||
worker = RemoteModelWorker()
|
||||
worker.load_worker(
|
||||
model_name, model_name, host=instance.host, port=instance.port
|
||||
)
|
||||
wr = WorkerRunData(
|
||||
host=instance.host,
|
||||
port=instance.port,
|
||||
worker_key=instance.model_name,
|
||||
worker=worker,
|
||||
worker_params=None,
|
||||
model_params=None,
|
||||
stop_event=asyncio.Event(),
|
||||
semaphore=asyncio.Semaphore(100), # Not limit in client
|
||||
)
|
||||
return wr
|
||||
|
||||
async def get_model_instances(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
@@ -158,6 +165,20 @@ class RemoteWorkerManager(LocalWorkerManager):
|
||||
)
|
||||
return self._build_worker_instances(model_name, instances)
|
||||
|
||||
async def get_all_model_instances(
|
||||
self, worker_type: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
instances: List[
|
||||
ModelInstance
|
||||
] = await self.model_registry.get_all_model_instances(healthy_only=healthy_only)
|
||||
result = []
|
||||
for instance in instances:
|
||||
name, wt = WorkerType.parse_worker_key(instance.model_name)
|
||||
if wt != worker_type:
|
||||
continue
|
||||
result.append(self._build_single_worker_instance(name, instance))
|
||||
return result
|
||||
|
||||
def sync_get_model_instances(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from typing import Dict, Iterator, List
|
||||
import logging
|
||||
from dbgpt.core import ModelOutput
|
||||
from dbgpt.core import ModelOutput, ModelMetadata
|
||||
from dbgpt.model.parameter import ModelParameters
|
||||
from dbgpt.model.cluster.worker_base import ModelWorker
|
||||
|
||||
@@ -90,6 +90,44 @@ class RemoteModelWorker(ModelWorker):
|
||||
)
|
||||
return ModelOutput(**response.json())
|
||||
|
||||
def count_token(self, prompt: str) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_count_token(self, prompt: str) -> int:
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
url = self.worker_addr + "/count_token"
|
||||
logger.debug(f"Send async_count_token to url {url}, params: {prompt}")
|
||||
response = await client.post(
|
||||
url,
|
||||
headers=self.headers,
|
||||
json={"prompt": prompt},
|
||||
timeout=self.timeout,
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def async_get_model_metadata(self, params: Dict) -> ModelMetadata:
|
||||
"""Asynchronously get model metadata"""
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
url = self.worker_addr + "/model_metadata"
|
||||
logger.debug(
|
||||
f"Send async_get_model_metadata to url {url}, params: {params}"
|
||||
)
|
||||
response = await client.post(
|
||||
url,
|
||||
headers=self.headers,
|
||||
json=params,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
return ModelMetadata(**response.json())
|
||||
|
||||
def get_model_metadata(self, params: Dict) -> ModelMetadata:
|
||||
"""Get model metadata"""
|
||||
raise NotImplementedError
|
||||
|
||||
def embeddings(self, params: Dict) -> List[List[float]]:
|
||||
"""Get embeddings for input"""
|
||||
import requests
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Iterator, List, Type
|
||||
|
||||
from dbgpt.core import ModelOutput
|
||||
from dbgpt.core import ModelOutput, ModelMetadata
|
||||
from dbgpt.model.parameter import ModelParameters, WorkerType
|
||||
from dbgpt.util.parameter_utils import (
|
||||
ParameterDescription,
|
||||
@@ -92,6 +92,42 @@ class ModelWorker(ABC):
|
||||
"""Asynchronously generate output (non-stream) based on provided parameters."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def count_token(self, prompt: str) -> int:
|
||||
"""Count token of prompt
|
||||
Args:
|
||||
prompt (str): prompt
|
||||
|
||||
Returns:
|
||||
int: token count
|
||||
"""
|
||||
|
||||
async def async_count_token(self, prompt: str) -> int:
|
||||
"""Asynchronously count token of prompt
|
||||
Args:
|
||||
prompt (str): prompt
|
||||
|
||||
Returns:
|
||||
int: token count
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_model_metadata(self, params: Dict) -> ModelMetadata:
|
||||
"""Get model metadata
|
||||
|
||||
Args:
|
||||
params (Dict): parameters, eg. {"model": "vicuna-13b-v1.5"}
|
||||
"""
|
||||
|
||||
async def async_get_model_metadata(self, params: Dict) -> ModelMetadata:
|
||||
"""Asynchronously get model metadata
|
||||
|
||||
Args:
|
||||
params (Dict): parameters, eg. {"model": "vicuna-13b-v1.5"}
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def embeddings(self, params: Dict) -> List[List[float]]:
|
||||
"""
|
||||
|
@@ -70,7 +70,7 @@ def _initialize_openai_v1(params: ProxyModelParameters):
|
||||
api_type = params.proxy_api_type or os.getenv("OPENAI_API_TYPE", "open_ai")
|
||||
|
||||
base_url = params.proxy_api_base or os.getenv(
|
||||
"OPENAI_API_TYPE",
|
||||
"OPENAI_API_BASE",
|
||||
os.getenv("AZURE_OPENAI_ENDPOINT") if api_type == "azure" else None,
|
||||
)
|
||||
api_key = params.proxy_api_key or os.getenv(
|
||||
|
0
dbgpt/model/utils/__init__.py
Normal file
0
dbgpt/model/utils/__init__.py
Normal file
282
dbgpt/model/utils/chatgpt_utils.py
Normal file
282
dbgpt/model/utils/chatgpt_utils.py
Normal file
@@ -0,0 +1,282 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
import importlib.metadata as metadata
|
||||
from typing import List, Dict, Any, Optional, TYPE_CHECKING, Union, AsyncIterator
|
||||
|
||||
from dbgpt.core.interface.llm import ModelMetadata, LLMClient
|
||||
from dbgpt.core.interface.llm import ModelOutput, ModelRequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import httpx
|
||||
from httpx._types import ProxiesTypes
|
||||
from openai import AsyncAzureOpenAI
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
ClientType = Union[AsyncAzureOpenAI, AsyncOpenAI]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenAIParameters:
|
||||
"""A class to represent a LLM model."""
|
||||
|
||||
api_type: str = "open_ai"
|
||||
api_base: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
api_version: Optional[str] = None
|
||||
full_url: Optional[str] = None
|
||||
proxies: Optional["ProxiesTypes"] = None
|
||||
|
||||
|
||||
def _initialize_openai_v1(init_params: OpenAIParameters):
|
||||
try:
|
||||
from openai import OpenAI
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
"Could not import python package: openai "
|
||||
"Please install openai by command `pip install openai"
|
||||
) from exc
|
||||
|
||||
if not metadata.version("openai") >= "1.0.0":
|
||||
raise ImportError("Please upgrade openai package to version 1.0.0 or above")
|
||||
|
||||
api_type: Optional[str] = init_params.api_type
|
||||
api_base: Optional[str] = init_params.api_base
|
||||
api_key: Optional[str] = init_params.api_key
|
||||
api_version: Optional[str] = init_params.api_version
|
||||
full_url: Optional[str] = init_params.full_url
|
||||
|
||||
api_type = api_type or os.getenv("OPENAI_API_TYPE", "open_ai")
|
||||
|
||||
base_url = api_base or os.getenv(
|
||||
"OPENAI_API_BASE",
|
||||
os.getenv("AZURE_OPENAI_ENDPOINT") if api_type == "azure" else None,
|
||||
)
|
||||
api_key = api_key or os.getenv(
|
||||
"OPENAI_API_KEY",
|
||||
os.getenv("AZURE_OPENAI_KEY") if api_type == "azure" else None,
|
||||
)
|
||||
api_version = api_version or os.getenv("OPENAI_API_VERSION")
|
||||
|
||||
if not base_url and full_url:
|
||||
base_url = full_url.split("/chat/completions")[0]
|
||||
|
||||
if api_key is None:
|
||||
raise ValueError("api_key is required, please set OPENAI_API_KEY environment")
|
||||
if base_url is None:
|
||||
raise ValueError("base_url is required, please set OPENAI_BASE_URL environment")
|
||||
if base_url.endswith("/"):
|
||||
base_url = base_url[:-1]
|
||||
|
||||
openai_params = {
|
||||
"api_key": api_key,
|
||||
"base_url": base_url,
|
||||
}
|
||||
return openai_params, api_type, api_version
|
||||
|
||||
|
||||
def _build_openai_client(init_params: OpenAIParameters):
|
||||
import httpx
|
||||
|
||||
openai_params, api_type, api_version = _initialize_openai_v1(init_params)
|
||||
if api_type == "azure":
|
||||
from openai import AsyncAzureOpenAI
|
||||
|
||||
return AsyncAzureOpenAI(
|
||||
api_key=openai_params["api_key"],
|
||||
api_version=api_version,
|
||||
azure_endpoint=openai_params["base_url"],
|
||||
http_client=httpx.AsyncClient(proxies=init_params.proxies),
|
||||
)
|
||||
else:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
return AsyncOpenAI(
|
||||
**openai_params, http_client=httpx.AsyncClient(proxies=init_params.proxies)
|
||||
)
|
||||
|
||||
|
||||
class OpenAILLMClient(LLMClient):
|
||||
"""An implementation of LLMClient using OpenAI API.
|
||||
|
||||
In order to have as few dependencies as possible, we directly use the http API.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_type: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
model: Optional[str] = "gpt-3.5-turbo",
|
||||
proxies: Optional["ProxiesTypes"] = None,
|
||||
timeout: Optional[int] = 240,
|
||||
model_alias: Optional[str] = "chatgpt_proxyllm",
|
||||
context_length: Optional[int] = 8192,
|
||||
openai_client: Optional["ClientType"] = None,
|
||||
openai_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self._init_params = OpenAIParameters(
|
||||
api_type=api_type,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
proxies=proxies,
|
||||
)
|
||||
|
||||
self._model = model
|
||||
self._proxies = proxies
|
||||
self._timeout = timeout
|
||||
self._model_alias = model_alias
|
||||
self._context_length = context_length
|
||||
self._client = openai_client
|
||||
self._openai_kwargs = openai_kwargs or {}
|
||||
|
||||
@property
|
||||
def client(self) -> ClientType:
|
||||
if self._client is None:
|
||||
self._client = _build_openai_client(init_params=self._init_params)
|
||||
return self._client
|
||||
|
||||
def _build_request(
|
||||
self, request: ModelRequest, stream: Optional[bool] = False
|
||||
) -> Dict[str, Any]:
|
||||
payload = {"model": request.model or self._model, "stream": stream}
|
||||
|
||||
# Apply openai kwargs
|
||||
for k, v in self._openai_kwargs.items():
|
||||
payload[k] = v
|
||||
if request.temperature:
|
||||
payload["temperature"] = request.temperature
|
||||
if request.max_new_tokens:
|
||||
payload["max_tokens"] = request.max_new_tokens
|
||||
return payload
|
||||
|
||||
async def generate(self, request: ModelRequest) -> ModelOutput:
|
||||
messages = request.to_openai_messages()
|
||||
payload = self._build_request(request)
|
||||
try:
|
||||
chat_completion = await self.client.chat.completions.create(
|
||||
messages=messages, **payload
|
||||
)
|
||||
text = chat_completion.choices[0].message.content
|
||||
usage = chat_completion.usage.dict()
|
||||
return ModelOutput(text=text, error_code=0, usage=usage)
|
||||
except Exception as e:
|
||||
return ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=1,
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self, request: ModelRequest
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
messages = request.to_openai_messages()
|
||||
payload = self._build_request(request)
|
||||
try:
|
||||
chat_completion = await self.client.chat.completions.create(
|
||||
messages=messages, **payload
|
||||
)
|
||||
text = ""
|
||||
for r in chat_completion:
|
||||
if len(r.choices) == 0:
|
||||
continue
|
||||
if r.choices[0].delta.content is not None:
|
||||
content = r.choices[0].delta.content
|
||||
text += content
|
||||
yield ModelOutput(text=text, error_code=0)
|
||||
except Exception as e:
|
||||
yield ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=1,
|
||||
)
|
||||
|
||||
async def models(self) -> List[ModelMetadata]:
|
||||
model_metadata = ModelMetadata(
|
||||
model=self._model_alias,
|
||||
context_length=await self.get_context_length(),
|
||||
)
|
||||
return [model_metadata]
|
||||
|
||||
async def get_context_length(self) -> int:
|
||||
"""Get the context length of the model.
|
||||
|
||||
Returns:
|
||||
int: The context length.
|
||||
# TODO: This is a temporary solution. We should have a better way to get the context length.
|
||||
eg. get real context length from the openai api.
|
||||
"""
|
||||
return self._context_length
|
||||
|
||||
async def count_token(self, model: str, prompt: str) -> int:
|
||||
"""Count the number of tokens in a given prompt.
|
||||
|
||||
TODO: Get the real number of tokens from the openai api or tiktoken package
|
||||
"""
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
async def _to_openai_stream(
|
||||
model: str, output_iter: AsyncIterator[ModelOutput]
|
||||
) -> AsyncIterator[str]:
|
||||
"""Convert the output_iter to openai stream format.
|
||||
|
||||
Args:
|
||||
model (str): The model name.
|
||||
output_iter (AsyncIterator[ModelOutput]): The output iterator.
|
||||
"""
|
||||
import json
|
||||
import shortuuid
|
||||
from fastchat.protocol.openai_api_protocol import (
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
DeltaMessage,
|
||||
)
|
||||
|
||||
id = f"chatcmpl-{shortuuid.random()}"
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(role="assistant"),
|
||||
finish_reason=None,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(id=id, choices=[choice_data], model=model)
|
||||
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
|
||||
previous_text = ""
|
||||
finish_stream_events = []
|
||||
async for model_output in output_iter:
|
||||
model_output: ModelOutput = model_output
|
||||
if model_output.error_code != 0:
|
||||
yield f"data: {json.dumps(model_output.to_dict(), ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
decoded_unicode = model_output.text.replace("\ufffd", "")
|
||||
delta_text = decoded_unicode[len(previous_text) :]
|
||||
previous_text = (
|
||||
decoded_unicode
|
||||
if len(decoded_unicode) > len(previous_text)
|
||||
else previous_text
|
||||
)
|
||||
|
||||
if len(delta_text) == 0:
|
||||
delta_text = None
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(content=delta_text),
|
||||
finish_reason=model_output.finish_reason,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(id=id, choices=[choice_data], model=model)
|
||||
if delta_text is None:
|
||||
if model_output.finish_reason is not None:
|
||||
finish_stream_events.append(chunk)
|
||||
continue
|
||||
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
for finish_chunk in finish_stream_events:
|
||||
yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
|
||||
import traceback
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.configs.model_config import (
|
||||
@@ -67,8 +68,9 @@ class DBSummaryClient:
|
||||
try:
|
||||
self.db_summary_embedding(item["db_name"], item["db_type"])
|
||||
except Exception as e:
|
||||
message = traceback.format_exc()
|
||||
logger.warn(
|
||||
f'{item["db_name"]}, {item["db_type"]} summary error!{str(e)}', e
|
||||
f'{item["db_name"]}, {item["db_type"]} summary error!{str(e)}, detail: {message}'
|
||||
)
|
||||
|
||||
def init_db_profile(self, db_summary_client, dbname, embeddings):
|
||||
|
149
examples/awel/simple_llm_client_example.py
Normal file
149
examples/awel/simple_llm_client_example.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""AWEL: Simple llm client example
|
||||
|
||||
DB-GPT will automatically load and execute the current file after startup.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
DBGPT_SERVER="http://127.0.0.1:5000"
|
||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/generate \
|
||||
-H "Content-Type: application/json" -d '{
|
||||
"model": "proxyllm",
|
||||
"messages": "hello"
|
||||
}'
|
||||
|
||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/generate_stream \
|
||||
-H "Content-Type: application/json" -d '{
|
||||
"model": "proxyllm",
|
||||
"messages": "hello",
|
||||
"stream": true
|
||||
}'
|
||||
|
||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/count_token \
|
||||
-H "Content-Type: application/json" -d '{
|
||||
"model": "proxyllm",
|
||||
"messages": "hello"
|
||||
}'
|
||||
|
||||
"""
|
||||
from typing import Dict, Any, AsyncIterator, Optional, Union, List
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.component import ComponentType
|
||||
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator, TransformStreamAbsOperator
|
||||
from dbgpt.core import (
|
||||
ModelMessage,
|
||||
LLMClient,
|
||||
LLMOperator,
|
||||
StreamingLLMOperator,
|
||||
ModelOutput,
|
||||
ModelRequest,
|
||||
)
|
||||
from dbgpt.model import DefaultLLMClient
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
|
||||
|
||||
class TriggerReqBody(BaseModel):
|
||||
messages: Union[str, List[Dict[str, str]]] = Field(
|
||||
..., description="User input messages"
|
||||
)
|
||||
model: str = Field(..., description="Model name")
|
||||
stream: Optional[bool] = Field(default=False, description="Whether return stream")
|
||||
|
||||
|
||||
class RequestHandleOperator(MapOperator[TriggerReqBody, ModelRequest]):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, input_value: TriggerReqBody) -> ModelRequest:
|
||||
messages = [ModelMessage.build_human_message(input_value.messages)]
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
"request_model_name", input_value.model
|
||||
)
|
||||
return ModelRequest(
|
||||
model=input_value.model,
|
||||
messages=messages,
|
||||
echo=False,
|
||||
)
|
||||
|
||||
|
||||
class LLMMixin:
|
||||
@property
|
||||
def llm_client(self) -> LLMClient:
|
||||
if not self._llm_client:
|
||||
worker_manager = self.system_app.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
self._llm_client = DefaultLLMClient(worker_manager)
|
||||
return self._llm_client
|
||||
|
||||
|
||||
class MyLLMOperator(LLMMixin, LLMOperator):
|
||||
def __init__(self, llm_client: LLMClient = None, **kwargs):
|
||||
super().__init__(llm_client, **kwargs)
|
||||
|
||||
|
||||
class MyStreamingLLMOperator(LLMMixin, StreamingLLMOperator):
|
||||
def __init__(self, llm_client: LLMClient = None, **kwargs):
|
||||
super().__init__(llm_client, **kwargs)
|
||||
|
||||
|
||||
class MyLLMStreamingOperator(TransformStreamAbsOperator[ModelOutput, str]):
|
||||
async def transform_stream(
|
||||
self, input_value: AsyncIterator[ModelOutput]
|
||||
) -> AsyncIterator[str]:
|
||||
from dbgpt.model.utils.chatgpt_utils import _to_openai_stream
|
||||
|
||||
model = await self.current_dag_context.get_share_data("request_model_name")
|
||||
async for output in _to_openai_stream(model, input_value):
|
||||
yield output
|
||||
|
||||
|
||||
class MyModelToolOperator(LLMMixin, MapOperator[TriggerReqBody, Dict[str, Any]]):
|
||||
def __init__(self, llm_client: LLMClient = None, **kwargs):
|
||||
self._llm_client = llm_client
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
|
||||
async def map(self, input_value: TriggerReqBody) -> Dict[str, Any]:
|
||||
prompt_tokens = await self.llm_client.count_token(
|
||||
input_value.model, input_value.messages
|
||||
)
|
||||
available_models = await self.llm_client.models()
|
||||
return {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"available_models": available_models,
|
||||
}
|
||||
|
||||
|
||||
with DAG("dbgpt_awel_simple_llm_client_generate") as client_generate_dag:
|
||||
# Receive http request and trigger dag to run.
|
||||
trigger = HttpTrigger(
|
||||
"/examples/simple_client/generate", methods="POST", request_body=TriggerReqBody
|
||||
)
|
||||
request_handle_task = RequestHandleOperator()
|
||||
model_task = MyLLMOperator()
|
||||
model_parse_task = MapOperator(lambda out: out.to_dict())
|
||||
trigger >> request_handle_task >> model_task >> model_parse_task
|
||||
|
||||
with DAG("dbgpt_awel_simple_llm_client_generate_stream") as client_generate_stream_dag:
|
||||
# Receive http request and trigger dag to run.
|
||||
trigger = HttpTrigger(
|
||||
"/examples/simple_client/generate_stream",
|
||||
methods="POST",
|
||||
request_body=TriggerReqBody,
|
||||
streaming_response=True,
|
||||
)
|
||||
request_handle_task = RequestHandleOperator()
|
||||
model_task = MyStreamingLLMOperator()
|
||||
openai_format_stream_task = MyLLMStreamingOperator()
|
||||
trigger >> request_handle_task >> model_task >> openai_format_stream_task
|
||||
|
||||
with DAG("dbgpt_awel_simple_llm_client_count_token") as client_count_token_dag:
|
||||
# Receive http request and trigger dag to run.
|
||||
trigger = HttpTrigger(
|
||||
"/examples/simple_client/count_token",
|
||||
methods="POST",
|
||||
request_body=TriggerReqBody,
|
||||
)
|
||||
model_task = MyModelToolOperator()
|
||||
trigger >> model_task
|
@@ -1,13 +1,19 @@
|
||||
import asyncio
|
||||
from dbgpt.core.awel import DAG
|
||||
from dbgpt.core import BaseOutputParser, OpenAILLM, RequestBuildOperator, PromptTemplate
|
||||
from dbgpt.core import (
|
||||
BaseOutputParser,
|
||||
RequestBuildOperator,
|
||||
PromptTemplate,
|
||||
LLMOperator,
|
||||
)
|
||||
from dbgpt.model import OpenAILLMClient
|
||||
|
||||
with DAG("simple_sdk_llm_example_dag") as dag:
|
||||
prompt_task = PromptTemplate.from_template(
|
||||
"Write a SQL of {dialect} to query all data of {table_name}."
|
||||
)
|
||||
model_pre_handle_task = RequestBuildOperator(model="gpt-3.5-turbo")
|
||||
llm_task = OpenAILLM()
|
||||
llm_task = LLMOperator(OpenAILLMClient())
|
||||
out_parse_task = BaseOutputParser()
|
||||
prompt_task >> model_pre_handle_task >> llm_task >> out_parse_task
|
||||
|
||||
|
@@ -8,10 +8,16 @@ from dbgpt.core.awel import (
|
||||
JoinOperator,
|
||||
MapOperator,
|
||||
)
|
||||
from dbgpt.core import SQLOutputParser, OpenAILLM, RequestBuildOperator, PromptTemplate
|
||||
from dbgpt.core import (
|
||||
SQLOutputParser,
|
||||
LLMOperator,
|
||||
RequestBuildOperator,
|
||||
PromptTemplate,
|
||||
)
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
|
||||
from dbgpt.datasource.operator.datasource_operator import DatasourceOperator
|
||||
from dbgpt.rag.operator.datasource import DatasourceRetrieverOperator
|
||||
from dbgpt.model import OpenAILLMClient
|
||||
|
||||
|
||||
def _create_temporary_connection():
|
||||
@@ -115,7 +121,7 @@ with DAG("simple_sdk_llm_sql_example") as dag:
|
||||
prompt_input_task = JoinOperator(combine_function=_join_func)
|
||||
prompt_task = PromptTemplate.from_template(_sql_prompt())
|
||||
model_pre_handle_task = RequestBuildOperator(model="gpt-3.5-turbo")
|
||||
llm_task = OpenAILLM()
|
||||
llm_task = LLMOperator(OpenAILLMClient())
|
||||
out_parse_task = SQLOutputParser()
|
||||
sql_parse_task = MapOperator(map_function=lambda x: x["sql"])
|
||||
db_query_task = DatasourceOperator(connection=db_connection)
|
||||
|
@@ -10,7 +10,7 @@ pytest-mock
|
||||
pytest-recording
|
||||
pytesseract==0.3.10
|
||||
aioresponses
|
||||
# python code format, usage `black .`
|
||||
black
|
||||
# for git hooks
|
||||
pre-commit
|
||||
pre-commit
|
||||
# Type checking
|
||||
mypy==0.991
|
11
requirements/lint-requirements.txt
Normal file
11
requirements/lint-requirements.txt
Normal file
@@ -0,0 +1,11 @@
|
||||
# python code format, usage `black .`
|
||||
black==22.8.0
|
||||
blackdoc==0.3.7
|
||||
flake8==5.0.4
|
||||
flake8-bugbear==22.10.25
|
||||
flake8-comprehensions==3.10.0
|
||||
flake8-docstrings==1.6.0
|
||||
flake8-simplify==0.19.3
|
||||
flake8-tidy-imports==4.8.0
|
||||
isort==5.10.1
|
||||
pyupgrade==3.1.0
|
43
setup.py
43
setup.py
@@ -364,23 +364,40 @@ def core_requires():
|
||||
"prettytable",
|
||||
"cachetools",
|
||||
]
|
||||
|
||||
setup_spec.extras["framework"] = [
|
||||
"coloredlogs",
|
||||
# Just use by DB-GPT internal, we should find the smallest dependency set for run we core unit test.
|
||||
# The dependency "framework" is too large for now.
|
||||
setup_spec.extras["simple_framework"] = setup_spec.extras["core"] + [
|
||||
"pydantic<2,>=1",
|
||||
"httpx",
|
||||
"fastapi==0.98.0",
|
||||
"shortuuid",
|
||||
# change from fixed version 2.0.22 to variable version, because other dependencies are >=1.4, such as pydoris is <2
|
||||
"SQLAlchemy>=1.4,<3",
|
||||
# for cache
|
||||
"msgpack",
|
||||
# for cache, TODO pympler has not been updated for a long time and needs to find a new toolkit.
|
||||
"pympler",
|
||||
"sqlparse==0.4.4",
|
||||
"duckdb==0.8.1",
|
||||
"duckdb-engine",
|
||||
]
|
||||
# TODO: remove fschat from simple_framework
|
||||
if BUILD_FROM_SOURCE:
|
||||
setup_spec.extras["simple_framework"].append(
|
||||
f"fschat @ {BUILD_FROM_SOURCE_URL_FAST_CHAT}"
|
||||
)
|
||||
else:
|
||||
setup_spec.extras["simple_framework"].append("fschat")
|
||||
|
||||
setup_spec.extras["framework"] = setup_spec.extras["simple_framework"] + [
|
||||
"coloredlogs",
|
||||
"seaborn",
|
||||
# https://github.com/eosphoros-ai/DB-GPT/issues/551
|
||||
"pandas==2.0.3",
|
||||
"auto-gpt-plugin-template",
|
||||
"gTTS==2.3.1",
|
||||
"langchain>=0.0.286",
|
||||
# change from fixed version 2.0.22 to variable version, because other dependencies are >=1.4, such as pydoris is <2
|
||||
"SQLAlchemy>=1.4,<3",
|
||||
"fastapi==0.98.0",
|
||||
"pymysql",
|
||||
"duckdb==0.8.1",
|
||||
"duckdb-engine",
|
||||
"jsonschema",
|
||||
# TODO move transformers to default
|
||||
# "transformers>=4.31.0",
|
||||
@@ -390,20 +407,10 @@ def core_requires():
|
||||
"openpyxl==3.1.2",
|
||||
"chardet==5.1.0",
|
||||
"xlrd==2.0.1",
|
||||
# for cache, TODO pympler has not been updated for a long time and needs to find a new toolkit.
|
||||
"pympler",
|
||||
"aiofiles",
|
||||
# for cache
|
||||
"msgpack",
|
||||
# for agent
|
||||
"GitPython",
|
||||
]
|
||||
if BUILD_FROM_SOURCE:
|
||||
setup_spec.extras["framework"].append(
|
||||
f"fschat @ {BUILD_FROM_SOURCE_URL_FAST_CHAT}"
|
||||
)
|
||||
else:
|
||||
setup_spec.extras["framework"].append("fschat")
|
||||
|
||||
|
||||
def knowledge_requires():
|
||||
|
Reference in New Issue
Block a user