Native data AI application framework based on AWEL+AGENT (#1152)

Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
Co-authored-by: lcx01800250 <lcx01800250@alibaba-inc.com>
Co-authored-by: licunxing <864255598@qq.com>
Co-authored-by: Aralhi <xiaoping0501@gmail.com>
Co-authored-by: xuyuan23 <643854343@qq.com>
Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: hzh97 <2976151305@qq.com>
This commit is contained in:
明天
2024-02-07 17:43:27 +08:00
committed by GitHub
parent dbb9ac83b1
commit d5afa6e206
328 changed files with 22606 additions and 3282 deletions

View File

@@ -1,6 +1,8 @@
import asyncio
from typing import AsyncIterator, List, Optional
from dbgpt.core.awel import DAGVar
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
from dbgpt.core.interface.llm import (
DefaultMessageConverter,
LLMClient,
@@ -13,6 +15,23 @@ from dbgpt.model.cluster.manager_base import WorkerManager
from dbgpt.model.parameter import WorkerType
@register_resource(
label="Default LLM Client",
name="default_llm_client",
category=ResourceCategory.LLM_CLIENT,
description="Default LLM client(Connect to your DB-GPT model serving)",
parameters=[
Parameter.build_from(
"Auto Convert Message",
name="auto_convert_message",
type=bool,
optional=True,
default=False,
description="Whether to auto convert the messages that are not supported "
"by the LLM to a compatible format",
)
],
)
class DefaultLLMClient(LLMClient):
"""Default LLM client implementation.
@@ -24,11 +43,28 @@ class DefaultLLMClient(LLMClient):
"""
def __init__(
self, worker_manager: WorkerManager, auto_convert_message: bool = False
self,
worker_manager: Optional[WorkerManager] = None,
auto_convert_message: bool = False,
):
self._worker_manager = worker_manager
self._auto_covert_message = auto_convert_message
@property
def worker_manager(self) -> WorkerManager:
"""Get the worker manager instance.
If not set, get the worker manager from the system app. If not set, raise
ValueError.
"""
if not self._worker_manager:
system_app = DAGVar.get_current_system_app()
if not system_app:
raise ValueError("System app is not initialized")
from dbgpt.model.cluster import WorkerManagerFactory
return WorkerManagerFactory.get_instance(system_app).create()
return self._worker_manager
async def generate(
self,
request: ModelRequest,
@@ -37,7 +73,7 @@ class DefaultLLMClient(LLMClient):
if not message_converter and self._auto_covert_message:
message_converter = DefaultMessageConverter()
request = await self.covert_message(request, message_converter)
return await self._worker_manager.generate(request.to_dict())
return await self.worker_manager.generate(request.to_dict())
async def generate_stream(
self,
@@ -47,18 +83,18 @@ class DefaultLLMClient(LLMClient):
if not message_converter and self._auto_covert_message:
message_converter = DefaultMessageConverter()
request = await self.covert_message(request, message_converter)
async for output in self._worker_manager.generate_stream(request.to_dict()):
async for output in self.worker_manager.generate_stream(request.to_dict()):
yield output
async def models(self) -> List[ModelMetadata]:
instances = await self._worker_manager.get_all_model_instances(
instances = await self.worker_manager.get_all_model_instances(
WorkerType.LLM.value, healthy_only=True
)
query_metadata_task = []
for instance in instances:
worker_name, _ = WorkerType.parse_worker_key(instance.worker_key)
query_metadata_task.append(
self._worker_manager.get_model_metadata({"model": worker_name})
self.worker_manager.get_model_metadata({"model": worker_name})
)
models: List[ModelMetadata] = await asyncio.gather(*query_metadata_task)
model_map = {}
@@ -67,6 +103,4 @@ class DefaultLLMClient(LLMClient):
return [model_map[model_name] for model_name in sorted(model_map.keys())]
async def count_token(self, model: str, prompt: str) -> int:
return await self._worker_manager.count_token(
{"model": model, "prompt": prompt}
)
return await self.worker_manager.count_token({"model": model, "prompt": prompt})