mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-05 02:51:07 +00:00
feat(model): Add new LLMClient and new build tools (#967)
This commit is contained in:
@@ -15,7 +15,7 @@ from fastapi.responses import StreamingResponse
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.configs.model_config import LOGDIR
|
||||
from dbgpt.core import ModelOutput
|
||||
from dbgpt.core import ModelOutput, ModelMetadata
|
||||
from dbgpt.model.base import (
|
||||
ModelInstance,
|
||||
WorkerApplyOutput,
|
||||
@@ -271,6 +271,18 @@ class LocalWorkerManager(WorkerManager):
|
||||
) -> List[WorkerRunData]:
|
||||
return self.sync_get_model_instances(worker_type, model_name, healthy_only)
|
||||
|
||||
async def get_all_model_instances(
|
||||
self, worker_type: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
instances = list(itertools.chain(*self.workers.values()))
|
||||
result = []
|
||||
for instance in instances:
|
||||
name, wt = WorkerType.parse_worker_key(instance.worker_key)
|
||||
if wt != worker_type or (healthy_only and instance.stopped):
|
||||
continue
|
||||
result.append(instance)
|
||||
return result
|
||||
|
||||
def sync_get_model_instances(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
@@ -390,6 +402,43 @@ class LocalWorkerManager(WorkerManager):
|
||||
worker_run_data = self._sync_get_model(params, worker_type="text2vec")
|
||||
return worker_run_data.worker.embeddings(params)
|
||||
|
||||
async def count_token(self, params: Dict) -> int:
|
||||
"""Count token of prompt"""
|
||||
with root_tracer.start_span(
|
||||
"WorkerManager.count_token", params.get("span_id")
|
||||
) as span:
|
||||
params["span_id"] = span.span_id
|
||||
try:
|
||||
worker_run_data = await self._get_model(params)
|
||||
except Exception as e:
|
||||
raise e
|
||||
prompt = params.get("prompt")
|
||||
async with worker_run_data.semaphore:
|
||||
if worker_run_data.worker.support_async():
|
||||
return await worker_run_data.worker.async_count_token(prompt)
|
||||
else:
|
||||
return await self.run_blocking_func(
|
||||
worker_run_data.worker.count_token, prompt
|
||||
)
|
||||
|
||||
async def get_model_metadata(self, params: Dict) -> ModelMetadata:
|
||||
"""Get model metadata"""
|
||||
with root_tracer.start_span(
|
||||
"WorkerManager.get_model_metadata", params.get("span_id")
|
||||
) as span:
|
||||
params["span_id"] = span.span_id
|
||||
try:
|
||||
worker_run_data = await self._get_model(params)
|
||||
except Exception as e:
|
||||
raise e
|
||||
async with worker_run_data.semaphore:
|
||||
if worker_run_data.worker.support_async():
|
||||
return await worker_run_data.worker.async_get_model_metadata(params)
|
||||
else:
|
||||
return await self.run_blocking_func(
|
||||
worker_run_data.worker.get_model_metadata, params
|
||||
)
|
||||
|
||||
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
||||
apply_func: Callable[[WorkerApplyRequest], Awaitable[str]] = None
|
||||
if apply_req.apply_type == WorkerApplyType.START:
|
||||
@@ -601,6 +650,13 @@ class WorkerManagerAdapter(WorkerManager):
|
||||
worker_type, model_name, healthy_only
|
||||
)
|
||||
|
||||
async def get_all_model_instances(
|
||||
self, worker_type: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
return await self.worker_manager.get_all_model_instances(
|
||||
worker_type, healthy_only
|
||||
)
|
||||
|
||||
def sync_get_model_instances(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
@@ -635,6 +691,12 @@ class WorkerManagerAdapter(WorkerManager):
|
||||
def sync_embeddings(self, params: Dict) -> List[List[float]]:
|
||||
return self.worker_manager.sync_embeddings(params)
|
||||
|
||||
async def count_token(self, params: Dict) -> int:
|
||||
return await self.worker_manager.count_token(params)
|
||||
|
||||
async def get_model_metadata(self, params: Dict) -> ModelMetadata:
|
||||
return await self.worker_manager.get_model_metadata(params)
|
||||
|
||||
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
||||
return await self.worker_manager.worker_apply(apply_req)
|
||||
|
||||
@@ -696,6 +758,24 @@ async def api_embeddings(request: EmbeddingsRequest):
|
||||
return await worker_manager.embeddings(params)
|
||||
|
||||
|
||||
@router.post("/worker/count_token")
|
||||
async def api_count_token(request: CountTokenRequest):
|
||||
params = request.dict(exclude_none=True)
|
||||
span_id = root_tracer.get_current_span_id()
|
||||
if "span_id" not in params and span_id:
|
||||
params["span_id"] = span_id
|
||||
return await worker_manager.count_token(params)
|
||||
|
||||
|
||||
@router.post("/worker/model_metadata")
|
||||
async def api_get_model_metadata(request: ModelMetadataRequest):
|
||||
params = request.dict(exclude_none=True)
|
||||
span_id = root_tracer.get_current_span_id()
|
||||
if "span_id" not in params and span_id:
|
||||
params["span_id"] = span_id
|
||||
return await worker_manager.get_model_metadata(params)
|
||||
|
||||
|
||||
@router.post("/worker/apply")
|
||||
async def api_worker_apply(request: WorkerApplyRequest):
|
||||
return await worker_manager.worker_apply(request)
|
||||
|
Reference in New Issue
Block a user