feat(model): Add new LLMClient and new build tools (#967)

This commit is contained in:
Fangyin Cheng
2023-12-23 16:33:01 +08:00
committed by GitHub
parent 12234ae258
commit 0c46c339ca
30 changed files with 1072 additions and 133 deletions

View File

@@ -15,7 +15,7 @@ from fastapi.responses import StreamingResponse
from dbgpt.component import SystemApp
from dbgpt.configs.model_config import LOGDIR
from dbgpt.core import ModelOutput
from dbgpt.core import ModelOutput, ModelMetadata
from dbgpt.model.base import (
ModelInstance,
WorkerApplyOutput,
@@ -271,6 +271,18 @@ class LocalWorkerManager(WorkerManager):
) -> List[WorkerRunData]:
return self.sync_get_model_instances(worker_type, model_name, healthy_only)
async def get_all_model_instances(
self, worker_type: str, healthy_only: bool = True
) -> List[WorkerRunData]:
instances = list(itertools.chain(*self.workers.values()))
result = []
for instance in instances:
name, wt = WorkerType.parse_worker_key(instance.worker_key)
if wt != worker_type or (healthy_only and instance.stopped):
continue
result.append(instance)
return result
def sync_get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
@@ -390,6 +402,43 @@ class LocalWorkerManager(WorkerManager):
worker_run_data = self._sync_get_model(params, worker_type="text2vec")
return worker_run_data.worker.embeddings(params)
async def count_token(self, params: Dict) -> int:
"""Count token of prompt"""
with root_tracer.start_span(
"WorkerManager.count_token", params.get("span_id")
) as span:
params["span_id"] = span.span_id
try:
worker_run_data = await self._get_model(params)
except Exception as e:
raise e
prompt = params.get("prompt")
async with worker_run_data.semaphore:
if worker_run_data.worker.support_async():
return await worker_run_data.worker.async_count_token(prompt)
else:
return await self.run_blocking_func(
worker_run_data.worker.count_token, prompt
)
async def get_model_metadata(self, params: Dict) -> ModelMetadata:
"""Get model metadata"""
with root_tracer.start_span(
"WorkerManager.get_model_metadata", params.get("span_id")
) as span:
params["span_id"] = span.span_id
try:
worker_run_data = await self._get_model(params)
except Exception as e:
raise e
async with worker_run_data.semaphore:
if worker_run_data.worker.support_async():
return await worker_run_data.worker.async_get_model_metadata(params)
else:
return await self.run_blocking_func(
worker_run_data.worker.get_model_metadata, params
)
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
apply_func: Callable[[WorkerApplyRequest], Awaitable[str]] = None
if apply_req.apply_type == WorkerApplyType.START:
@@ -601,6 +650,13 @@ class WorkerManagerAdapter(WorkerManager):
worker_type, model_name, healthy_only
)
async def get_all_model_instances(
self, worker_type: str, healthy_only: bool = True
) -> List[WorkerRunData]:
return await self.worker_manager.get_all_model_instances(
worker_type, healthy_only
)
def sync_get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
@@ -635,6 +691,12 @@ class WorkerManagerAdapter(WorkerManager):
def sync_embeddings(self, params: Dict) -> List[List[float]]:
return self.worker_manager.sync_embeddings(params)
async def count_token(self, params: Dict) -> int:
return await self.worker_manager.count_token(params)
async def get_model_metadata(self, params: Dict) -> ModelMetadata:
return await self.worker_manager.get_model_metadata(params)
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
return await self.worker_manager.worker_apply(apply_req)
@@ -696,6 +758,24 @@ async def api_embeddings(request: EmbeddingsRequest):
return await worker_manager.embeddings(params)
@router.post("/worker/count_token")
async def api_count_token(request: CountTokenRequest):
params = request.dict(exclude_none=True)
span_id = root_tracer.get_current_span_id()
if "span_id" not in params and span_id:
params["span_id"] = span_id
return await worker_manager.count_token(params)
@router.post("/worker/model_metadata")
async def api_get_model_metadata(request: ModelMetadataRequest):
params = request.dict(exclude_none=True)
span_id = root_tracer.get_current_span_id()
if "span_id" not in params and span_id:
params["span_id"] = span_id
return await worker_manager.get_model_metadata(params)
@router.post("/worker/apply")
async def api_worker_apply(request: WorkerApplyRequest):
return await worker_manager.worker_apply(request)