mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-26 13:27:46 +00:00
169 lines
6.0 KiB
Python
169 lines
6.0 KiB
Python
import asyncio
|
|
from typing import AsyncIterator, List, Optional
|
|
|
|
from dbgpt.core.awel import DAGVar
|
|
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
|
from dbgpt.core.interface.llm import (
|
|
DefaultMessageConverter,
|
|
LLMClient,
|
|
MessageConverter,
|
|
ModelMetadata,
|
|
ModelOutput,
|
|
ModelRequest,
|
|
)
|
|
from dbgpt.model.cluster.manager_base import WorkerManager
|
|
from dbgpt.model.parameter import WorkerType
|
|
from dbgpt.util.i18n_utils import _
|
|
|
|
|
|
@register_resource(
|
|
label=_("Default LLM Client"),
|
|
name="default_llm_client",
|
|
category=ResourceCategory.LLM_CLIENT,
|
|
description=_("Default LLM client(Connect to your DB-GPT model serving)"),
|
|
parameters=[
|
|
Parameter.build_from(
|
|
_("Auto Convert Message"),
|
|
name="auto_convert_message",
|
|
type=bool,
|
|
optional=True,
|
|
default=True,
|
|
description=_(
|
|
"Whether to auto convert the messages that are not supported "
|
|
"by the LLM to a compatible format"
|
|
),
|
|
)
|
|
],
|
|
)
|
|
class DefaultLLMClient(LLMClient):
|
|
"""Default LLM client implementation.
|
|
|
|
Connect to the worker manager and send the request to the worker manager.
|
|
|
|
Args:
|
|
worker_manager (WorkerManager): worker manager instance.
|
|
auto_convert_message (bool, optional): auto convert the message to ModelRequest. Defaults to True.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
worker_manager: Optional[WorkerManager] = None,
|
|
auto_convert_message: bool = True,
|
|
):
|
|
self._worker_manager = worker_manager
|
|
self._auto_covert_message = auto_convert_message
|
|
|
|
@property
|
|
def worker_manager(self) -> WorkerManager:
|
|
"""Get the worker manager instance.
|
|
If not set, get the worker manager from the system app. If not set, raise
|
|
ValueError.
|
|
"""
|
|
if not self._worker_manager:
|
|
system_app = DAGVar.get_current_system_app()
|
|
if not system_app:
|
|
raise ValueError("System app is not initialized")
|
|
from dbgpt.model.cluster import WorkerManagerFactory
|
|
|
|
return WorkerManagerFactory.get_instance(system_app).create()
|
|
return self._worker_manager
|
|
|
|
async def generate(
|
|
self,
|
|
request: ModelRequest,
|
|
message_converter: Optional[MessageConverter] = None,
|
|
) -> ModelOutput:
|
|
if not message_converter and self._auto_covert_message:
|
|
message_converter = DefaultMessageConverter()
|
|
request = await self.covert_message(request, message_converter)
|
|
return await self.worker_manager.generate(request.to_dict())
|
|
|
|
async def generate_stream(
|
|
self,
|
|
request: ModelRequest,
|
|
message_converter: Optional[MessageConverter] = None,
|
|
) -> AsyncIterator[ModelOutput]:
|
|
if not message_converter and self._auto_covert_message:
|
|
message_converter = DefaultMessageConverter()
|
|
request = await self.covert_message(request, message_converter)
|
|
async for output in self.worker_manager.generate_stream(request.to_dict()):
|
|
yield output
|
|
|
|
async def models(self) -> List[ModelMetadata]:
|
|
instances = await self.worker_manager.get_all_model_instances(
|
|
WorkerType.LLM.value, healthy_only=True
|
|
)
|
|
query_metadata_task = []
|
|
for instance in instances:
|
|
worker_name, _ = WorkerType.parse_worker_key(instance.worker_key)
|
|
query_metadata_task.append(
|
|
self.worker_manager.get_model_metadata({"model": worker_name})
|
|
)
|
|
models: List[ModelMetadata] = await asyncio.gather(*query_metadata_task)
|
|
model_map = {}
|
|
for single_model in models:
|
|
model_map[single_model.model] = single_model
|
|
return [model_map[model_name] for model_name in sorted(model_map.keys())]
|
|
|
|
async def count_token(self, model: str, prompt: str) -> int:
|
|
return await self.worker_manager.count_token({"model": model, "prompt": prompt})
|
|
|
|
|
|
@register_resource(
|
|
label=_("Remote LLM Client"),
|
|
name="remote_llm_client",
|
|
category=ResourceCategory.LLM_CLIENT,
|
|
description=_("Remote LLM client(Connect to the remote DB-GPT model serving)"),
|
|
parameters=[
|
|
Parameter.build_from(
|
|
_("Controller Address"),
|
|
name="controller_address",
|
|
type=str,
|
|
optional=True,
|
|
default=_("http://127.0.0.1:8000"),
|
|
description=_("Model controller address"),
|
|
),
|
|
Parameter.build_from(
|
|
_("Auto Convert Message"),
|
|
name="auto_convert_message",
|
|
type=bool,
|
|
optional=True,
|
|
default=True,
|
|
description=_(
|
|
"Whether to auto convert the messages that are not supported "
|
|
"by the LLM to a compatible format"
|
|
),
|
|
),
|
|
],
|
|
)
|
|
class RemoteLLMClient(DefaultLLMClient):
|
|
"""Remote LLM client implementation.
|
|
|
|
Connect to the remote worker manager and send the request to the remote worker manager.
|
|
|
|
Args:
|
|
controller_address (str): model controller address
|
|
auto_convert_message (bool, optional): auto convert the message to
|
|
ModelRequest. Defaults to False.
|
|
|
|
If you start DB-GPT model cluster, the controller address is the address of the
|
|
Model Controller(`dbgpt start controller`, the default port of model controller
|
|
is 8000).
|
|
Otherwise, if you already have a running DB-GPT server(start it by
|
|
`dbgpt start webserver --port ${remote_port}`), you can use the address of the
|
|
`http://${remote_ip}:${remote_port}`.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
controller_address: str = "http://127.0.0.1:8000",
|
|
auto_convert_message: bool = True,
|
|
):
|
|
"""Initialize the RemoteLLMClient."""
|
|
from dbgpt.model.cluster import ModelRegistryClient, RemoteWorkerManager
|
|
|
|
model_registry_client = ModelRegistryClient(controller_address)
|
|
worker_manager = RemoteWorkerManager(model_registry_client)
|
|
super().__init__(worker_manager, auto_convert_message)
|