feat(core): Support opentelemetry exporter (#1690)

This commit is contained in:
Fangyin Cheng
2024-07-05 15:20:21 +08:00
committed by GitHub
parent 84fc1fc7fe
commit bf978d2bf9
39 changed files with 1176 additions and 218 deletions

View File

@@ -33,6 +33,10 @@ class ModelInstance:
prompt_template: Optional[str] = None
last_heartbeat: Optional[datetime] = None
def to_dict(self) -> Dict:
"""Convert to dict"""
return asdict(self)
class WorkerApplyType(str, Enum):
START = "start"

View File

@@ -45,6 +45,7 @@ from dbgpt.model.cluster.registry import ModelRegistry
from dbgpt.model.parameter import ModelAPIServerParameters, WorkerType
from dbgpt.util.fastapi import create_app
from dbgpt.util.parameter_utils import EnvArgumentParser
from dbgpt.util.tracer import initialize_tracer, root_tracer
from dbgpt.util.utils import setup_logging
logger = logging.getLogger(__name__)
@@ -353,23 +354,34 @@ class APIServer(BaseComponent):
return ChatCompletionResponse(model=model_name, choices=choices, usage=usage)
async def embeddings_generate(
self, model: str, texts: List[str]
self,
model: str,
texts: List[str],
span_id: Optional[str] = None,
) -> List[List[float]]:
"""Generate embeddings
Args:
model (str): Model name
texts (List[str]): Texts to embed
span_id (Optional[str], optional): The span id. Defaults to None.
Returns:
List[List[float]]: The embeddings of texts
"""
worker_manager: WorkerManager = self.get_worker_manager()
params = {
"input": texts,
"model": model,
}
return await worker_manager.embeddings(params)
with root_tracer.start_span(
"dbgpt.model.apiserver.generate_embeddings",
parent_span_id=span_id,
metadata={
"model": model,
},
):
worker_manager: WorkerManager = self.get_worker_manager()
params = {
"input": texts,
"model": model,
}
return await worker_manager.embeddings(params)
async def relevance_generate(
self, model: str, query: str, texts: List[str]
@@ -438,12 +450,29 @@ async def create_chat_completion(
params["user"] = request.user
# TODO check token length
trace_kwargs = {
"operation_name": "dbgpt.model.apiserver.create_chat_completion",
"metadata": {
"model": request.model,
"messages": request.messages,
"temperature": request.temperature,
"top_p": request.top_p,
"max_tokens": request.max_tokens,
"stop": request.stop,
"user": request.user,
},
}
if request.stream:
generator = api_server.chat_completion_stream_generator(
request.model, params, request.n
)
return StreamingResponse(generator, media_type="text/event-stream")
return await api_server.chat_completion_generate(request.model, params, request.n)
trace_generator = root_tracer.wrapper_async_stream(generator, **trace_kwargs)
return StreamingResponse(trace_generator, media_type="text/event-stream")
else:
with root_tracer.start_span(**trace_kwargs):
return await api_server.chat_completion_generate(
request.model, params, request.n
)
@router.post("/v1/embeddings", dependencies=[Depends(check_api_key)])
@@ -462,7 +491,11 @@ async def create_embeddings(
data = []
async_tasks = []
for num_batch, batch in enumerate(batches):
async_tasks.append(api_server.embeddings_generate(request.model, batch))
async_tasks.append(
api_server.embeddings_generate(
request.model, batch, span_id=root_tracer.get_current_span_id()
)
)
# Request all embeddings in parallel
batch_embeddings: List[List[List[float]]] = await asyncio.gather(*async_tasks)
@@ -486,15 +519,22 @@ async def create_embeddings(
dependencies=[Depends(check_api_key)],
response_model=RelevanceResponse,
)
async def create_embeddings(
async def create_relevance(
request: RelevanceRequest, api_server: APIServer = Depends(get_api_server)
):
"""Generate relevance scores for a query and a list of documents."""
await api_server.get_model_instances_or_raise(request.model, worker_type="text2vec")
scores = await api_server.relevance_generate(
request.model, request.query, request.documents
)
with root_tracer.start_span(
"dbgpt.model.apiserver.generate_relevance",
metadata={
"model": request.model,
"query": request.query,
},
):
scores = await api_server.relevance_generate(
request.model, request.query, request.documents
)
return model_to_dict(
RelevanceResponse(data=scores, model=request.model, usage=UsageInfo()),
exclude_none=True,
@@ -534,6 +574,7 @@ def _initialize_all(controller_addr: str, system_app: SystemApp):
def initialize_apiserver(
controller_addr: str,
apiserver_params: Optional[ModelAPIServerParameters] = None,
app=None,
system_app: SystemApp = None,
host: str = None,
@@ -541,6 +582,10 @@ def initialize_apiserver(
api_keys: List[str] = None,
embedding_batch_size: Optional[int] = None,
):
import os
from dbgpt.configs.model_config import LOGDIR
global global_system_app
global api_settings
embedded_mod = True
@@ -552,6 +597,18 @@ def initialize_apiserver(
system_app = SystemApp(app)
global_system_app = system_app
if apiserver_params:
initialize_tracer(
os.path.join(LOGDIR, apiserver_params.tracer_file),
system_app=system_app,
root_operation_name="DB-GPT-APIServer",
tracer_storage_cls=apiserver_params.tracer_storage_cls,
enable_open_telemetry=apiserver_params.tracer_to_open_telemetry,
otlp_endpoint=apiserver_params.otel_exporter_otlp_traces_endpoint,
otlp_insecure=apiserver_params.otel_exporter_otlp_traces_insecure,
otlp_timeout=apiserver_params.otel_exporter_otlp_traces_timeout,
)
if api_keys:
api_settings.api_keys = api_keys
@@ -602,6 +659,7 @@ def run_apiserver():
initialize_apiserver(
apiserver_params.controller_addr,
apiserver_params,
host=apiserver_params.host,
port=apiserver_params.port,
api_keys=api_keys,

View File

@@ -52,7 +52,7 @@ async def client(request, system_app: SystemApp):
worker_manager, model_registry = cluster
system_app.register(_DefaultWorkerManagerFactory, worker_manager)
system_app.register_instance(model_registry)
initialize_apiserver(None, app, system_app, api_keys=api_keys)
initialize_apiserver(None, None, app, system_app, api_keys=api_keys)
yield client

View File

@@ -1,4 +1,5 @@
import logging
import os
from abc import ABC, abstractmethod
from typing import List, Literal, Optional
@@ -13,6 +14,7 @@ from dbgpt.util.api_utils import _api_remote as api_remote
from dbgpt.util.api_utils import _sync_api_remote as sync_api_remote
from dbgpt.util.fastapi import create_app
from dbgpt.util.parameter_utils import EnvArgumentParser
from dbgpt.util.tracer.tracer_impl import initialize_tracer, root_tracer
from dbgpt.util.utils import setup_http_service_logging, setup_logging
logger = logging.getLogger(__name__)
@@ -159,7 +161,10 @@ def initialize_controller(
host: str = None,
port: int = None,
registry: Optional[ModelRegistry] = None,
controller_params: Optional[ModelControllerParameters] = None,
system_app: Optional[SystemApp] = None,
):
global controller
if remote_controller_addr:
controller.backend = _RemoteModelController(remote_controller_addr)
@@ -173,8 +178,25 @@ def initialize_controller(
else:
import uvicorn
from dbgpt.configs.model_config import LOGDIR
setup_http_service_logging()
app = create_app()
if not system_app:
system_app = SystemApp(app)
if not controller_params:
raise ValueError("Controller parameters are required.")
initialize_tracer(
os.path.join(LOGDIR, controller_params.tracer_file),
root_operation_name="DB-GPT-ModelController",
system_app=system_app,
tracer_storage_cls=controller_params.tracer_storage_cls,
enable_open_telemetry=controller_params.tracer_to_open_telemetry,
otlp_endpoint=controller_params.otel_exporter_otlp_traces_endpoint,
otlp_insecure=controller_params.otel_exporter_otlp_traces_insecure,
otlp_timeout=controller_params.otel_exporter_otlp_traces_timeout,
)
app.include_router(router, prefix="/api", tags=["Model"])
uvicorn.run(app, host=host, port=port, log_level="info")
@@ -187,13 +209,19 @@ async def api_health_check():
@router.post("/controller/models")
async def api_register_instance(request: ModelInstance):
return await controller.register_instance(request)
with root_tracer.start_span(
"dbgpt.model.controller.register_instance", metadata=request.to_dict()
):
return await controller.register_instance(request)
@router.delete("/controller/models")
async def api_deregister_instance(model_name: str, host: str, port: int):
instance = ModelInstance(model_name=model_name, host=host, port=port)
return await controller.deregister_instance(instance)
with root_tracer.start_span(
"dbgpt.model.controller.deregister_instance", metadata=instance.to_dict()
):
return await controller.deregister_instance(instance)
@router.get("/controller/models")
@@ -303,7 +331,10 @@ def run_model_controller():
registry = _create_registry(controller_params)
initialize_controller(
host=controller_params.host, port=controller_params.port, registry=registry
host=controller_params.host,
port=controller_params.port,
registry=registry,
controller_params=controller_params,
)

View File

@@ -320,17 +320,16 @@ class DefaultModelWorker(ModelWorker):
map(lambda m: m.dict(), span_params["messages"])
)
model_span = root_tracer.start_span(
span_operation_name,
metadata={
"prompt": str_prompt,
"params": span_params,
"is_async_func": self.support_async(),
"llm_adapter": str(self.llm_adapter),
"generate_stream_func": generate_stream_func_str_name,
"model_context": model_context,
},
)
metadata = {
"is_async_func": self.support_async(),
"llm_adapter": str(self.llm_adapter),
"generate_stream_func": generate_stream_func_str_name,
}
metadata.update(span_params)
metadata.update(model_context)
metadata["prompt"] = str_prompt
model_span = root_tracer.start_span(span_operation_name, metadata=metadata)
return params, model_context, generate_stream_func, model_span

View File

@@ -827,12 +827,18 @@ async def api_model_shutdown(request: WorkerStartupRequest):
def _setup_fastapi(
worker_params: ModelWorkerParameters, app=None, ignore_exception: bool = False
worker_params: ModelWorkerParameters,
app=None,
ignore_exception: bool = False,
system_app: Optional[SystemApp] = None,
):
if not app:
app = create_app()
setup_http_service_logging()
if system_app:
system_app._asgi_app = app
if worker_params.standalone:
from dbgpt.model.cluster.controller.controller import initialize_controller
from dbgpt.model.cluster.controller.controller import (
@@ -848,7 +854,7 @@ def _setup_fastapi(
logger.info(
f"Run WorkerManager with standalone mode, controller_addr: {worker_params.controller_addr}"
)
initialize_controller(app=app)
initialize_controller(app=app, system_app=system_app)
app.include_router(controller_router, prefix="/api")
async def startup_event():
@@ -1074,7 +1080,7 @@ def initialize_worker_manager_in_client(
worker_params.register = True
worker_params.port = local_port
logger.info(f"Worker params: {worker_params}")
_setup_fastapi(worker_params, app, ignore_exception=True)
_setup_fastapi(worker_params, app, ignore_exception=True, system_app=system_app)
_start_local_worker(worker_manager, worker_params)
worker_manager.after_start(start_listener)
_start_local_embedding_worker(
@@ -1100,7 +1106,9 @@ def initialize_worker_manager_in_client(
worker_manager.worker_manager = RemoteWorkerManager(client)
worker_manager.after_start(start_listener)
initialize_controller(
app=app, remote_controller_addr=worker_params.controller_addr
app=app,
remote_controller_addr=worker_params.controller_addr,
system_app=system_app,
)
loop = asyncio.get_event_loop()
loop.run_until_complete(worker_manager.start())
@@ -1140,17 +1148,22 @@ def run_worker_manager(
embedded_mod = True
logger.info(f"Worker params: {worker_params}")
system_app = SystemApp()
if not app:
# Run worker manager independently
embedded_mod = False
app = _setup_fastapi(worker_params)
app = _setup_fastapi(worker_params, system_app=system_app)
system_app._asgi_app = app
system_app = SystemApp(app)
initialize_tracer(
os.path.join(LOGDIR, worker_params.tracer_file),
system_app=system_app,
root_operation_name="DB-GPT-WorkerManager-Entry",
root_operation_name="DB-GPT-ModelWorker",
tracer_storage_cls=worker_params.tracer_storage_cls,
enable_open_telemetry=worker_params.tracer_to_open_telemetry,
otlp_endpoint=worker_params.otel_exporter_otlp_traces_endpoint,
otlp_insecure=worker_params.otel_exporter_otlp_traces_insecure,
otlp_timeout=worker_params.otel_exporter_otlp_traces_timeout,
)
_start_local_worker(worker_manager, worker_params)

View File

@@ -5,6 +5,7 @@ from typing import Dict, Iterator, List
from dbgpt.core import ModelMetadata, ModelOutput
from dbgpt.model.cluster.worker_base import ModelWorker
from dbgpt.model.parameter import ModelParameters
from dbgpt.util.tracer import DBGPT_TRACER_SPAN_ID, root_tracer
logger = logging.getLogger(__name__)
@@ -57,7 +58,7 @@ class RemoteModelWorker(ModelWorker):
async with client.stream(
"POST",
url,
headers=self.headers,
headers=self._get_trace_headers(),
json=params,
timeout=self.timeout,
) as response:
@@ -84,7 +85,7 @@ class RemoteModelWorker(ModelWorker):
logger.debug(f"Send async_generate to url {url}, params: {params}")
response = await client.post(
url,
headers=self.headers,
headers=self._get_trace_headers(),
json=params,
timeout=self.timeout,
)
@@ -101,7 +102,7 @@ class RemoteModelWorker(ModelWorker):
logger.debug(f"Send async_count_token to url {url}, params: {prompt}")
response = await client.post(
url,
headers=self.headers,
headers=self._get_trace_headers(),
json={"prompt": prompt},
timeout=self.timeout,
)
@@ -118,7 +119,7 @@ class RemoteModelWorker(ModelWorker):
)
response = await client.post(
url,
headers=self.headers,
headers=self._get_trace_headers(),
json=params,
timeout=self.timeout,
)
@@ -136,7 +137,7 @@ class RemoteModelWorker(ModelWorker):
logger.debug(f"Send embeddings to url {url}, params: {params}")
response = requests.post(
url,
headers=self.headers,
headers=self._get_trace_headers(),
json=params,
timeout=self.timeout,
)
@@ -151,8 +152,14 @@ class RemoteModelWorker(ModelWorker):
logger.debug(f"Send async_embeddings to url {url}")
response = await client.post(
url,
headers=self.headers,
headers=self._get_trace_headers(),
json=params,
timeout=self.timeout,
)
return response.json()
def _get_trace_headers(self):
span_id = root_tracer.get_current_span_id()
headers = self.headers.copy()
headers.update({DBGPT_TRACER_SPAN_ID: span_id})
return headers

View File

@@ -6,7 +6,7 @@ from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, Optional, Tuple, Union
from dbgpt.util.parameter_utils import BaseParameters
from dbgpt.util.parameter_utils import BaseParameters, BaseServerParameters
class WorkerType(str, Enum):
@@ -48,10 +48,7 @@ class WorkerType(str, Enum):
@dataclass
class ModelControllerParameters(BaseParameters):
host: Optional[str] = field(
default="0.0.0.0", metadata={"help": "Model Controller deploy host"}
)
class ModelControllerParameters(BaseServerParameters):
port: Optional[int] = field(
default=8000, metadata={"help": "Model Controller deploy port"}
)
@@ -133,24 +130,6 @@ class ModelControllerParameters(BaseParameters):
},
)
daemon: Optional[bool] = field(
default=False, metadata={"help": "Run Model Controller in background"}
)
log_level: Optional[str] = field(
default=None,
metadata={
"help": "Logging level",
"valid_values": [
"FATAL",
"ERROR",
"WARNING",
"WARNING",
"INFO",
"DEBUG",
"NOTSET",
],
},
)
log_file: Optional[str] = field(
default="dbgpt_model_controller.log",
metadata={
@@ -172,16 +151,10 @@ class ModelControllerParameters(BaseParameters):
@dataclass
class ModelAPIServerParameters(BaseParameters):
host: Optional[str] = field(
default="0.0.0.0", metadata={"help": "Model API server deploy host"}
)
class ModelAPIServerParameters(BaseServerParameters):
port: Optional[int] = field(
default=8100, metadata={"help": "Model API server deploy port"}
)
daemon: Optional[bool] = field(
default=False, metadata={"help": "Run Model API server in background"}
)
controller_addr: Optional[str] = field(
default="http://127.0.0.1:8000",
metadata={"help": "The Model controller address to connect"},
@@ -195,21 +168,6 @@ class ModelAPIServerParameters(BaseParameters):
default=None, metadata={"help": "Embedding batch size"}
)
log_level: Optional[str] = field(
default=None,
metadata={
"help": "Logging level",
"valid_values": [
"FATAL",
"ERROR",
"WARNING",
"WARNING",
"INFO",
"DEBUG",
"NOTSET",
],
},
)
log_file: Optional[str] = field(
default="dbgpt_model_apiserver.log",
metadata={
@@ -237,7 +195,7 @@ class BaseModelParameters(BaseParameters):
@dataclass
class ModelWorkerParameters(BaseModelParameters):
class ModelWorkerParameters(BaseServerParameters, BaseModelParameters):
worker_type: Optional[str] = field(
default=None,
metadata={"valid_values": WorkerType.values(), "help": "Worker type"},
@@ -257,16 +215,10 @@ class ModelWorkerParameters(BaseModelParameters):
"tags": "fixed",
},
)
host: Optional[str] = field(
default="0.0.0.0", metadata={"help": "Model worker deploy host"}
)
port: Optional[int] = field(
default=8001, metadata={"help": "Model worker deploy port"}
)
daemon: Optional[bool] = field(
default=False, metadata={"help": "Run Model Worker in background"}
)
limit_model_concurrency: Optional[int] = field(
default=5, metadata={"help": "Model concurrency limit"}
)
@@ -280,7 +232,8 @@ class ModelWorkerParameters(BaseModelParameters):
worker_register_host: Optional[str] = field(
default=None,
metadata={
"help": "The ip address of current worker to register to ModelController. If None, the address is automatically determined"
"help": "The ip address of current worker to register to ModelController. "
"If None, the address is automatically determined"
},
)
controller_addr: Optional[str] = field(
@@ -293,21 +246,6 @@ class ModelWorkerParameters(BaseModelParameters):
default=20, metadata={"help": "The interval for sending heartbeats (seconds)"}
)
log_level: Optional[str] = field(
default=None,
metadata={
"help": "Logging level",
"valid_values": [
"FATAL",
"ERROR",
"WARNING",
"WARNING",
"INFO",
"DEBUG",
"NOTSET",
],
},
)
log_file: Optional[str] = field(
default="dbgpt_model_worker_manager.log",
metadata={