mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-06 19:40:13 +00:00
feat(core): Support opentelemetry exporter (#1690)
This commit is contained in:
@@ -45,6 +45,7 @@ from dbgpt.model.cluster.registry import ModelRegistry
|
||||
from dbgpt.model.parameter import ModelAPIServerParameters, WorkerType
|
||||
from dbgpt.util.fastapi import create_app
|
||||
from dbgpt.util.parameter_utils import EnvArgumentParser
|
||||
from dbgpt.util.tracer import initialize_tracer, root_tracer
|
||||
from dbgpt.util.utils import setup_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -353,23 +354,34 @@ class APIServer(BaseComponent):
|
||||
return ChatCompletionResponse(model=model_name, choices=choices, usage=usage)
|
||||
|
||||
async def embeddings_generate(
|
||||
self, model: str, texts: List[str]
|
||||
self,
|
||||
model: str,
|
||||
texts: List[str],
|
||||
span_id: Optional[str] = None,
|
||||
) -> List[List[float]]:
|
||||
"""Generate embeddings
|
||||
|
||||
Args:
|
||||
model (str): Model name
|
||||
texts (List[str]): Texts to embed
|
||||
span_id (Optional[str], optional): The span id. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[List[float]]: The embeddings of texts
|
||||
"""
|
||||
worker_manager: WorkerManager = self.get_worker_manager()
|
||||
params = {
|
||||
"input": texts,
|
||||
"model": model,
|
||||
}
|
||||
return await worker_manager.embeddings(params)
|
||||
with root_tracer.start_span(
|
||||
"dbgpt.model.apiserver.generate_embeddings",
|
||||
parent_span_id=span_id,
|
||||
metadata={
|
||||
"model": model,
|
||||
},
|
||||
):
|
||||
worker_manager: WorkerManager = self.get_worker_manager()
|
||||
params = {
|
||||
"input": texts,
|
||||
"model": model,
|
||||
}
|
||||
return await worker_manager.embeddings(params)
|
||||
|
||||
async def relevance_generate(
|
||||
self, model: str, query: str, texts: List[str]
|
||||
@@ -438,12 +450,29 @@ async def create_chat_completion(
|
||||
params["user"] = request.user
|
||||
|
||||
# TODO check token length
|
||||
trace_kwargs = {
|
||||
"operation_name": "dbgpt.model.apiserver.create_chat_completion",
|
||||
"metadata": {
|
||||
"model": request.model,
|
||||
"messages": request.messages,
|
||||
"temperature": request.temperature,
|
||||
"top_p": request.top_p,
|
||||
"max_tokens": request.max_tokens,
|
||||
"stop": request.stop,
|
||||
"user": request.user,
|
||||
},
|
||||
}
|
||||
if request.stream:
|
||||
generator = api_server.chat_completion_stream_generator(
|
||||
request.model, params, request.n
|
||||
)
|
||||
return StreamingResponse(generator, media_type="text/event-stream")
|
||||
return await api_server.chat_completion_generate(request.model, params, request.n)
|
||||
trace_generator = root_tracer.wrapper_async_stream(generator, **trace_kwargs)
|
||||
return StreamingResponse(trace_generator, media_type="text/event-stream")
|
||||
else:
|
||||
with root_tracer.start_span(**trace_kwargs):
|
||||
return await api_server.chat_completion_generate(
|
||||
request.model, params, request.n
|
||||
)
|
||||
|
||||
|
||||
@router.post("/v1/embeddings", dependencies=[Depends(check_api_key)])
|
||||
@@ -462,7 +491,11 @@ async def create_embeddings(
|
||||
data = []
|
||||
async_tasks = []
|
||||
for num_batch, batch in enumerate(batches):
|
||||
async_tasks.append(api_server.embeddings_generate(request.model, batch))
|
||||
async_tasks.append(
|
||||
api_server.embeddings_generate(
|
||||
request.model, batch, span_id=root_tracer.get_current_span_id()
|
||||
)
|
||||
)
|
||||
|
||||
# Request all embeddings in parallel
|
||||
batch_embeddings: List[List[List[float]]] = await asyncio.gather(*async_tasks)
|
||||
@@ -486,15 +519,22 @@ async def create_embeddings(
|
||||
dependencies=[Depends(check_api_key)],
|
||||
response_model=RelevanceResponse,
|
||||
)
|
||||
async def create_embeddings(
|
||||
async def create_relevance(
|
||||
request: RelevanceRequest, api_server: APIServer = Depends(get_api_server)
|
||||
):
|
||||
"""Generate relevance scores for a query and a list of documents."""
|
||||
await api_server.get_model_instances_or_raise(request.model, worker_type="text2vec")
|
||||
|
||||
scores = await api_server.relevance_generate(
|
||||
request.model, request.query, request.documents
|
||||
)
|
||||
with root_tracer.start_span(
|
||||
"dbgpt.model.apiserver.generate_relevance",
|
||||
metadata={
|
||||
"model": request.model,
|
||||
"query": request.query,
|
||||
},
|
||||
):
|
||||
scores = await api_server.relevance_generate(
|
||||
request.model, request.query, request.documents
|
||||
)
|
||||
return model_to_dict(
|
||||
RelevanceResponse(data=scores, model=request.model, usage=UsageInfo()),
|
||||
exclude_none=True,
|
||||
@@ -534,6 +574,7 @@ def _initialize_all(controller_addr: str, system_app: SystemApp):
|
||||
|
||||
def initialize_apiserver(
|
||||
controller_addr: str,
|
||||
apiserver_params: Optional[ModelAPIServerParameters] = None,
|
||||
app=None,
|
||||
system_app: SystemApp = None,
|
||||
host: str = None,
|
||||
@@ -541,6 +582,10 @@ def initialize_apiserver(
|
||||
api_keys: List[str] = None,
|
||||
embedding_batch_size: Optional[int] = None,
|
||||
):
|
||||
import os
|
||||
|
||||
from dbgpt.configs.model_config import LOGDIR
|
||||
|
||||
global global_system_app
|
||||
global api_settings
|
||||
embedded_mod = True
|
||||
@@ -552,6 +597,18 @@ def initialize_apiserver(
|
||||
system_app = SystemApp(app)
|
||||
global_system_app = system_app
|
||||
|
||||
if apiserver_params:
|
||||
initialize_tracer(
|
||||
os.path.join(LOGDIR, apiserver_params.tracer_file),
|
||||
system_app=system_app,
|
||||
root_operation_name="DB-GPT-APIServer",
|
||||
tracer_storage_cls=apiserver_params.tracer_storage_cls,
|
||||
enable_open_telemetry=apiserver_params.tracer_to_open_telemetry,
|
||||
otlp_endpoint=apiserver_params.otel_exporter_otlp_traces_endpoint,
|
||||
otlp_insecure=apiserver_params.otel_exporter_otlp_traces_insecure,
|
||||
otlp_timeout=apiserver_params.otel_exporter_otlp_traces_timeout,
|
||||
)
|
||||
|
||||
if api_keys:
|
||||
api_settings.api_keys = api_keys
|
||||
|
||||
@@ -602,6 +659,7 @@ def run_apiserver():
|
||||
|
||||
initialize_apiserver(
|
||||
apiserver_params.controller_addr,
|
||||
apiserver_params,
|
||||
host=apiserver_params.host,
|
||||
port=apiserver_params.port,
|
||||
api_keys=api_keys,
|
||||
|
Reference in New Issue
Block a user