feat(core): Support cross-service data recording

This commit is contained in:
FangYin Cheng 2023-10-11 04:23:25 +08:00
parent c0219a672e
commit 1e919aeef3
17 changed files with 948 additions and 14 deletions

View File

@ -47,6 +47,8 @@ class ComponentType(str, Enum):
WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory"
MODEL_CONTROLLER = "dbgpt_model_controller"
EXECUTOR_DEFAULT = "dbgpt_thread_pool_default"
TRACER = "dbgpt_tracer"
TRACER_SPAN_STORAGE = "dbgpt_tracer_span_storage"
class BaseComponent(LifeCycle, ABC):
@ -70,6 +72,8 @@ class BaseComponent(LifeCycle, ABC):
T = TypeVar("T", bound=BaseComponent)
_EMPTY_DEFAULT_COMPONENT = "_EMPTY_DEFAULT_COMPONENT"
class SystemApp(LifeCycle):
"""Main System Application class that manages the lifecycle and registration of components."""
@ -104,13 +108,18 @@ class SystemApp(LifeCycle):
instance.init_app(self)
def get_component(
self, name: Union[str, ComponentType], component_type: Type[T]
self,
name: Union[str, ComponentType],
component_type: Type[T],
default_component=_EMPTY_DEFAULT_COMPONENT,
) -> T:
"""Retrieve a registered component by its name and type."""
if isinstance(name, ComponentType):
name = name.value
component = self.components.get(name)
if not component:
if default_component != _EMPTY_DEFAULT_COMPONENT:
return default_component
raise ValueError(f"No component found with name {name}")
if not isinstance(component, component_type):
raise TypeError(f"Component {name} is not of type {component_type}")

View File

@ -354,7 +354,7 @@ class LlamaCppAdapater(BaseLLMAdaper):
if not path.is_file():
model_paths = list(path.glob("*ggml*.gguf"))
if not model_paths:
return False
return False, None
model_path = str(model_paths[0])
logger.warn(
f"Model path {model_path} is not single file, use first *gglm*.gguf model file: {model_path}"

View File

@ -53,6 +53,9 @@ class ModelOutput:
error_code: int
model_context: Dict = None
def to_dict(self) -> Dict:
return asdict(self)
@dataclass
class WorkerApplyOutput:

View File

@ -10,6 +10,7 @@ from pilot.model.parameter import ModelParameters
from pilot.model.cluster.worker_base import ModelWorker
from pilot.utils.model_utils import _clear_model_cache
from pilot.utils.parameter_utils import EnvArgumentParser
from pilot.utils.tracer import root_tracer
logger = logging.getLogger(__name__)
@ -109,9 +110,16 @@ class DefaultModelWorker(ModelWorker):
_clear_model_cache(self._model_params.device)
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
span = root_tracer.start_span("DefaultModelWorker.generate_stream")
try:
params, model_context, generate_stream_func = self._prepare_generate_stream(
params
(
params,
model_context,
generate_stream_func,
model_span,
) = self._prepare_generate_stream(
params,
span_operation_name="DefaultModelWorker_call.generate_stream_func",
)
previous_response = ""
@ -127,8 +135,12 @@ class DefaultModelWorker(ModelWorker):
print(
f"\n\nfull stream output:\n{previous_response}\n\nmodel generate_stream params:\n{params}"
)
model_span.end(metadata={"output": previous_response})
span.end()
except Exception as e:
yield self._handle_exception(e)
output = self._handle_exception(e)
yield output
span.end(metadata={"error": output.to_dict()})
def generate(self, params: Dict) -> ModelOutput:
"""Generate non stream result"""
@ -141,9 +153,16 @@ class DefaultModelWorker(ModelWorker):
raise NotImplementedError
async def async_generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
span = root_tracer.start_span("DefaultModelWorker.async_generate_stream")
try:
params, model_context, generate_stream_func = self._prepare_generate_stream(
params
(
params,
model_context,
generate_stream_func,
model_span,
) = self._prepare_generate_stream(
params,
span_operation_name="DefaultModelWorker_call.generate_stream_func",
)
previous_response = ""
@ -159,8 +178,12 @@ class DefaultModelWorker(ModelWorker):
print(
f"\n\nfull stream output:\n{previous_response}\n\nmodel generate_stream params:\n{params}"
)
model_span.end(metadata={"output": previous_response})
span.end()
except Exception as e:
yield self._handle_exception(e)
output = self._handle_exception(e)
yield output
span.end(metadata={"error": output.to_dict()})
async def async_generate(self, params: Dict) -> ModelOutput:
output = None
@ -168,7 +191,7 @@ class DefaultModelWorker(ModelWorker):
output = out
return output
def _prepare_generate_stream(self, params: Dict):
def _prepare_generate_stream(self, params: Dict, span_operation_name: str):
params, model_context = self.llm_adapter.model_adaptation(
params,
self.model_name,
@ -190,7 +213,30 @@ class DefaultModelWorker(ModelWorker):
)
str_prompt = params.get("prompt")
print(f"model prompt: \n\n{str_prompt}\n\n{stream_type}stream output:\n")
return params, model_context, generate_stream_func
generate_stream_func_str_name = "{}.{}".format(
generate_stream_func.__module__, generate_stream_func.__name__
)
span_params = {k: v for k, v in params.items()}
if "messages" in span_params:
span_params["messages"] = list(
map(lambda m: m.dict(), span_params["messages"])
)
model_span = root_tracer.start_span(
span_operation_name,
metadata={
"prompt": str_prompt,
"params": span_params,
"is_async_func": self.support_async(),
"llm_adapter": str(self.llm_adapter),
"generate_stream_func": generate_stream_func_str_name,
"model_context": model_context,
},
)
return params, model_context, generate_stream_func, model_span
def _handle_output(self, output, previous_response, model_context):
if isinstance(output, dict):

View File

@ -186,6 +186,13 @@ class OldLLMModelAdaperWrapper(LLMModelAdaper):
def get_generate_stream_function(self, model, model_path: str):
return self._chat_adapter.get_generate_stream_func(model_path)
def __str__(self) -> str:
return "{}({}.{})".format(
self.__class__.__name__,
self._adapter.__class__.__module__,
self._adapter.__class__.__name__,
)
class FastChatLLMModelAdaperWrapper(LLMModelAdaper):
"""Wrapping fastchat adapter"""
@ -206,6 +213,13 @@ class FastChatLLMModelAdaperWrapper(LLMModelAdaper):
) -> "Conversation":
return self._adapter.get_default_conv_template(model_path)
def __str__(self) -> str:
return "{}({}.{})".format(
self.__class__.__name__,
self._adapter.__class__.__module__,
self._adapter.__class__.__name__,
)
def get_conv_template(name: str) -> "Conversation":
"""Get a conversation template."""
@ -412,6 +426,9 @@ class VLLMModelAdaperWrapper(LLMModelAdaper):
) -> "Conversation":
return _auto_get_conv_template(model_name, model_path)
def __str__(self) -> str:
return "{}.{}".format(self.__class__.__module__, self.__class__.__name__)
# Covering the configuration of fastcaht, we will regularly feedback the code here to fastchat.
# We also recommend that you modify it directly in the fastchat repository.

View File

@ -46,6 +46,7 @@ from pilot.summary.db_summary_client import DBSummaryClient
from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
from pilot.model.base import FlatSupportedModel
from pilot.utils.tracer import root_tracer
router = APIRouter()
CFG = Config()
@ -366,7 +367,8 @@ async def chat_completions(dialogue: ConversationVo = Body()):
print(
f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}"
)
chat: BaseChat = get_chat_instance(dialogue)
with root_tracer.start_span("chat_completions", metadata=dialogue.dict()) as _:
chat: BaseChat = get_chat_instance(dialogue)
# background_tasks = BackgroundTasks()
# background_tasks.add_task(release_model_semaphore)
headers = {
@ -417,9 +419,11 @@ async def model_supports(worker_manager: WorkerManager = Depends(get_worker_mana
async def no_stream_generator(chat):
span = root_tracer.start_span("no_stream_generator")
msg = await chat.nostream_call()
msg = msg.replace("\n", "\\n")
yield f"data: {msg}\n\n"
span.end()
async def stream_generator(chat, incremental: bool, model_name: str):
@ -436,6 +440,7 @@ async def stream_generator(chat, incremental: bool, model_name: str):
Yields:
_type_: streaming responses
"""
span = root_tracer.start_span("stream_generator")
msg = "[LLM_ERROR]: llm server has no output, maybe your prompt template is wrong."
stream_id = f"chatcmpl-{str(uuid.uuid1())}"
@ -462,6 +467,7 @@ async def stream_generator(chat, incremental: bool, model_name: str):
yield f"data:{msg}\n\n"
previous_response = msg
await asyncio.sleep(0.02)
span.end()
if incremental:
yield "data: [DONE]\n\n"
chat.current_message.add_ai_message(msg)

View File

@ -15,6 +15,7 @@ from pilot.prompts.prompt_new import PromptTemplate
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
from pilot.scene.message import OnceConversation
from pilot.utils import get_or_create_event_loop
from pilot.utils.tracer import root_tracer
from pydantic import Extra
logger = logging.getLogger(__name__)
@ -135,6 +136,12 @@ class BaseChat(ABC):
}
return payload
def _get_span_metadata(self, payload: Dict) -> Dict:
metadata = {k: v for k, v in payload.items()}
del metadata["prompt"]
metadata["messages"] = list(map(lambda m: m.dict(), metadata["messages"]))
return metadata
async def stream_call(self):
# TODO Retry when server connection error
payload = self.__call_base()
@ -142,6 +149,9 @@ class BaseChat(ABC):
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
logger.info(f"Request: \n{payload}")
ai_response_text = ""
span = root_tracer.start_span(
"BaseChat.stream_call", metadata=self._get_span_metadata(payload)
)
try:
from pilot.model.cluster import WorkerManagerFactory
@ -150,6 +160,7 @@ class BaseChat(ABC):
).create()
async for output in worker_manager.generate_stream(payload):
yield output
span.end()
except Exception as e:
print(traceback.format_exc())
logger.error("model response parase faild" + str(e))
@ -158,11 +169,15 @@ class BaseChat(ABC):
)
### store current conversation
self.memory.append(self.current_message)
span.end(metadata={"error": str(e)})
async def nostream_call(self):
payload = self.__call_base()
logger.info(f"Request: \n{payload}")
ai_response_text = ""
span = root_tracer.start_span(
"BaseChat.nostream_call", metadata=self._get_span_metadata(payload)
)
try:
from pilot.model.cluster import WorkerManagerFactory
@ -170,7 +185,8 @@ class BaseChat(ABC):
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
model_output = await worker_manager.generate(payload)
with root_tracer.start_span("BaseChat.invoke_worker_manager.generate") as _:
model_output = await worker_manager.generate(payload)
### output parse
ai_response_text = (
@ -185,8 +201,14 @@ class BaseChat(ABC):
ai_response_text
)
)
### run
result = self.do_action(prompt_define_response)
metadata = {
"model_output": model_output.to_dict(),
"ai_response_text": ai_response_text,
"prompt_define_response": prompt_define_response,
}
with root_tracer.start_span("BaseChat.do_action", metadata=metadata) as _:
### run
result = self.do_action(prompt_define_response)
### llm speaker
speak_to_user = self.get_llm_speak(prompt_define_response)
@ -195,12 +217,14 @@ class BaseChat(ABC):
speak_to_user, result
)
self.current_message.add_view_message(view_message)
span.end()
except Exception as e:
print(traceback.format_exc())
logger.error("model response parase faild" + str(e))
self.current_message.add_view_message(
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
)
span.end(metadata={"error": str(e)})
### store dialogue
self.memory.append(self.current_message)
return self.current_ai_response()

View File

@ -2,11 +2,14 @@ from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Type
import os
from pilot.component import ComponentType, SystemApp
from pilot.utils.executor_utils import DefaultExecutorFactory
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
from pilot.server.base import WebWerverParameters
from pilot.configs.model_config import LOGDIR
from pilot.utils.tracer import root_tracer, initialize_tracer
if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings
@ -23,6 +26,8 @@ def initialize_components(
):
from pilot.model.cluster.controller.controller import controller
initialize_tracer(system_app, os.path.join(LOGDIR, "dbgpt_webserver_tracer.jsonl"))
# Register global default executor factory first
system_app.register(DefaultExecutorFactory)

View File

@ -0,0 +1,28 @@
from pilot.utils.tracer.base import (
Span,
Tracer,
SpanStorage,
SpanStorageType,
TracerContext,
)
from pilot.utils.tracer.span_storage import MemorySpanStorage, FileSpanStorage
from pilot.utils.tracer.tracer_impl import (
root_tracer,
initialize_tracer,
DefaultTracer,
TracerManager,
)
__all__ = [
"Span",
"Tracer",
"SpanStorage",
"SpanStorageType",
"TracerContext",
"MemorySpanStorage",
"FileSpanStorage",
"root_tracer",
"initialize_tracer",
"DefaultTracer",
"TracerManager",
]

161
pilot/utils/tracer/base.py Normal file
View File

@ -0,0 +1,161 @@
from __future__ import annotations
from typing import Dict, Callable, Optional
from dataclasses import dataclass
from abc import ABC, abstractmethod
from enum import Enum
import uuid
from datetime import datetime
from pilot.component import BaseComponent, SystemApp, ComponentType
class Span:
"""Represents a unit of work that is being traced.
This can be any operation like a function call or a database query.
"""
span_type: str = "base"
def __init__(
self,
trace_id: str,
span_id: str,
parent_span_id: str = None,
operation_name: str = None,
metadata: Dict = None,
end_caller: Callable[[Span], None] = None,
):
# The unique identifier for the entire trace
self.trace_id = trace_id
# Unique identifier for this span within the trace
self.span_id = span_id
# Identifier of the parent span, if this is a child span
self.parent_span_id = parent_span_id
# Descriptive name for the operation being traced
self.operation_name = operation_name
# Timestamp when this span started
self.start_time = datetime.now()
# Timestamp when this span ended, initially None
self.end_time = None
# Additional metadata associated with the span
self.metadata = metadata
self._end_callers = []
if end_caller:
self._end_callers.append(end_caller)
def end(self, **kwargs):
"""Mark the end of this span by recording the current time."""
self.end_time = datetime.now()
if "metadata" in kwargs:
self.metadata = kwargs.get("metadata")
for caller in self._end_callers:
caller(self)
def add_end_caller(self, end_caller: Callable[[Span], None]):
if end_caller:
self._end_callers.append(end_caller)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.end()
return False
def to_dict(self) -> Dict:
return {
"span_type": self.span_type,
"trace_id": self.trace_id,
"span_id": self.span_id,
"parent_span_id": self.parent_span_id,
"operation_name": self.operation_name,
"start_time": None
if not self.start_time
else self.start_time.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3],
"end_time": None
if not self.end_time
else self.end_time.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3],
"metadata": self.metadata,
}
class SpanStorageType(str, Enum):
ON_CREATE = "on_create"
ON_END = "on_end"
ON_CREATE_END = "on_create_end"
class SpanStorage(BaseComponent, ABC):
"""Abstract base class for storing spans.
This allows different storage mechanisms (e.g., in-memory, database) to be implemented.
"""
name = ComponentType.TRACER_SPAN_STORAGE.value
def init_app(self, system_app: SystemApp):
"""Initialize the storage with the given application context."""
pass
@abstractmethod
def append_span(self, span: Span):
"""Store the given span. This needs to be implemented by subclasses."""
class Tracer(BaseComponent, ABC):
"""Abstract base class for tracing operations.
Provides the core logic for starting, ending, and retrieving spans.
"""
name = ComponentType.TRACER.value
def __init__(self, system_app: SystemApp | None = None):
super().__init__(system_app)
self.system_app = system_app # Application context
def init_app(self, system_app: SystemApp):
"""Initialize the tracer with the given application context."""
self.system_app = system_app
@abstractmethod
def append_span(self, span: Span):
"""Append the given span to storage. This needs to be implemented by subclasses."""
@abstractmethod
def start_span(
self, operation_name: str, parent_span_id: str = None, metadata: Dict = None
) -> Span:
"""Begin a new span for the given operation. If provided, the span will be
a child of the span with the given parent_span_id.
"""
@abstractmethod
def end_span(self, span: Span, **kwargs):
"""
End the given span.
"""
@abstractmethod
def get_current_span(self) -> Optional[Span]:
"""
Retrieve the span that is currently being traced.
"""
@abstractmethod
def _get_current_storage(self) -> SpanStorage:
"""
Get the storage mechanism currently in use for storing spans.
This needs to be implemented by subclasses.
"""
def _new_uuid(self) -> str:
"""
Generate a new unique identifier.
"""
return str(uuid.uuid4())
@dataclass
class TracerContext:
span_id: str

View File

@ -0,0 +1,79 @@
import os
import json
import time
import threading
import queue
import logging
from pilot.component import SystemApp
from pilot.utils.tracer.base import Span, SpanStorage
logger = logging.getLogger(__name__)
class MemorySpanStorage(SpanStorage):
def __init__(self, system_app: SystemApp | None = None):
super().__init__(system_app)
self.spans = []
self._lock = threading.Lock()
def append_span(self, span: Span):
with self._lock:
self.spans.append(span)
class FileSpanStorage(SpanStorage):
def __init__(self, filename: str, batch_size=10, flush_interval=10):
super().__init__()
self.filename = filename
self.queue = queue.Queue()
self.batch_size = batch_size
self.flush_interval = flush_interval
self.last_flush_time = time.time()
self.flush_signal_queue = queue.Queue()
if not os.path.exists(filename):
with open(filename, "w") as _:
pass
self.flush_thread = threading.Thread(target=self._flush_to_file, daemon=True)
self.flush_thread.start()
def append_span(self, span: Span):
span_data = span.to_dict()
logger.debug(f"append span: {span_data}")
self.queue.put(span_data)
if self.queue.qsize() >= self.batch_size:
try:
self.flush_signal_queue.put_nowait(True)
except queue.Full:
pass # If the signal queue is full, it's okay. The flush thread will handle it.
def _write_to_file(self):
spans_to_write = []
while not self.queue.empty():
spans_to_write.append(self.queue.get())
with open(self.filename, "a") as file:
for span_data in spans_to_write:
try:
file.write(json.dumps(span_data, ensure_ascii=False) + "\n")
except Exception as e:
logger.warning(
f"Write span to file failed: {str(e)}, span_data: {span_data}"
)
def _flush_to_file(self):
while True:
interval = time.time() - self.last_flush_time
if interval < self.flush_interval:
try:
self.flush_signal_queue.get(
block=True, timeout=self.flush_interval - interval
)
except Exception:
# Timeout
pass
self._write_to_file()
self.last_flush_time = time.time()

View File

View File

@ -0,0 +1,122 @@
from typing import Dict
from pilot.component import SystemApp
from pilot.utils.tracer import Span, SpanStorage, Tracer
# Mock implementations
class MockSpanStorage(SpanStorage):
def __init__(self):
self.spans = []
def append_span(self, span: Span):
self.spans.append(span)
class MockTracer(Tracer):
def __init__(self, system_app: SystemApp | None = None):
super().__init__(system_app)
self.current_span = None
self.storage = MockSpanStorage()
def append_span(self, span: Span):
self.storage.append_span(span)
def start_span(
self, operation_name: str, parent_span_id: str = None, metadata: Dict = None
) -> Span:
trace_id = (
self._new_uuid() if parent_span_id is None else parent_span_id.split(":")[0]
)
span_id = f"{trace_id}:{self._new_uuid()}"
span = Span(trace_id, span_id, parent_span_id, operation_name, metadata)
self.current_span = span
return span
def end_span(self, span: Span):
span.end()
self.append_span(span)
def get_current_span(self) -> Span:
return self.current_span
def _get_current_storage(self) -> SpanStorage:
return self.storage
# Tests
def test_span_creation():
span = Span("trace_id", "span_id", "parent_span_id", "operation", {"key": "value"})
assert span.trace_id == "trace_id"
assert span.span_id == "span_id"
assert span.parent_span_id == "parent_span_id"
assert span.operation_name == "operation"
assert span.metadata == {"key": "value"}
def test_span_end():
span = Span("trace_id", "span_id")
assert span.end_time is None
span.end()
assert span.end_time is not None
def test_mock_tracer_start_span():
tracer = MockTracer()
span = tracer.start_span("operation")
assert span.operation_name == "operation"
assert tracer.get_current_span() == span
def test_mock_tracer_end_span():
tracer = MockTracer()
span = tracer.start_span("operation")
tracer.end_span(span)
assert span in tracer._get_current_storage().spans
def test_mock_tracer_append_span():
tracer = MockTracer()
span = Span("trace_id", "span_id")
tracer.append_span(span)
assert span in tracer._get_current_storage().spans
def test_parent_child_span_relation():
tracer = MockTracer()
# Start a parent span
parent_span = tracer.start_span("parent_operation")
# Start a child span with parent span's ID
child_span = tracer.start_span(
"child_operation", parent_span_id=parent_span.span_id
)
# Assert the relationships
assert child_span.parent_span_id == parent_span.span_id
assert (
child_span.trace_id == parent_span.trace_id
) # Assuming children share the same trace ID
# End spans
tracer.end_span(child_span)
tracer.end_span(parent_span)
# Assert they are in the storage
assert child_span in tracer._get_current_storage().spans
assert parent_span in tracer._get_current_storage().spans
# This test checks if unique UUIDs are being generated.
# Note: This is a simple test and doesn't guarantee uniqueness for large numbers of UUIDs.
def test_new_uuid_unique():
tracer = MockTracer()
uuid_set = {tracer._new_uuid() for _ in range(1000)}
assert len(uuid_set) == 1000

View File

@ -0,0 +1,124 @@
import os
import pytest
import asyncio
import json
import tempfile
import time
from pilot.utils.tracer import SpanStorage, FileSpanStorage, Span
@pytest.fixture
def storage(request):
if not request or not hasattr(request, "param"):
batch_size = 10
flush_interval = 10
file_does_not_exist = False
else:
batch_size = request.param.get("batch_size", 10)
flush_interval = request.param.get("flush_interval", 10)
file_does_not_exist = request.param.get("file_does_not_exist", False)
if file_does_not_exist:
with tempfile.TemporaryDirectory() as tmp_dir:
filename = os.path.join(tmp_dir, "non_existent_file.jsonl")
storage_instance = FileSpanStorage(
filename, batch_size=batch_size, flush_interval=flush_interval
)
yield storage_instance
else:
with tempfile.NamedTemporaryFile(delete=True) as tmp_file:
filename = tmp_file.name
storage_instance = FileSpanStorage(
filename, batch_size=batch_size, flush_interval=flush_interval
)
yield storage_instance
def read_spans_from_file(filename):
with open(filename, "r") as f:
return [json.loads(line) for line in f.readlines()]
@pytest.mark.parametrize(
"storage", [{"batch_size": 1, "flush_interval": 5}], indirect=True
)
def test_write_span(storage: SpanStorage):
span = Span("1", "a", "b", "op1")
storage.append_span(span)
time.sleep(0.1)
spans_in_file = read_spans_from_file(storage.filename)
assert len(spans_in_file) == 1
assert spans_in_file[0]["trace_id"] == "1"
@pytest.mark.parametrize(
"storage", [{"batch_size": 1, "flush_interval": 5}], indirect=True
)
def test_incremental_write(storage: SpanStorage):
span1 = Span("1", "a", "b", "op1")
span2 = Span("2", "c", "d", "op2")
storage.append_span(span1)
storage.append_span(span2)
time.sleep(0.1)
spans_in_file = read_spans_from_file(storage.filename)
assert len(spans_in_file) == 2
@pytest.mark.parametrize(
"storage", [{"batch_size": 2, "flush_interval": 5}], indirect=True
)
def test_sync_and_async_append(storage: SpanStorage):
span = Span("1", "a", "b", "op1")
storage.append_span(span)
async def async_append():
storage.append_span(span)
asyncio.run(async_append())
time.sleep(0.1)
spans_in_file = read_spans_from_file(storage.filename)
assert len(spans_in_file) == 2
@pytest.mark.asyncio
async def test_flush_policy(storage: SpanStorage):
span = Span("1", "a", "b", "op1")
for _ in range(storage.batch_size - 1):
storage.append_span(span)
spans_in_file = read_spans_from_file(storage.filename)
assert len(spans_in_file) == 0
# Trigger batch write
storage.append_span(span)
await asyncio.sleep(0.1)
spans_in_file = read_spans_from_file(storage.filename)
assert len(spans_in_file) == storage.batch_size
@pytest.mark.parametrize(
"storage", [{"batch_size": 2, "file_does_not_exist": True}], indirect=True
)
def test_non_existent_file(storage: SpanStorage):
span = Span("1", "a", "b", "op1")
span2 = Span("2", "c", "d", "op2")
storage.append_span(span)
time.sleep(0.1)
spans_in_file = read_spans_from_file(storage.filename)
assert len(spans_in_file) == 0
storage.append_span(span2)
time.sleep(0.1)
spans_in_file = read_spans_from_file(storage.filename)
assert len(spans_in_file) == 2
assert spans_in_file[0]["trace_id"] == "1"
assert spans_in_file[1]["trace_id"] == "2"

View File

@ -0,0 +1,103 @@
import pytest
from pilot.utils.tracer import (
Span,
SpanStorageType,
SpanStorage,
DefaultTracer,
TracerManager,
Tracer,
MemorySpanStorage,
)
from pilot.component import SystemApp
@pytest.fixture
def system_app():
return SystemApp()
@pytest.fixture
def storage(system_app: SystemApp):
ms = MemorySpanStorage(system_app)
system_app.register_instance(ms)
return ms
@pytest.fixture
def tracer(request, system_app: SystemApp):
if not request or not hasattr(request, "param"):
return DefaultTracer(system_app)
else:
span_storage_type = request.param.get(
"span_storage_type", SpanStorageType.ON_CREATE_END
)
return DefaultTracer(system_app, span_storage_type=span_storage_type)
@pytest.fixture
def tracer_manager(system_app: SystemApp, tracer: Tracer):
system_app.register_instance(tracer)
manager = TracerManager()
manager.initialize(system_app)
return manager
def test_start_and_end_span(tracer: Tracer):
span = tracer.start_span("operation")
assert isinstance(span, Span)
assert span.operation_name == "operation"
tracer.end_span(span)
assert span.end_time is not None
stored_span = tracer._get_current_storage().spans[0]
assert stored_span == span
def test_start_and_end_span_with_tracer_manager(tracer_manager: TracerManager):
span = tracer_manager.start_span("operation")
assert isinstance(span, Span)
assert span.operation_name == "operation"
tracer_manager.end_span(span)
assert span.end_time is not None
def test_parent_child_span_relation(tracer: Tracer):
parent_span = tracer.start_span("parent_operation")
child_span = tracer.start_span(
"child_operation", parent_span_id=parent_span.span_id
)
assert child_span.parent_span_id == parent_span.span_id
assert child_span.trace_id == parent_span.trace_id
tracer.end_span(child_span)
tracer.end_span(parent_span)
assert parent_span in tracer._get_current_storage().spans
assert child_span in tracer._get_current_storage().spans
@pytest.mark.parametrize(
"tracer, expected_count, after_create_inc_count",
[
({"span_storage_type": SpanStorageType.ON_CREATE}, 1, 1),
({"span_storage_type": SpanStorageType.ON_END}, 1, 0),
({"span_storage_type": SpanStorageType.ON_CREATE_END}, 2, 1),
],
indirect=["tracer"],
)
def test_tracer_span_storage_type_and_with(
tracer: Tracer,
expected_count: int,
after_create_inc_count: int,
storage: SpanStorage,
):
span = tracer.start_span("new_span")
span.end()
assert len(storage.spans) == expected_count
with tracer.start_span("with_span") as ws:
assert len(storage.spans) == expected_count + after_create_inc_count
assert len(storage.spans) == expected_count + expected_count

View File

@ -0,0 +1,164 @@
from typing import Dict, Optional
from contextvars import ContextVar
from functools import wraps
from pilot.component import SystemApp, ComponentType
from pilot.utils.tracer.base import (
Span,
Tracer,
SpanStorage,
SpanStorageType,
TracerContext,
)
from pilot.utils.tracer.span_storage import MemorySpanStorage
class DefaultTracer(Tracer):
def __init__(
self,
system_app: SystemApp | None = None,
default_storage: SpanStorage = None,
span_storage_type: SpanStorageType = SpanStorageType.ON_CREATE_END,
):
super().__init__(system_app)
self._span_stack_var = ContextVar("span_stack", default=[])
if not default_storage:
default_storage = MemorySpanStorage(system_app)
self._default_storage = default_storage
self._span_storage_type = span_storage_type
def append_span(self, span: Span):
self._get_current_storage().append_span(span)
def start_span(
self, operation_name: str, parent_span_id: str = None, metadata: Dict = None
) -> Span:
trace_id = (
self._new_uuid() if parent_span_id is None else parent_span_id.split(":")[0]
)
span_id = f"{trace_id}:{self._new_uuid()}"
span = Span(
trace_id, span_id, parent_span_id, operation_name, metadata=metadata
)
if self._span_storage_type in [
SpanStorageType.ON_END,
SpanStorageType.ON_CREATE_END,
]:
span.add_end_caller(self.append_span)
if self._span_storage_type in [
SpanStorageType.ON_CREATE,
SpanStorageType.ON_CREATE_END,
]:
self.append_span(span)
current_stack = self._span_stack_var.get()
current_stack.append(span)
self._span_stack_var.set(current_stack)
span.add_end_caller(self._remove_from_stack_top)
return span
def end_span(self, span: Span, **kwargs):
""""""
span.end(**kwargs)
def _remove_from_stack_top(self, span: Span):
current_stack = self._span_stack_var.get()
if current_stack:
current_stack.pop()
self._span_stack_var.set(current_stack)
def get_current_span(self) -> Optional[Span]:
current_stack = self._span_stack_var.get()
return current_stack[-1] if current_stack else None
def _get_current_storage(self) -> SpanStorage:
return self.system_app.get_component(
ComponentType.TRACER_SPAN_STORAGE, SpanStorage, self._default_storage
)
class TracerManager:
def __init__(self) -> None:
self._system_app: Optional[SystemApp] = None
self._trace_context_var: ContextVar[TracerContext] = ContextVar(
"trace_context",
default=TracerContext(span_id="default_trace_id:default_span_id"),
)
def initialize(
self, system_app: SystemApp, trace_context_var: ContextVar[TracerContext] = None
) -> None:
self._system_app = system_app
if trace_context_var:
self._trace_context_var = trace_context_var
def _get_tracer(self) -> Tracer:
if not self._system_app:
return None
return self._system_app.get_component(ComponentType.TRACER, Tracer, None)
def start_span(
self, operation_name: str, parent_span_id: str = None, metadata: Dict = None
) -> Span:
tracer = self._get_tracer()
if not tracer:
return Span("empty_span", "empty_span")
if not parent_span_id:
current_span = self.get_current_span()
if current_span:
parent_span_id = current_span.span_id
else:
parent_span_id = self._trace_context_var.get().span_id
return tracer.start_span(operation_name, parent_span_id, metadata)
def end_span(self, span: Span, **kwargs):
tracer = self._get_tracer()
if not tracer or not span:
return
tracer.end_span(span, **kwargs)
def get_current_span(self) -> Optional[Span]:
tracer = self._get_tracer()
if not tracer:
return None
return tracer.get_current_span()
root_tracer: TracerManager = TracerManager()
def trace(operation_name: str):
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
with root_tracer.start_span(operation_name) as span:
return await func(*args, **kwargs)
return wrapper
return decorator
def initialize_tracer(system_app: SystemApp, tracer_filename: str):
if not system_app:
return
from pilot.utils.tracer.span_storage import FileSpanStorage
trace_context_var = ContextVar(
"trace_context",
default=TracerContext(span_id="default_trace_id:default_span_id"),
)
tracer = DefaultTracer(system_app)
system_app.register_instance(FileSpanStorage(tracer_filename))
system_app.register_instance(tracer)
root_tracer.initialize(system_app, trace_context_var)
if system_app.app:
from pilot.utils.tracer.tracer_middleware import TraceIDMiddleware
system_app.app.add_middleware(
TraceIDMiddleware, trace_context_var=trace_context_var, tracer=tracer
)

View File

@ -0,0 +1,43 @@
import uuid
from contextvars import ContextVar
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.types import ASGIApp
from pilot.utils.tracer import TracerContext, Tracer
_DEFAULT_EXCLUDE_PATHS = ["/api/controller/heartbeat"]
class TraceIDMiddleware(BaseHTTPMiddleware):
def __init__(
self,
app: ASGIApp,
trace_context_var: ContextVar[TracerContext],
tracer: Tracer,
include_prefix: str = "/api",
exclude_paths=_DEFAULT_EXCLUDE_PATHS,
):
super().__init__(app)
self.trace_context_var = trace_context_var
self.tracer = tracer
self.include_prefix = include_prefix
self.exclude_paths = exclude_paths
async def dispatch(self, request: Request, call_next):
if request.url.path in self.exclude_paths or not request.url.path.startswith(
self.include_prefix
):
return await call_next(request)
span_id = request.headers.get("DBGPT_TRACER_SPAN_ID")
# if not span_id:
# span_id = str(uuid.uuid4())
# self.trace_context_var.set(TracerContext(span_id=span_id))
with self.tracer.start_span(
"DB-GPT-Web-Entry", span_id, metadata={"path": request.url.path}
) as _:
response = await call_next(request)
return response