refactor: The first refactored version for sdk release (#907)

Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
FangYin Cheng
2023-12-08 14:45:59 +08:00
committed by GitHub
parent e7e4aff667
commit cd725db1fb
573 changed files with 2094 additions and 3571 deletions

View File

@@ -0,0 +1,39 @@
from dbgpt.util.tracer.base import (
SpanType,
Span,
SpanTypeRunName,
Tracer,
SpanStorage,
SpanStorageType,
TracerContext,
)
from dbgpt.util.tracer.span_storage import (
MemorySpanStorage,
FileSpanStorage,
SpanStorageContainer,
)
from dbgpt.util.tracer.tracer_impl import (
root_tracer,
trace,
initialize_tracer,
DefaultTracer,
TracerManager,
)
__all__ = [
"SpanType",
"Span",
"SpanTypeRunName",
"Tracer",
"SpanStorage",
"SpanStorageType",
"TracerContext",
"MemorySpanStorage",
"FileSpanStorage",
"SpanStorageContainer",
"root_tracer",
"trace",
"initialize_tracer",
"DefaultTracer",
"TracerManager",
]

189
dbgpt/util/tracer/base.py Normal file
View File

@@ -0,0 +1,189 @@
from __future__ import annotations
from typing import Dict, Callable, Optional, List
from dataclasses import dataclass
from abc import ABC, abstractmethod
from enum import Enum
import uuid
from datetime import datetime
from dbgpt.component import BaseComponent, SystemApp, ComponentType
class SpanType(str, Enum):
BASE = "base"
RUN = "run"
CHAT = "chat"
class SpanTypeRunName(str, Enum):
WEBSERVER = "Webserver"
WORKER_MANAGER = "WorkerManager"
MODEL_WORKER = "ModelWorker"
EMBEDDING_MODEL = "EmbeddingModel"
@staticmethod
def values():
return [item.value for item in SpanTypeRunName]
class Span:
"""Represents a unit of work that is being traced.
This can be any operation like a function call or a database query.
"""
def __init__(
self,
trace_id: str,
span_id: str,
span_type: SpanType = None,
parent_span_id: str = None,
operation_name: str = None,
metadata: Dict = None,
end_caller: Callable[[Span], None] = None,
):
if not span_type:
span_type = SpanType.BASE
self.span_type = span_type
# The unique identifier for the entire trace
self.trace_id = trace_id
# Unique identifier for this span within the trace
self.span_id = span_id
# Identifier of the parent span, if this is a child span
self.parent_span_id = parent_span_id
# Descriptive name for the operation being traced
self.operation_name = operation_name
# Timestamp when this span started
self.start_time = datetime.now()
# Timestamp when this span ended, initially None
self.end_time = None
# Additional metadata associated with the span
self.metadata = metadata
self._end_callers = []
if end_caller:
self._end_callers.append(end_caller)
def end(self, **kwargs):
"""Mark the end of this span by recording the current time."""
self.end_time = datetime.now()
if "metadata" in kwargs:
self.metadata = kwargs.get("metadata")
for caller in self._end_callers:
caller(self)
def add_end_caller(self, end_caller: Callable[[Span], None]):
if end_caller:
self._end_callers.append(end_caller)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.end()
return False
def to_dict(self) -> Dict:
return {
"span_type": self.span_type.value,
"trace_id": self.trace_id,
"span_id": self.span_id,
"parent_span_id": self.parent_span_id,
"operation_name": self.operation_name,
"start_time": None
if not self.start_time
else self.start_time.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3],
"end_time": None
if not self.end_time
else self.end_time.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3],
"metadata": self.metadata,
}
class SpanStorageType(str, Enum):
ON_CREATE = "on_create"
ON_END = "on_end"
ON_CREATE_END = "on_create_end"
class SpanStorage(BaseComponent, ABC):
"""Abstract base class for storing spans.
This allows different storage mechanisms (e.g., in-memory, database) to be implemented.
"""
name = ComponentType.TRACER_SPAN_STORAGE.value
def init_app(self, system_app: SystemApp):
"""Initialize the storage with the given application context."""
pass
@abstractmethod
def append_span(self, span: Span):
"""Store the given span. This needs to be implemented by subclasses."""
def append_span_batch(self, spans: List[Span]):
"""Store the span batch"""
for span in spans:
self.append_span(span)
class Tracer(BaseComponent, ABC):
"""Abstract base class for tracing operations.
Provides the core logic for starting, ending, and retrieving spans.
"""
name = ComponentType.TRACER.value
def __init__(self, system_app: SystemApp | None = None):
super().__init__(system_app)
self.system_app = system_app # Application context
def init_app(self, system_app: SystemApp):
"""Initialize the tracer with the given application context."""
self.system_app = system_app
@abstractmethod
def append_span(self, span: Span):
"""Append the given span to storage. This needs to be implemented by subclasses."""
@abstractmethod
def start_span(
self,
operation_name: str,
parent_span_id: str = None,
span_type: SpanType = None,
metadata: Dict = None,
) -> Span:
"""Begin a new span for the given operation. If provided, the span will be
a child of the span with the given parent_span_id.
"""
@abstractmethod
def end_span(self, span: Span, **kwargs):
"""
End the given span.
"""
@abstractmethod
def get_current_span(self) -> Optional[Span]:
"""
Retrieve the span that is currently being traced.
"""
@abstractmethod
def _get_current_storage(self) -> SpanStorage:
"""
Get the storage mechanism currently in use for storing spans.
This needs to be implemented by subclasses.
"""
def _new_uuid(self) -> str:
"""
Generate a new unique identifier.
"""
return str(uuid.uuid4())
@dataclass
class TracerContext:
span_id: Optional[str] = None

View File

@@ -0,0 +1,150 @@
import os
import json
import time
import datetime
import threading
import queue
import logging
from typing import Optional, List
from concurrent.futures import Executor, ThreadPoolExecutor
from dbgpt.component import SystemApp
from dbgpt.util.tracer.base import Span, SpanStorage
logger = logging.getLogger(__name__)
class MemorySpanStorage(SpanStorage):
def __init__(self, system_app: SystemApp | None = None):
super().__init__(system_app)
self.spans = []
self._lock = threading.Lock()
def append_span(self, span: Span):
with self._lock:
self.spans.append(span)
class SpanStorageContainer(SpanStorage):
def __init__(
self,
system_app: SystemApp | None = None,
batch_size=10,
flush_interval=10,
executor: Executor = None,
):
super().__init__(system_app)
if not executor:
executor = ThreadPoolExecutor(thread_name_prefix="trace_storage_sync_")
self.executor = executor
self.storages: List[SpanStorage] = []
self.last_date = (
datetime.datetime.now().date()
) # Store the current date for checking date changes
self.queue = queue.Queue()
self.batch_size = batch_size
self.flush_interval = flush_interval
self.last_flush_time = time.time()
self.flush_signal_queue = queue.Queue()
self.flush_thread = threading.Thread(
target=self._flush_to_storages, daemon=True
)
self.flush_thread.start()
def append_storage(self, storage: SpanStorage):
"""Append sotrage to container
Args:
storage ([`SpanStorage`]): The storage to be append to current container
"""
self.storages.append(storage)
def append_span(self, span: Span):
self.queue.put(span)
if self.queue.qsize() >= self.batch_size:
try:
self.flush_signal_queue.put_nowait(True)
except queue.Full:
pass # If the signal queue is full, it's okay. The flush thread will handle it.
def _flush_to_storages(self):
while True:
interval = time.time() - self.last_flush_time
if interval < self.flush_interval:
try:
self.flush_signal_queue.get(
block=True, timeout=self.flush_interval - interval
)
except Exception:
# Timeout
pass
spans_to_write = []
while not self.queue.empty():
spans_to_write.append(self.queue.get())
for s in self.storages:
def append_and_ignore_error(
storage: SpanStorage, spans_to_write: List[SpanStorage]
):
try:
storage.append_span_batch(spans_to_write)
except Exception as e:
logger.warn(
f"Append spans to storage {str(storage)} failed: {str(e)}, span_data: {spans_to_write}"
)
self.executor.submit(append_and_ignore_error, s, spans_to_write)
self.last_flush_time = time.time()
class FileSpanStorage(SpanStorage):
def __init__(self, filename: str):
super().__init__()
self.filename = filename
# Split filename into prefix and suffix
self.filename_prefix, self.filename_suffix = os.path.splitext(filename)
if not self.filename_suffix:
self.filename_suffix = ".log"
self.last_date = (
datetime.datetime.now().date()
) # Store the current date for checking date changes
self.queue = queue.Queue()
if not os.path.exists(filename):
# New file if not exist
os.makedirs(os.path.dirname(filename), exist_ok=True)
with open(filename, "a"):
pass
def append_span(self, span: Span):
self._write_to_file([span])
def append_span_batch(self, spans: List[Span]):
self._write_to_file(spans)
def _get_dated_filename(self, date: datetime.date) -> str:
"""Return the filename based on a specific date."""
date_str = date.strftime("%Y-%m-%d")
return f"{self.filename_prefix}_{date_str}{self.filename_suffix}"
def _roll_over_if_needed(self):
"""Checks if a day has changed since the last write, and if so, renames the current file."""
current_date = datetime.datetime.now().date()
if current_date != self.last_date:
if os.path.exists(self.filename):
os.rename(self.filename, self._get_dated_filename(self.last_date))
self.last_date = current_date
def _write_to_file(self, spans: List[Span]):
self._roll_over_if_needed()
with open(self.filename, "a") as file:
for span in spans:
span_data = span.to_dict()
try:
file.write(json.dumps(span_data, ensure_ascii=False) + "\n")
except Exception as e:
logger.warning(
f"Write span to file failed: {str(e)}, span_data: {span_data}"
)

View File

View File

@@ -0,0 +1,131 @@
from typing import Dict
from dbgpt.component import SystemApp
from dbgpt.util.tracer import Span, SpanType, SpanStorage, Tracer
# Mock implementations
class MockSpanStorage(SpanStorage):
def __init__(self):
self.spans = []
def append_span(self, span: Span):
self.spans.append(span)
class MockTracer(Tracer):
def __init__(self, system_app: SystemApp | None = None):
super().__init__(system_app)
self.current_span = None
self.storage = MockSpanStorage()
def append_span(self, span: Span):
self.storage.append_span(span)
def start_span(
self, operation_name: str, parent_span_id: str = None, metadata: Dict = None
) -> Span:
trace_id = (
self._new_uuid() if parent_span_id is None else parent_span_id.split(":")[0]
)
span_id = f"{trace_id}:{self._new_uuid()}"
span = Span(
trace_id, span_id, SpanType.BASE, parent_span_id, operation_name, metadata
)
self.current_span = span
return span
def end_span(self, span: Span):
span.end()
self.append_span(span)
def get_current_span(self) -> Span:
return self.current_span
def _get_current_storage(self) -> SpanStorage:
return self.storage
# Tests
def test_span_creation():
span = Span(
"trace_id",
"span_id",
SpanType.BASE,
"parent_span_id",
"operation",
{"key": "value"},
)
assert span.trace_id == "trace_id"
assert span.span_id == "span_id"
assert span.parent_span_id == "parent_span_id"
assert span.operation_name == "operation"
assert span.metadata == {"key": "value"}
def test_span_end():
span = Span("trace_id", "span_id")
assert span.end_time is None
span.end()
assert span.end_time is not None
def test_mock_tracer_start_span():
tracer = MockTracer()
span = tracer.start_span("operation")
assert span.operation_name == "operation"
assert tracer.get_current_span() == span
def test_mock_tracer_end_span():
tracer = MockTracer()
span = tracer.start_span("operation")
tracer.end_span(span)
assert span in tracer._get_current_storage().spans
def test_mock_tracer_append_span():
tracer = MockTracer()
span = Span("trace_id", "span_id")
tracer.append_span(span)
assert span in tracer._get_current_storage().spans
def test_parent_child_span_relation():
tracer = MockTracer()
# Start a parent span
parent_span = tracer.start_span("parent_operation")
# Start a child span with parent span's ID
child_span = tracer.start_span(
"child_operation", parent_span_id=parent_span.span_id
)
# Assert the relationships
assert child_span.parent_span_id == parent_span.span_id
assert (
child_span.trace_id == parent_span.trace_id
) # Assuming children share the same trace ID
# End spans
tracer.end_span(child_span)
tracer.end_span(parent_span)
# Assert they are in the storage
assert child_span in tracer._get_current_storage().spans
assert parent_span in tracer._get_current_storage().spans
# This test checks if unique UUIDs are being generated.
# Note: This is a simple test and doesn't guarantee uniqueness for large numbers of UUIDs.
def test_new_uuid_unique():
tracer = MockTracer()
uuid_set = {tracer._new_uuid() for _ in range(1000)}
assert len(uuid_set) == 1000

View File

@@ -0,0 +1,174 @@
import os
import pytest
import asyncio
import json
import tempfile
import time
from unittest.mock import patch
from datetime import datetime, timedelta
from dbgpt.util.tracer import (
SpanStorage,
FileSpanStorage,
Span,
SpanType,
SpanStorageContainer,
)
@pytest.fixture
def storage(request):
if not request or not hasattr(request, "param"):
file_does_not_exist = False
else:
file_does_not_exist = request.param.get("file_does_not_exist", False)
if file_does_not_exist:
with tempfile.TemporaryDirectory() as tmp_dir:
filename = os.path.join(tmp_dir, "non_existent_file.jsonl")
storage_instance = FileSpanStorage(filename)
yield storage_instance
else:
with tempfile.NamedTemporaryFile(delete=True) as tmp_file:
filename = tmp_file.name
storage_instance = FileSpanStorage(filename)
yield storage_instance
@pytest.fixture
def storage_container(request):
if not request or not hasattr(request, "param"):
batch_size = 10
flush_interval = 10
else:
batch_size = request.param.get("batch_size", 10)
flush_interval = request.param.get("flush_interval", 10)
storage_container = SpanStorageContainer(
batch_size=batch_size, flush_interval=flush_interval
)
yield storage_container
def read_spans_from_file(filename):
with open(filename, "r") as f:
return [json.loads(line) for line in f.readlines()]
def test_write_span(storage: SpanStorage):
span = Span("1", "a", SpanType.BASE, "b", "op1")
storage.append_span(span)
time.sleep(0.1)
spans_in_file = read_spans_from_file(storage.filename)
assert len(spans_in_file) == 1
assert spans_in_file[0]["trace_id"] == "1"
def test_incremental_write(storage: SpanStorage):
span1 = Span("1", "a", SpanType.BASE, "b", "op1")
span2 = Span("2", "c", SpanType.BASE, "d", "op2")
storage.append_span(span1)
storage.append_span(span2)
time.sleep(0.1)
spans_in_file = read_spans_from_file(storage.filename)
assert len(spans_in_file) == 2
def test_sync_and_async_append(storage: SpanStorage):
span = Span("1", "a", SpanType.BASE, "b", "op1")
storage.append_span(span)
async def async_append():
storage.append_span(span)
asyncio.run(async_append())
time.sleep(0.1)
spans_in_file = read_spans_from_file(storage.filename)
assert len(spans_in_file) == 2
@pytest.mark.parametrize("storage", [{"file_does_not_exist": True}], indirect=True)
def test_non_existent_file(storage: SpanStorage):
span = Span("1", "a", SpanType.BASE, "b", "op1")
span2 = Span("2", "c", SpanType.BASE, "d", "op2")
storage.append_span(span)
time.sleep(0.1)
spans_in_file = read_spans_from_file(storage.filename)
assert len(spans_in_file) == 1
storage.append_span(span2)
time.sleep(0.1)
spans_in_file = read_spans_from_file(storage.filename)
assert len(spans_in_file) == 2
assert spans_in_file[0]["trace_id"] == "1"
assert spans_in_file[1]["trace_id"] == "2"
@pytest.mark.parametrize("storage", [{"file_does_not_exist": True}], indirect=True)
def test_log_rollover(storage: SpanStorage):
# mock start date
mock_start_date = datetime(2023, 10, 18, 23, 59)
with patch("datetime.datetime") as mock_datetime:
mock_datetime.now.return_value = mock_start_date
span1 = Span("1", "a", SpanType.BASE, "b", "op1")
storage.append_span(span1)
time.sleep(0.1)
# mock new day
mock_datetime.now.return_value = mock_start_date + timedelta(minutes=1)
span2 = Span("2", "c", SpanType.BASE, "d", "op2")
storage.append_span(span2)
time.sleep(0.1)
# origin filename need exists
assert os.path.exists(storage.filename)
# get roll over filename
dated_filename = os.path.join(
os.path.dirname(storage.filename),
f"{os.path.basename(storage.filename).split('.')[0]}_2023-10-18.jsonl",
)
assert os.path.exists(dated_filename)
# check origin filename just include the second span
spans_in_original_file = read_spans_from_file(storage.filename)
assert len(spans_in_original_file) == 1
assert spans_in_original_file[0]["trace_id"] == "2"
# check the roll over filename just include the first span
spans_in_dated_file = read_spans_from_file(dated_filename)
assert len(spans_in_dated_file) == 1
assert spans_in_dated_file[0]["trace_id"] == "1"
@pytest.mark.asyncio
@pytest.mark.parametrize("storage_container", [{"batch_size": 5}], indirect=True)
async def test_container_flush_policy(
storage_container: SpanStorageContainer, storage: FileSpanStorage
):
storage_container.append_storage(storage)
span = Span("1", "a", SpanType.BASE, "b", "op1")
filename = storage.filename
for _ in range(storage_container.batch_size - 1):
storage_container.append_span(span)
spans_in_file = read_spans_from_file(filename)
assert len(spans_in_file) == 0
# Trigger batch write
storage_container.append_span(span)
await asyncio.sleep(0.1)
spans_in_file = read_spans_from_file(filename)
assert len(spans_in_file) == storage_container.batch_size

View File

@@ -0,0 +1,103 @@
import pytest
from dbgpt.util.tracer import (
Span,
SpanStorageType,
SpanStorage,
DefaultTracer,
TracerManager,
Tracer,
MemorySpanStorage,
)
from dbgpt.component import SystemApp
@pytest.fixture
def system_app():
return SystemApp()
@pytest.fixture
def storage(system_app: SystemApp):
ms = MemorySpanStorage(system_app)
system_app.register_instance(ms)
return ms
@pytest.fixture
def tracer(request, system_app: SystemApp):
if not request or not hasattr(request, "param"):
return DefaultTracer(system_app)
else:
span_storage_type = request.param.get(
"span_storage_type", SpanStorageType.ON_CREATE_END
)
return DefaultTracer(system_app, span_storage_type=span_storage_type)
@pytest.fixture
def tracer_manager(system_app: SystemApp, tracer: Tracer):
system_app.register_instance(tracer)
manager = TracerManager()
manager.initialize(system_app)
return manager
def test_start_and_end_span(tracer: Tracer):
span = tracer.start_span("operation")
assert isinstance(span, Span)
assert span.operation_name == "operation"
tracer.end_span(span)
assert span.end_time is not None
stored_span = tracer._get_current_storage().spans[0]
assert stored_span == span
def test_start_and_end_span_with_tracer_manager(tracer_manager: TracerManager):
span = tracer_manager.start_span("operation")
assert isinstance(span, Span)
assert span.operation_name == "operation"
tracer_manager.end_span(span)
assert span.end_time is not None
def test_parent_child_span_relation(tracer: Tracer):
parent_span = tracer.start_span("parent_operation")
child_span = tracer.start_span(
"child_operation", parent_span_id=parent_span.span_id
)
assert child_span.parent_span_id == parent_span.span_id
assert child_span.trace_id == parent_span.trace_id
tracer.end_span(child_span)
tracer.end_span(parent_span)
assert parent_span in tracer._get_current_storage().spans
assert child_span in tracer._get_current_storage().spans
@pytest.mark.parametrize(
"tracer, expected_count, after_create_inc_count",
[
({"span_storage_type": SpanStorageType.ON_CREATE}, 1, 1),
({"span_storage_type": SpanStorageType.ON_END}, 1, 0),
({"span_storage_type": SpanStorageType.ON_CREATE_END}, 2, 1),
],
indirect=["tracer"],
)
def test_tracer_span_storage_type_and_with(
tracer: Tracer,
expected_count: int,
after_create_inc_count: int,
storage: SpanStorage,
):
span = tracer.start_span("new_span")
span.end()
assert len(storage.spans) == expected_count
with tracer.start_span("with_span") as ws:
assert len(storage.spans) == expected_count + after_create_inc_count
assert len(storage.spans) == expected_count + expected_count

View File

@@ -0,0 +1,597 @@
import os
import click
import logging
import glob
import json
from datetime import datetime
from typing import Iterable, Dict, Callable
from dbgpt.configs.model_config import LOGDIR
from dbgpt.util.tracer import SpanType, SpanTypeRunName
logger = logging.getLogger("dbgpt_cli")
_DEFAULT_FILE_PATTERN = os.path.join(LOGDIR, "dbgpt*.jsonl")
@click.group("trace")
def trace_cli_group():
"""Analyze and visualize trace spans."""
pass
@trace_cli_group.command()
@click.option(
"--trace_id",
required=False,
type=str,
default=None,
show_default=True,
help="Specify the trace ID to list",
)
@click.option(
"--span_id",
required=False,
type=str,
default=None,
show_default=True,
help="Specify the Span ID to list.",
)
@click.option(
"--span_type",
required=False,
type=str,
default=None,
show_default=True,
help="Specify the Span Type to list.",
)
@click.option(
"--parent_span_id",
required=False,
type=str,
default=None,
show_default=True,
help="Specify the Parent Span ID to list.",
)
@click.option(
"--search",
required=False,
type=str,
default=None,
show_default=True,
help="Search trace_id, span_id, parent_span_id, operation_name or content in metadata.",
)
@click.option(
"-l",
"--limit",
type=int,
default=20,
help="Limit the number of recent span displayed.",
)
@click.option(
"--start_time",
type=str,
help='Filter by start time. Format: "YYYY-MM-DD HH:MM:SS.mmm"',
)
@click.option(
"--end_time", type=str, help='Filter by end time. Format: "YYYY-MM-DD HH:MM:SS.mmm"'
)
@click.option(
"--desc",
required=False,
type=bool,
default=False,
is_flag=True,
help="Whether to use reverse sorting. By default, sorting is based on start time.",
)
@click.option(
"--output",
required=False,
type=click.Choice(["text", "html", "csv", "latex", "json"]),
default="text",
help="The output format",
)
@click.argument("files", nargs=-1, type=click.Path(exists=True, readable=True))
def list(
trace_id: str,
span_id: str,
span_type: str,
parent_span_id: str,
search: str,
limit: int,
start_time: str,
end_time: str,
desc: bool,
output: str,
files=None,
):
"""List your trace spans"""
from prettytable import PrettyTable
# If no files are explicitly specified, use the default pattern to get them
spans = read_spans_from_files(files)
if trace_id:
spans = filter(lambda s: s["trace_id"] == trace_id, spans)
if span_id:
spans = filter(lambda s: s["span_id"] == span_id, spans)
if span_type:
spans = filter(lambda s: s["span_type"] == span_type, spans)
if parent_span_id:
spans = filter(lambda s: s["parent_span_id"] == parent_span_id, spans)
# Filter spans based on the start and end times
if start_time:
start_dt = _parse_datetime(start_time)
spans = filter(
lambda span: _parse_datetime(span["start_time"]) >= start_dt, spans
)
if end_time:
end_dt = _parse_datetime(end_time)
spans = filter(
lambda span: _parse_datetime(span["start_time"]) <= end_dt, spans
)
if search:
spans = filter(_new_search_span_func(search), spans)
# Sort spans based on the start time
spans = sorted(
spans, key=lambda span: _parse_datetime(span["start_time"]), reverse=desc
)[:limit]
table = PrettyTable(
["Trace ID", "Span ID", "Operation Name", "Conversation UID"],
)
for sp in spans:
conv_uid = None
if "metadata" in sp and sp:
metadata = sp["metadata"]
if isinstance(metadata, dict):
conv_uid = metadata.get("conv_uid")
table.add_row(
[
sp.get("trace_id"),
sp.get("span_id"),
# sp.get("parent_span_id"),
sp.get("operation_name"),
conv_uid,
]
)
out_kwargs = {"ensure_ascii": False} if output == "json" else {}
print(table.get_formatted_string(out_format=output, **out_kwargs))
@trace_cli_group.command()
@click.option(
"--trace_id",
required=True,
type=str,
help="Specify the trace ID to list",
)
@click.argument("files", nargs=-1, type=click.Path(exists=True, readable=True))
def tree(trace_id: str, files):
"""Display trace links as a tree"""
hierarchy = _view_trace_hierarchy(trace_id, files)
if not hierarchy:
_print_empty_message(files)
return
_print_trace_hierarchy(hierarchy)
@trace_cli_group.command()
@click.option(
"--trace_id",
required=False,
type=str,
default=None,
help="Specify the trace ID to analyze. If None, show latest conversation details",
)
@click.option(
"--tree",
required=False,
type=bool,
default=False,
is_flag=True,
help="Display trace spans as a tree",
)
@click.option(
"--hide_conv",
required=False,
type=bool,
default=False,
is_flag=True,
help="Hide your conversation details",
)
@click.option(
"--hide_run_params",
required=False,
type=bool,
default=False,
is_flag=True,
help="Hide run params",
)
@click.option(
"--output",
required=False,
type=click.Choice(["text", "html", "csv", "latex", "json"]),
default="text",
help="The output format",
)
@click.argument("files", nargs=-1, type=click.Path(exists=False, readable=True))
def chat(
trace_id: str,
tree: bool,
hide_conv: bool,
hide_run_params: bool,
output: str,
files,
):
"""Show conversation details"""
from prettytable import PrettyTable
spans = read_spans_from_files(files)
# Sort by start time
spans = sorted(
spans, key=lambda span: _parse_datetime(span["start_time"]), reverse=True
)
spans = [sp for sp in spans]
if not spans:
_print_empty_message(files)
return
service_spans = {}
service_names = set(SpanTypeRunName.values())
found_trace_id = None
for sp in spans:
span_type = sp["span_type"]
metadata = sp.get("metadata")
if span_type == SpanType.RUN:
service_name = metadata["run_service"]
service_spans[service_name] = sp.copy()
if set(service_spans.keys()) == service_names and found_trace_id:
break
elif span_type == SpanType.CHAT and not found_trace_id:
if not trace_id:
found_trace_id = sp["trace_id"]
if trace_id and trace_id == sp["trace_id"]:
found_trace_id = trace_id
service_tables = {}
system_infos_table = {}
out_kwargs = {"ensure_ascii": False} if output == "json" else {}
for service_name, sp in service_spans.items():
metadata = sp["metadata"]
table = PrettyTable(["Config Key", "Config Value"], title=service_name)
for k, v in metadata["params"].items():
table.add_row([k, v])
service_tables[service_name] = table
sys_infos = metadata.get("sys_infos")
if sys_infos and isinstance(sys_infos, dict):
sys_table = PrettyTable(
["System Config Key", "System Config Value"],
title=f"{service_name} System information",
)
for k, v in sys_infos.items():
sys_table.add_row([k, v])
system_infos_table[service_name] = sys_table
if not hide_run_params:
merged_table1 = merge_tables_horizontally(
[
service_tables.get(SpanTypeRunName.WEBSERVER.value),
service_tables.get(SpanTypeRunName.EMBEDDING_MODEL.value),
]
)
merged_table2 = merge_tables_horizontally(
[
service_tables.get(SpanTypeRunName.MODEL_WORKER.value),
service_tables.get(SpanTypeRunName.WORKER_MANAGER.value),
]
)
sys_table = system_infos_table.get(SpanTypeRunName.WORKER_MANAGER.value)
if system_infos_table:
for k, v in system_infos_table.items():
sys_table = v
break
if output == "text":
print(merged_table1)
print(merged_table2)
else:
for service_name, table in service_tables.items():
print(table.get_formatted_string(out_format=output, **out_kwargs))
if sys_table:
print(sys_table.get_formatted_string(out_format=output, **out_kwargs))
if not found_trace_id:
print(f"Can't found conversation with trace_id: {trace_id}")
return
trace_id = found_trace_id
trace_spans = [span for span in spans if span["trace_id"] == trace_id]
trace_spans = [s for s in reversed(trace_spans)]
hierarchy = _build_trace_hierarchy(trace_spans)
if tree:
print(f"\nInvoke Trace Tree(trace_id: {trace_id}):\n")
_print_trace_hierarchy(hierarchy)
if hide_conv:
return
trace_spans = _get_ordered_trace_from(hierarchy)
table = PrettyTable(["Key", "Value Value"], title="Chat Trace Details")
split_long_text = output == "text"
for sp in trace_spans:
op = sp["operation_name"]
metadata = sp.get("metadata")
if op == "get_chat_instance" and not sp["end_time"]:
table.add_row(["trace_id", trace_id])
table.add_row(["span_id", sp["span_id"]])
table.add_row(["conv_uid", metadata.get("conv_uid")])
table.add_row(["user_input", metadata.get("user_input")])
table.add_row(["chat_mode", metadata.get("chat_mode")])
table.add_row(["select_param", metadata.get("select_param")])
table.add_row(["model_name", metadata.get("model_name")])
if op in ["BaseChat.stream_call", "BaseChat.nostream_call"]:
if not sp["end_time"]:
table.add_row(["temperature", metadata.get("temperature")])
table.add_row(["max_new_tokens", metadata.get("max_new_tokens")])
table.add_row(["echo", metadata.get("echo")])
elif "error" in metadata:
table.add_row(["BaseChat Error", metadata.get("error")])
if op == "BaseChat.do_action" and not sp["end_time"]:
if "model_output" in metadata:
table.add_row(
[
"BaseChat model_output",
split_string_by_terminal_width(
metadata.get("model_output").get("text"),
split=split_long_text,
),
]
)
if "ai_response_text" in metadata:
table.add_row(
[
"BaseChat ai_response_text",
split_string_by_terminal_width(
metadata.get("ai_response_text"), split=split_long_text
),
]
)
if "prompt_define_response" in metadata:
prompt_define_response = metadata.get("prompt_define_response") or ""
if isinstance(prompt_define_response, dict) or isinstance(
prompt_define_response, type([])
):
prompt_define_response = json.dumps(
prompt_define_response, ensure_ascii=False
)
table.add_row(
[
"BaseChat prompt_define_response",
split_string_by_terminal_width(
prompt_define_response,
split=split_long_text,
),
]
)
if op == "DefaultModelWorker_call.generate_stream_func":
if not sp["end_time"]:
table.add_row(["llm_adapter", metadata.get("llm_adapter")])
table.add_row(
[
"User prompt",
split_string_by_terminal_width(
metadata.get("prompt"), split=split_long_text
),
]
)
else:
table.add_row(
[
"Model output",
split_string_by_terminal_width(metadata.get("output")),
]
)
if (
op
in [
"DefaultModelWorker.async_generate_stream",
"DefaultModelWorker.generate_stream",
]
and metadata
and "error" in metadata
):
table.add_row(["Model Error", metadata.get("error")])
print(table.get_formatted_string(out_format=output, **out_kwargs))
def read_spans_from_files(files=None) -> Iterable[Dict]:
"""
Reads spans from multiple files based on the provided file paths.
"""
if not files:
files = [_DEFAULT_FILE_PATTERN]
for filepath in files:
for filename in glob.glob(filepath):
with open(filename, "r") as file:
for line in file:
yield json.loads(line)
def _print_empty_message(files=None):
if not files:
files = [_DEFAULT_FILE_PATTERN]
file_names = ",".join(files)
print(f"No trace span records found in your tracer files: {file_names}")
def _new_search_span_func(search: str):
def func(span: Dict) -> bool:
items = [span["trace_id"], span["span_id"], span["parent_span_id"]]
if "operation_name" in span:
items.append(span["operation_name"])
if "metadata" in span:
metadata = span["metadata"]
if isinstance(metadata, dict):
for k, v in metadata.items():
items.append(k)
items.append(v)
return any(search in str(item) for item in items if item)
return func
def _parse_datetime(dt_str):
"""Parse a datetime string to a datetime object."""
return datetime.strptime(dt_str, "%Y-%m-%d %H:%M:%S.%f")
def _build_trace_hierarchy(spans, parent_span_id=None, indent=0):
# Current spans
current_level_spans = [
span
for span in spans
if span["parent_span_id"] == parent_span_id and span["end_time"] is None
]
hierarchy = []
for start_span in current_level_spans:
# Find end span
end_span = next(
(
span
for span in spans
if span["span_id"] == start_span["span_id"]
and span["end_time"] is not None
),
None,
)
entry = {
"operation_name": start_span["operation_name"],
"parent_span_id": start_span["parent_span_id"],
"span_id": start_span["span_id"],
"start_time": start_span["start_time"],
"end_time": start_span["end_time"],
"metadata": start_span["metadata"],
"children": _build_trace_hierarchy(
spans, start_span["span_id"], indent + 1
),
}
hierarchy.append(entry)
# Append end span
if end_span:
entry_end = {
"operation_name": end_span["operation_name"],
"parent_span_id": end_span["parent_span_id"],
"span_id": end_span["span_id"],
"start_time": end_span["start_time"],
"end_time": end_span["end_time"],
"metadata": end_span["metadata"],
"children": [],
}
hierarchy.append(entry_end)
return hierarchy
def _view_trace_hierarchy(trace_id, files=None):
"""Find and display the calls of the entire link based on the given trace_id"""
spans = read_spans_from_files(files)
trace_spans = [span for span in spans if span["trace_id"] == trace_id]
if not trace_spans:
return None
hierarchy = _build_trace_hierarchy(trace_spans)
return hierarchy
def _print_trace_hierarchy(hierarchy, indent=0):
"""Print link hierarchy"""
for entry in hierarchy:
print(
" " * indent
+ f"Operation: {entry['operation_name']} (Start: {entry['start_time']}, End: {entry['end_time']})"
)
_print_trace_hierarchy(entry["children"], indent + 1)
def _get_ordered_trace_from(hierarchy):
traces = []
def func(items):
for item in items:
traces.append(item)
func(item["children"])
func(hierarchy)
return traces
def _print(service_spans: Dict):
for names in [
[SpanTypeRunName.WEBSERVER.name, SpanTypeRunName.EMBEDDING_MODEL],
[SpanTypeRunName.WORKER_MANAGER.name, SpanTypeRunName.MODEL_WORKER],
]:
pass
def merge_tables_horizontally(tables):
from prettytable import PrettyTable
if not tables:
return None
tables = [t for t in tables if t]
if not tables:
return None
max_rows = max(len(table._rows) for table in tables)
merged_table = PrettyTable()
new_field_names = []
for table in tables:
new_field_names.extend(
[
f"{name} ({table.title})" if table.title else f"{name}"
for name in table.field_names
]
)
merged_table.field_names = new_field_names
for i in range(max_rows):
merged_row = []
for table in tables:
if i < len(table._rows):
merged_row.extend(table._rows[i])
else:
# Fill empty cells for shorter tables
merged_row.extend([""] * len(table.field_names))
merged_table.add_row(merged_row)
return merged_table
def split_string_by_terminal_width(s, split=True, max_len=None, sp="\n"):
"""
Split a string into substrings based on the current terminal width.
Parameters:
- s: the input string
"""
if not split:
return s
if not max_len:
try:
max_len = int(os.get_terminal_size().columns * 0.8)
except OSError:
# Default to 80 columns if the terminal size can't be determined
max_len = 100
return sp.join([s[i : i + max_len] for i in range(0, len(s), max_len)])

View File

@@ -0,0 +1,235 @@
from typing import Dict, Optional
from contextvars import ContextVar
from functools import wraps
import asyncio
import inspect
import logging
from dbgpt.component import SystemApp, ComponentType
from dbgpt.util.tracer.base import (
SpanType,
Span,
Tracer,
SpanStorage,
SpanStorageType,
TracerContext,
)
from dbgpt.util.tracer.span_storage import MemorySpanStorage
from dbgpt.util.module_utils import import_from_checked_string
logger = logging.getLogger(__name__)
class DefaultTracer(Tracer):
def __init__(
self,
system_app: SystemApp | None = None,
default_storage: SpanStorage = None,
span_storage_type: SpanStorageType = SpanStorageType.ON_CREATE_END,
):
super().__init__(system_app)
self._span_stack_var = ContextVar("span_stack", default=[])
if not default_storage:
default_storage = MemorySpanStorage(system_app)
self._default_storage = default_storage
self._span_storage_type = span_storage_type
def append_span(self, span: Span):
self._get_current_storage().append_span(span)
def start_span(
self,
operation_name: str,
parent_span_id: str = None,
span_type: SpanType = None,
metadata: Dict = None,
) -> Span:
trace_id = (
self._new_uuid() if parent_span_id is None else parent_span_id.split(":")[0]
)
span_id = f"{trace_id}:{self._new_uuid()}"
span = Span(
trace_id,
span_id,
span_type,
parent_span_id,
operation_name,
metadata=metadata,
)
if self._span_storage_type in [
SpanStorageType.ON_END,
SpanStorageType.ON_CREATE_END,
]:
span.add_end_caller(self.append_span)
if self._span_storage_type in [
SpanStorageType.ON_CREATE,
SpanStorageType.ON_CREATE_END,
]:
self.append_span(span)
current_stack = self._span_stack_var.get()
current_stack.append(span)
self._span_stack_var.set(current_stack)
span.add_end_caller(self._remove_from_stack_top)
return span
def end_span(self, span: Span, **kwargs):
""""""
span.end(**kwargs)
def _remove_from_stack_top(self, span: Span):
current_stack = self._span_stack_var.get()
if current_stack:
current_stack.pop()
self._span_stack_var.set(current_stack)
def get_current_span(self) -> Optional[Span]:
current_stack = self._span_stack_var.get()
return current_stack[-1] if current_stack else None
def _get_current_storage(self) -> SpanStorage:
return self.system_app.get_component(
ComponentType.TRACER_SPAN_STORAGE, SpanStorage, self._default_storage
)
class TracerManager:
"""The manager of current tracer"""
def __init__(self) -> None:
self._system_app: Optional[SystemApp] = None
self._trace_context_var: ContextVar[TracerContext] = ContextVar(
"trace_context",
default=TracerContext(),
)
def initialize(
self, system_app: SystemApp, trace_context_var: ContextVar[TracerContext] = None
) -> None:
self._system_app = system_app
if trace_context_var:
self._trace_context_var = trace_context_var
def _get_tracer(self) -> Tracer:
if not self._system_app:
return None
return self._system_app.get_component(ComponentType.TRACER, Tracer, None)
def start_span(
self,
operation_name: str,
parent_span_id: str = None,
span_type: SpanType = None,
metadata: Dict = None,
) -> Span:
"""Start a new span with operation_name
This method must not throw an exception under any case and try not to block as much as possible
"""
tracer = self._get_tracer()
if not tracer:
return Span("empty_span", "empty_span")
if not parent_span_id:
parent_span_id = self.get_current_span_id()
return tracer.start_span(
operation_name, parent_span_id, span_type=span_type, metadata=metadata
)
def end_span(self, span: Span, **kwargs):
tracer = self._get_tracer()
if not tracer or not span:
return
tracer.end_span(span, **kwargs)
def get_current_span(self) -> Optional[Span]:
tracer = self._get_tracer()
if not tracer:
return None
return tracer.get_current_span()
def get_current_span_id(self) -> Optional[str]:
current_span = self.get_current_span()
if current_span:
return current_span.span_id
ctx = self._trace_context_var.get()
return ctx.span_id if ctx else None
root_tracer: TracerManager = TracerManager()
def trace(operation_name: Optional[str] = None, **trace_kwargs):
def decorator(func):
@wraps(func)
def sync_wrapper(*args, **kwargs):
name = (
operation_name if operation_name else _parse_operation_name(func, *args)
)
with root_tracer.start_span(name, **trace_kwargs):
return func(*args, **kwargs)
@wraps(func)
async def async_wrapper(*args, **kwargs):
name = (
operation_name if operation_name else _parse_operation_name(func, *args)
)
with root_tracer.start_span(name, **trace_kwargs):
return await func(*args, **kwargs)
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
return decorator
def _parse_operation_name(func, *args):
self_name = None
if inspect.signature(func).parameters.get("self"):
self_name = args[0].__class__.__name__
func_name = func.__name__
if self_name:
return f"{self_name}.{func_name}"
return func_name
def initialize_tracer(
system_app: SystemApp,
tracer_filename: str,
root_operation_name: str = "DB-GPT-Web-Entry",
tracer_storage_cls: str = None,
):
if not system_app:
return
from dbgpt.util.tracer.span_storage import FileSpanStorage, SpanStorageContainer
trace_context_var = ContextVar(
"trace_context",
default=TracerContext(),
)
tracer = DefaultTracer(system_app)
storage_container = SpanStorageContainer(system_app)
storage_container.append_storage(FileSpanStorage(tracer_filename))
if tracer_storage_cls:
logger.info(f"Begin parse storage class {tracer_storage_cls}")
storage = import_from_checked_string(tracer_storage_cls, SpanStorage)
storage_container.append_storage(storage())
system_app.register_instance(storage_container)
system_app.register_instance(tracer)
root_tracer.initialize(system_app, trace_context_var)
if system_app.app:
from dbgpt.util.tracer.tracer_middleware import TraceIDMiddleware
system_app.app.add_middleware(
TraceIDMiddleware,
trace_context_var=trace_context_var,
tracer=tracer,
root_operation_name=root_operation_name,
)

View File

@@ -0,0 +1,45 @@
import uuid
from contextvars import ContextVar
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.types import ASGIApp
from dbgpt.util.tracer import TracerContext, Tracer
_DEFAULT_EXCLUDE_PATHS = ["/api/controller/heartbeat"]
class TraceIDMiddleware(BaseHTTPMiddleware):
def __init__(
self,
app: ASGIApp,
trace_context_var: ContextVar[TracerContext],
tracer: Tracer,
root_operation_name: str = "DB-GPT-Web-Entry",
include_prefix: str = "/api",
exclude_paths=_DEFAULT_EXCLUDE_PATHS,
):
super().__init__(app)
self.trace_context_var = trace_context_var
self.tracer = tracer
self.root_operation_name = root_operation_name
self.include_prefix = include_prefix
self.exclude_paths = exclude_paths
async def dispatch(self, request: Request, call_next):
if request.url.path in self.exclude_paths or not request.url.path.startswith(
self.include_prefix
):
return await call_next(request)
span_id = request.headers.get("DBGPT_TRACER_SPAN_ID")
# if not span_id:
# span_id = str(uuid.uuid4())
# self.trace_context_var.set(TracerContext(span_id=span_id))
with self.tracer.start_span(
self.root_operation_name, span_id, metadata={"path": request.url.path}
):
response = await call_next(request)
return response