refactor: The first refactored version for sdk release (#907)

Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
FangYin Cheng
2023-12-08 14:45:59 +08:00
committed by GitHub
parent e7e4aff667
commit cd725db1fb
573 changed files with 2094 additions and 3571 deletions

8
dbgpt/util/__init__.py Normal file
View File

@@ -0,0 +1,8 @@
from .utils import (
get_gpu_memory,
StreamToLogger,
disable_torch_init,
pretty_print_semaphore,
server_error_msg,
get_or_create_event_loop,
)

67
dbgpt/util/annotations.py Normal file
View File

@@ -0,0 +1,67 @@
def PublicAPI(*args, **kwargs):
"""Decorator to mark a function or class as a public API.
Args:
stability: The stability of the API. Can be "alpha", "beta" or "stable".
If "alpha", the API is in alpha may come breaking changes before becoming beta.
If "beta", the API is in beta and may change before becoming stable.
If "stable", the API will remain backwards compatible with the current major version.
Defaults to "stable".
Examples:
>>> from dbgpt.util.annotations import PublicAPI
>>> @PublicAPI
... def foo():
... pass
>>> @PublicAPI(stability="beta")
... def bar():
... pass
"""
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
return PublicAPI(stability="stable")(args[0])
stability = None
if "stability" in kwargs:
stability = kwargs["stability"]
if not stability:
stability = "stable"
assert stability in ["alpha", "beta", "stable"]
def decorator(obj):
if stability in ["alpha", "beta"]:
_modify_docstring(
obj,
f"**PublicAPI ({stability}):** This API is in {stability} and may change before becoming stable.",
)
_modify_annotation(obj, stability)
return obj
return decorator
def _modify_docstring(obj, message: str = None):
if not message:
return
if not obj.__doc__:
obj.__doc__ = ""
original_doc = obj.__doc__
lines = original_doc.splitlines()
min_indent = float("inf")
for line in lines[1:]:
stripped = line.lstrip()
if stripped:
min_indent = min(min_indent, len(line) - len(stripped))
if min_indent == float("inf"):
min_indent = 0
indented_message = message.rstrip() + "\n" + (" " * min_indent)
obj.__doc__ = indented_message + original_doc
def _modify_annotation(obj, stability) -> None:
if stability:
obj._public_stability = stability
if hasattr(obj, "__name__"):
obj._annotated = obj.__name__

127
dbgpt/util/api_utils.py Normal file
View File

@@ -0,0 +1,127 @@
from inspect import signature
import logging
from typing import get_type_hints, List, Type, TypeVar, Union, Optional, Tuple
from dataclasses import is_dataclass, asdict
T = TypeVar("T")
logger = logging.getLogger(__name__)
def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]:
import typing_inspect
"""Extract actual dataclass from generic type hints like List[dataclass], Optional[dataclass], etc."""
if typing_inspect.is_generic_type(type_hint) and typing_inspect.get_args(type_hint):
return typing_inspect.get_args(type_hint)[0]
return None
def _build_request(self, func, path, method, *args, **kwargs):
return_type = get_type_hints(func).get("return")
if return_type is None:
raise TypeError("Return type must be annotated in the decorated function.")
actual_dataclass = _extract_dataclass_from_generic(return_type)
logger.debug(f"return_type: {return_type}, actual_dataclass: {actual_dataclass}")
if not actual_dataclass:
actual_dataclass = return_type
sig = signature(func)
base_url = self.base_url # Get base_url from class instance
bound = sig.bind(self, *args, **kwargs)
bound.apply_defaults()
formatted_url = base_url + path.format(**bound.arguments)
# Extract args names from signature, except "self"
arg_names = list(sig.parameters.keys())[1:]
# Combine args and kwargs into a single dictionary
combined_args = dict(zip(arg_names, args))
combined_args.update(kwargs)
request_data = {}
for key, value in combined_args.items():
if is_dataclass(value):
# Here, instead of adding it as a nested dictionary,
# we set request_data directly to its dictionary representation.
request_data = asdict(value)
else:
request_data[key] = value
request_params = {"method": method, "url": formatted_url}
if method in ["POST", "PUT", "PATCH"]:
request_params["json"] = request_data
else: # For GET, DELETE, etc.
request_params["params"] = request_data
logger.debug(f"request_params: {request_params}, args: {args}, kwargs: {kwargs}")
return return_type, actual_dataclass, request_params
def _api_remote(path, method="GET"):
def decorator(func):
async def wrapper(self, *args, **kwargs):
import httpx
return_type, actual_dataclass, request_params = _build_request(
self, func, path, method, *args, **kwargs
)
async with httpx.AsyncClient() as client:
response = await client.request(**request_params)
if response.status_code == 200:
return _parse_response(
response.json(), return_type, actual_dataclass
)
else:
error_msg = f"Remote request error, error code: {response.status_code}, error msg: {response.text}"
raise Exception(error_msg)
return wrapper
return decorator
def _sync_api_remote(path, method="GET"):
def decorator(func):
def wrapper(self, *args, **kwargs):
import requests
return_type, actual_dataclass, request_params = _build_request(
self, func, path, method, *args, **kwargs
)
response = requests.request(**request_params)
if response.status_code == 200:
return _parse_response(response.json(), return_type, actual_dataclass)
else:
error_msg = f"Remote request error, error code: {response.status_code}, error msg: {response.text}"
raise Exception(error_msg)
return wrapper
return decorator
def _parse_response(json_response, return_type, actual_dataclass):
# print(f'return_type.__origin__: {return_type.__origin__}, actual_dataclass: {actual_dataclass}, json_response: {json_response}')
if is_dataclass(actual_dataclass):
if return_type.__origin__ is list: # for List[dataclass]
if isinstance(json_response, list):
return [actual_dataclass(**item) for item in json_response]
else:
raise TypeError(
f"Expected list in response but got {type(json_response)}"
)
else:
if isinstance(json_response, dict):
return actual_dataclass(**json_response)
else:
raise TypeError(
f"Expected dictionary in response but got {type(json_response)}"
)
else:
return json_response

View File

View File

View File

@@ -0,0 +1,296 @@
"""
Adapted from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py.
For benchmarks.
"""
import gc
from typing import Iterable, Dict
import torch
from transformers.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length
def prepare_logits_processor(
temperature: float, repetition_penalty: float, top_p: float, top_k: int
) -> LogitsProcessorList:
processor_list = LogitsProcessorList()
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
if temperature >= 1e-5 and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
if repetition_penalty > 1.0:
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
if 1e-8 <= top_p < 1.0:
processor_list.append(TopPLogitsWarper(top_p))
if top_k > 0:
processor_list.append(TopKLogitsWarper(top_k))
return processor_list
@torch.inference_mode()
def generate_stream(
model,
tokenizer,
params: Dict,
device: str,
context_len: int,
stream_interval: int = 1,
judge_sent_end: bool = False,
):
if hasattr(model, "device"):
device = model.device
# Read parameters
prompt = params["prompt"]
len_prompt = len(prompt)
temperature = float(params.get("temperature", 1.0))
repetition_penalty = float(params.get("repetition_penalty", 1.0))
top_p = float(params.get("top_p", 1.0))
top_k = int(params.get("top_k", -1)) # -1 means disable
max_new_tokens = int(params.get("max_new_tokens", 256))
logprobs = params.get("logprobs", None) # FIXME: Support logprobs>1.
echo = bool(params.get("echo", True))
stop_str = params.get("stop", None)
stop_token_ids = params.get("stop_token_ids", None) or []
if tokenizer.eos_token_id not in stop_token_ids:
stop_token_ids.append(tokenizer.eos_token_id)
logits_processor = prepare_logits_processor(
temperature, repetition_penalty, top_p, top_k
)
input_ids = tokenizer(prompt).input_ids
if model.config.is_encoder_decoder:
max_src_len = context_len
else: # truncate
max_src_len = context_len - max_new_tokens - 1
input_ids = input_ids[-max_src_len:]
output_ids = list(input_ids)
input_echo_len = len(input_ids)
# Don't stop generate until max_new_tokens is reached.
stop_token_ids = []
stop_str = None
if model.config.is_encoder_decoder:
if logprobs is not None: # FIXME: Support logprobs for encoder-decoder models.
raise NotImplementedError
encoder_output = model.encoder(
input_ids=torch.as_tensor([input_ids], device=device)
)[0]
start_ids = torch.as_tensor(
[[model.generation_config.decoder_start_token_id]],
dtype=torch.int64,
device=device,
)
else:
start_ids = torch.as_tensor([input_ids], device=device)
past_key_values = out = None
token_logprobs = [None] # The first token has no logprobs.
sent_interrupt = False
finish_reason = None
for i in range(max_new_tokens):
if i == 0: # prefill
if model.config.is_encoder_decoder:
out = model.decoder(
input_ids=start_ids,
encoder_hidden_states=encoder_output,
use_cache=True,
)
logits = model.lm_head(out[0])
else:
out = model(input_ids=start_ids, use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
if logprobs is not None:
# Prefull logprobs for the prompt.
shift_input_ids = start_ids[..., 1:].contiguous()
shift_logits = logits[..., :-1, :].contiguous()
shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist()
for label_id, logit in zip(
shift_input_ids[0].tolist(), shift_logits[0]
):
token_logprobs.append(logit[label_id])
else: # decoding
if model.config.is_encoder_decoder:
out = model.decoder(
input_ids=torch.as_tensor(
[[token] if not sent_interrupt else output_ids],
device=device,
),
encoder_hidden_states=encoder_output,
use_cache=True,
past_key_values=past_key_values if not sent_interrupt else None,
)
sent_interrupt = False
logits = model.lm_head(out[0])
else:
out = model(
input_ids=torch.as_tensor(
[[token] if not sent_interrupt else output_ids],
device=device,
),
use_cache=True,
past_key_values=past_key_values if not sent_interrupt else None,
)
sent_interrupt = False
logits = out.logits
past_key_values = out.past_key_values
if logits_processor:
if repetition_penalty > 1.0:
tmp_output_ids = torch.as_tensor([output_ids], device=logits.device)
else:
tmp_output_ids = None
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
else:
last_token_logits = logits[0, -1, :]
if device == "mps":
# Switch to CPU by avoiding some bugs in mps backend.
last_token_logits = last_token_logits.float().to("cpu")
if temperature < 1e-5 or top_p < 1e-8: # greedy
_, indices = torch.topk(last_token_logits, 2)
tokens = [int(index) for index in indices.tolist()]
else:
probs = torch.softmax(last_token_logits, dim=-1)
indices = torch.multinomial(probs, num_samples=2)
tokens = [int(token) for token in indices.tolist()]
token = tokens[0]
output_ids.append(token)
if logprobs is not None:
# Cannot use last_token_logits because logprobs is based on raw logits.
token_logprobs.append(
torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist()
)
if token in stop_token_ids:
stopped = True
else:
stopped = False
# Yield the output tokens
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
if echo:
tmp_output_ids = output_ids
rfind_start = len_prompt
else:
tmp_output_ids = output_ids[input_echo_len:]
rfind_start = 0
output = tokenizer.decode(
tmp_output_ids,
skip_special_tokens=True,
spaces_between_special_tokens=False,
clean_up_tokenization_spaces=True,
)
ret_logprobs = None
if logprobs is not None:
ret_logprobs = {
"text_offset": [],
"tokens": [
tokenizer.decode(token)
for token in (
output_ids if echo else output_ids[input_echo_len:]
)
],
"token_logprobs": token_logprobs
if echo
else token_logprobs[input_echo_len:],
"top_logprobs": [{}]
* len(token_logprobs if echo else token_logprobs[input_echo_len:]),
}
# Compute text_offset
curr_pos = 0
for text in ret_logprobs["tokens"]:
ret_logprobs["text_offset"].append(curr_pos)
curr_pos += len(text)
# TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way
if judge_sent_end and stopped and not is_sentence_complete(output):
if len(tokens) > 1:
token = tokens[1]
output_ids[-1] = token
else:
output_ids.pop()
stopped = False
sent_interrupt = True
partially_stopped = False
if stop_str:
if isinstance(stop_str, str):
pos = output.rfind(stop_str, rfind_start)
if pos != -1:
output = output[:pos]
stopped = True
else:
partially_stopped = is_partial_stop(output, stop_str)
elif isinstance(stop_str, Iterable):
for each_stop in stop_str:
pos = output.rfind(each_stop, rfind_start)
if pos != -1:
output = output[:pos]
stopped = True
break
else:
partially_stopped = is_partial_stop(output, each_stop)
if partially_stopped:
break
else:
raise ValueError("Invalid stop field type.")
# Prevent yielding partial stop sequence
if not partially_stopped:
yield {
"text": output,
"logprobs": ret_logprobs,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": None,
}
if stopped:
break
# Finish stream event, which contains finish reason
else:
finish_reason = "length"
if stopped:
finish_reason = "stop"
yield {
"text": output,
"logprobs": ret_logprobs,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": finish_reason,
}
# Clean
del past_key_values, out
gc.collect()
torch.cuda.empty_cache()
if device == "xpu":
torch.xpu.empty_cache()
if device == "npu":
torch.npu.empty_cache()

View File

@@ -0,0 +1,282 @@
from typing import Dict, List
import asyncio
import os
import sys
import time
import csv
import argparse
import logging
import traceback
from dbgpt.configs.model_config import ROOT_PATH, LLM_MODEL_CONFIG
from datetime import datetime
from dbgpt.model.cluster.worker.manager import (
run_worker_manager,
initialize_worker_manager_in_client,
WorkerManager,
)
from dbgpt.core import ModelOutput, ModelInferenceMetrics
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
model_name = "vicuna-7b-v1.5"
model_path = LLM_MODEL_CONFIG[model_name]
# or vllm
model_type = "huggingface"
controller_addr = "http://127.0.0.1:5000"
result_csv_file = None
parallel_nums = [1, 2, 4, 16, 32]
# parallel_nums = [1, 2, 4]
def get_result_csv_file() -> str:
return os.path.join(
ROOT_PATH, f"pilot/data/{model_name}_{model_type}_benchmarks_llm.csv"
)
input_lens = [64, 64]
output_lens = [256, 512]
prompt_file_map = {
"11k": os.path.join(
ROOT_PATH, "docker/examples/benchmarks/benchmarks_llm_11k_prompt.txt"
)
}
METRICS_HEADERS = [
# Params
"model_name",
"gpu_nums",
"parallel_nums",
"input_length",
"output_length",
# Merge parallel result
"test_time_cost_ms",
"test_total_tokens",
# avg_test_speed_per_second: (tokens / s), test_total_tokens / (test_time_cost_ms / 1000.0)
"avg_test_speed_per_second(tokens/s)",
# avg_first_token_latency_ms: sum(first_token_time_ms) / parallel_nums
"avg_first_token_latency_ms",
# avg_latency_ms: sum(end_time_ms - start_time_ms) / parallel_nums
"avg_latency_ms",
"gpu_mem(GiB)",
# Detail for each task
"start_time_ms",
"end_time_ms",
"current_time_ms",
"first_token_time_ms",
"first_completion_time_ms",
"first_completion_tokens",
"prompt_tokens",
"completion_tokens",
"total_tokens",
"speed_per_second",
]
def read_prompt_from_file(file_key: str) -> str:
full_path = prompt_file_map[file_key]
with open(full_path, "r+", encoding="utf-8") as f:
return f.read()
def build_param(
input_len: int,
output_len: int,
user_input: str,
system_prompt: str = None,
) -> Dict:
hist = []
if system_prompt is not None:
hist.append(
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=system_prompt)
)
hist.append(ModelMessage(role=ModelMessageRoleType.HUMAN, content=user_input))
hist = list(h.dict() for h in hist)
context_len = input_len + output_len + 2
params = {
"prompt": user_input,
"messages": hist,
"model": model_name,
"echo": False,
"max_new_tokens": output_len,
"context_len": context_len,
}
return params
async def run_batch(
wh: WorkerManager,
input_len: int,
output_len: int,
parallel_num: int,
output_file: str,
):
tasks = []
prompt = read_prompt_from_file("11k")
if model_type == "vllm":
max_input_str_len = input_len
if "baichuan" in model_name:
# TODO prompt handle first
max_input_str_len *= 2
prompt = prompt[-max_input_str_len:]
# Warmup first
params = build_param(input_len, output_len, prompt, system_prompt="")
await wh.generate(params)
for _ in range(parallel_num):
params = build_param(input_len, output_len, prompt, system_prompt="")
tasks.append(wh.generate(params))
print(
f"Begin run benchmarks, model name: {model_name}, input_len: {input_len}, output_len: {output_len}, parallel_num: {parallel_num}, save result to {output_file}"
)
start_time_ms = time.time_ns() // 1_000_000
results: List[ModelOutput] = await asyncio.gather(*tasks)
end_time_ms = time.time_ns() // 1_000_000
test_time_cost_ms = end_time_ms - start_time_ms
test_total_tokens = 0
first_token_latency_ms = 0
latency_ms = 0
gpu_nums = 0
avg_gpu_mem = 0
rows = []
for r in results:
metrics = r.metrics
if isinstance(metrics, dict):
metrics = ModelInferenceMetrics(**metrics)
print(r)
test_total_tokens += metrics.total_tokens
first_token_latency_ms += metrics.first_token_time_ms - metrics.start_time_ms
latency_ms += metrics.end_time_ms - metrics.start_time_ms
row_data = metrics.to_dict()
del row_data["collect_index"]
if "avg_gpu_infos" in row_data:
avg_gpu_infos = row_data["avg_gpu_infos"]
gpu_nums = len(avg_gpu_infos)
avg_gpu_mem = (
sum(i["allocated_memory_gb"] for i in avg_gpu_infos) / gpu_nums
)
del row_data["avg_gpu_infos"]
del row_data["current_gpu_infos"]
rows.append(row_data)
avg_test_speed_per_second = test_total_tokens / (test_time_cost_ms / 1000.0)
avg_first_token_latency_ms = first_token_latency_ms / len(results)
avg_latency_ms = latency_ms / len(results)
with open(output_file, "a", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=METRICS_HEADERS)
if f.tell() == 0:
# Fist time
writer.writeheader()
for row in rows:
row["model_name"] = model_name
row["parallel_nums"] = parallel_num
row["input_length"] = input_len
row["output_length"] = output_len
row["test_time_cost_ms"] = test_time_cost_ms
row["test_total_tokens"] = test_total_tokens
row["avg_test_speed_per_second(tokens/s)"] = avg_test_speed_per_second
row["avg_first_token_latency_ms"] = avg_first_token_latency_ms
row["avg_latency_ms"] = avg_latency_ms
row["gpu_nums"] = gpu_nums
row["gpu_mem(GiB)"] = avg_gpu_mem
writer.writerow(row)
print(
f"input_len: {input_len}, output_len: {output_len}, parallel_num: {parallel_num}, save result to {output_file}"
)
async def run_model(wh: WorkerManager) -> None:
global result_csv_file
if not result_csv_file:
result_csv_file = get_result_csv_file()
if os.path.exists(result_csv_file):
now = datetime.now()
now_str = now.strftime("%Y-%m-%d")
os.rename(result_csv_file, f"{result_csv_file}.bak_{now_str}.csv")
for parallel_num in parallel_nums:
for input_len, output_len in zip(input_lens, output_lens):
try:
await run_batch(
wh, input_len, output_len, parallel_num, result_csv_file
)
except Exception:
msg = traceback.format_exc()
logging.error(
f"Run benchmarks error, input_len: {input_len}, output_len: {output_len}, parallel_num: {parallel_num}, error message: {msg}"
)
if "torch.cuda.OutOfMemoryError" in msg:
return
sys.exit(0)
def startup_llm_env():
from fastapi import FastAPI
app = FastAPI()
initialize_worker_manager_in_client(
app=app,
model_name=model_name,
model_path=model_path,
run_locally=False,
controller_addr=controller_addr,
local_port=6000,
start_listener=run_model,
)
def connect_to_remote_model():
startup_llm_env()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default=model_name)
parser.add_argument("--model_path", type=str, default=None)
parser.add_argument("--model_type", type=str, default="huggingface")
parser.add_argument("--result_csv_file", type=str, default=None)
parser.add_argument("--input_lens", type=str, default="8,8,256,1024")
parser.add_argument("--output_lens", type=str, default="256,512,1024,1024")
parser.add_argument("--parallel_nums", type=str, default="1,2,4,16,32")
parser.add_argument(
"--remote_model", type=bool, default=False, help="Connect to remote model"
)
parser.add_argument("--controller_addr", type=str, default="http://127.0.0.1:8000")
parser.add_argument("--limit_model_concurrency", type=int, default=200)
args = parser.parse_args()
print(f"args: {args}")
model_name = args.model_name
model_path = args.model_path or LLM_MODEL_CONFIG[model_name]
result_csv_file = args.result_csv_file
input_lens = [int(i) for i in args.input_lens.strip().split(",")]
output_lens = [int(i) for i in args.output_lens.strip().split(",")]
parallel_nums = [int(i) for i in args.parallel_nums.strip().split(",")]
remote_model = args.remote_model
controller_addr = args.controller_addr
limit_model_concurrency = args.limit_model_concurrency
model_type = args.model_type
if len(input_lens) != len(output_lens):
raise ValueError("input_lens size must equal output_lens size")
if remote_model:
# Connect to remote model and run benchmarks
connect_to_remote_model()
else:
# Start worker manager and run benchmarks
run_worker_manager(
model_name=model_name,
model_path=model_path,
start_listener=run_model,
limit_model_concurrency=limit_model_concurrency,
model_type=model_type,
)

162
dbgpt/util/command_utils.py Normal file
View File

@@ -0,0 +1,162 @@
import sys
import os
import subprocess
from typing import List, Dict
import psutil
import platform
from functools import lru_cache
def _get_abspath_of_current_command(command_path: str):
if not command_path.endswith(".py"):
return command_path
# This implementation is very ugly
command_path = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"scripts",
"cli_scripts.py",
)
return command_path
def _run_current_with_daemon(name: str, log_file: str):
# Get all arguments except for --daemon
args = [arg for arg in sys.argv if arg != "--daemon" and arg != "-d"]
args[0] = _get_abspath_of_current_command(args[0])
daemon_cmd = [sys.executable] + args
daemon_cmd = " ".join(daemon_cmd)
daemon_cmd += f" > {log_file} 2>&1"
print(f"daemon cmd: {daemon_cmd}")
# Check the platform and set the appropriate flags or functions
if "windows" in platform.system().lower():
process = subprocess.Popen(
daemon_cmd,
creationflags=subprocess.CREATE_NEW_PROCESS_GROUP,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
shell=True,
)
else: # macOS, Linux, and other Unix-like systems
process = subprocess.Popen(
daemon_cmd,
preexec_fn=os.setsid,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
shell=True,
)
print(f"Started {name} in background with pid: {process.pid}")
def _run_current_with_gunicorn(app: str, config_path: str, kwargs: Dict):
try:
import gunicorn
except ImportError as e:
raise ValueError(
"Could not import python package: gunicorn"
"Daemon mode need install gunicorn, please install `pip install gunicorn`"
) from e
from dbgpt.util.parameter_utils import EnvArgumentParser
env_to_app = {}
env_to_app.update(os.environ)
app_env = EnvArgumentParser._kwargs_to_env_key_value(kwargs)
env_to_app.update(app_env)
cmd = f"uvicorn {app} --host 0.0.0.0 --port 5000"
if "windows" in platform.system().lower():
raise Exception("Not support on windows")
else: # macOS, Linux, and other Unix-like systems
process = subprocess.Popen(cmd, shell=True, env=env_to_app)
print(f"Started {app} with gunicorn in background with pid: {process.pid}")
def _stop_service(
key: str, fullname: str, service_keys: List[str] = None, port: int = None
):
if not service_keys:
service_keys = [sys.argv[0], "start", key]
not_found = True
for process in psutil.process_iter(attrs=["pid", "datasource", "cmdline"]):
try:
cmdline = " ".join(process.info["cmdline"])
# Check if all key fragments are in the cmdline
if all(fragment in cmdline for fragment in service_keys):
if port:
for conn in process.info["datasource"]:
if (
conn.status == psutil.CONN_LISTEN
and conn.laddr.port == port
):
psutil.Process(process.info["pid"]).terminate()
print(
f"Terminated the {fullname} with PID: {process.info['pid']} listening on port: {port}"
)
not_found = False
else:
psutil.Process(process.info["pid"]).terminate()
print(f"Terminated the {fullname} with PID: {process.info['pid']}")
not_found = False
except (psutil.NoSuchProcess, psutil.AccessDenied):
continue
if not_found:
print(f"{fullname} process not found.")
def _get_ports_by_cmdline_part(service_keys: List[str]) -> List[int]:
"""
Return a list of ports that are associated with processes that have all the service_keys in their cmdline.
Args:
service_keys (List[str]): List of strings that should all be present in the process's cmdline.
Returns:
List[int]: List of ports sorted with preference for 8000 and 5000, and then in ascending order.
"""
ports = []
for process in psutil.process_iter(attrs=["pid", "name", "cmdline", "connections"]):
try:
# Convert the cmdline list to a single string for easier checking
cmdline = ""
if process.info.get("cmdline"):
cmdline = " ".join(process.info["cmdline"])
# Check if all the service keys are present in the cmdline
if cmdline and all(fragment in cmdline for fragment in service_keys):
connections = process.info.get("connections")
if connections is not None and len(ports) == 0:
for connection in connections:
if connection.status == psutil.CONN_LISTEN:
ports.append(connection.laddr.port)
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
pass
# Sort ports with preference for 8000 and 5000
ports.sort(key=lambda x: (x != 8000, x != 5000, x))
return ports
@lru_cache()
def _detect_controller_address() -> str:
controller_addr = os.getenv("CONTROLLER_ADDRESS")
if controller_addr:
return controller_addr
cmdline_fragments = [
["python", "start", "controller"],
["python", "controller"],
["python", "start", "webserver"],
["python", "dbgpt_server"],
]
for fragments in cmdline_fragments:
ports = _get_ports_by_cmdline_part(fragments)
if ports:
return f"http://127.0.0.1:{ports[0]}"
return f"http://127.0.0.1:8000"

View File

@@ -0,0 +1,34 @@
from collections import OrderedDict
from collections import deque
class FixedSizeDict(OrderedDict):
def __init__(self, max_size):
super().__init__()
self.max_size = max_size
def __setitem__(self, key, value):
if len(self) >= self.max_size:
self.popitem(last=False)
super().__setitem__(key, value)
class FixedSizeList:
def __init__(self, max_size):
self.max_size = max_size
self.list = deque(maxlen=max_size)
def append(self, value):
self.list.append(value)
def __getitem__(self, index):
return self.list[index]
def __setitem__(self, index, value):
self.list[index] = value
def __len__(self):
return len(self.list)
def __str__(self):
return str(list(self.list))

View File

@@ -0,0 +1,87 @@
from typing import Callable, Awaitable, Any
import asyncio
import contextvars
from abc import ABC, abstractmethod
from concurrent.futures import Executor, ThreadPoolExecutor
from functools import partial
from dbgpt.component import BaseComponent, ComponentType, SystemApp
class ExecutorFactory(BaseComponent, ABC):
name = ComponentType.EXECUTOR_DEFAULT.value
@abstractmethod
def create(self) -> "Executor":
"""Create executor"""
class DefaultExecutorFactory(ExecutorFactory):
def __init__(self, system_app: SystemApp | None = None, max_workers=None):
super().__init__(system_app)
self._executor = ThreadPoolExecutor(
max_workers=max_workers, thread_name_prefix=self.name
)
def init_app(self, system_app: SystemApp):
pass
def create(self) -> Executor:
return self._executor
BlockingFunction = Callable[..., Any]
async def blocking_func_to_async(
executor: Executor, func: BlockingFunction, *args, **kwargs
):
"""Run a potentially blocking function within an executor.
Args:
executor (Executor): The concurrent.futures.Executor to run the function within.
func (ApplyFunction): The callable function, which should be a synchronous function.
It should accept any number and type of arguments and return an asynchronous coroutine.
*args (Any): Any additional arguments to pass to the function.
**kwargs (Any): Other arguments to pass to the function
Returns:
Any: The result of the function's execution.
Raises:
ValueError: If the provided function 'func' is an asynchronous coroutine function.
This function allows you to execute a potentially blocking function within an executor.
It expects 'func' to be a synchronous function and will raise an error if 'func' is an asynchronous coroutine.
"""
if asyncio.iscoroutinefunction(func):
raise ValueError(f"The function {func} is not blocking function")
# This function will be called within the new thread, capturing the current context
ctx = contextvars.copy_context()
def run_with_context():
return ctx.run(partial(func, *args, **kwargs))
loop = asyncio.get_event_loop()
return await loop.run_in_executor(executor, run_with_context)
class AsyncToSyncIterator:
def __init__(self, async_iterable, loop: asyncio.BaseEventLoop):
self.async_iterable = async_iterable
self.async_iterator = None
self._loop = loop
def __iter__(self):
self.async_iterator = self.async_iterable.__aiter__()
return self
def __next__(self):
if self.async_iterator is None:
raise StopIteration
try:
return self._loop.run_until_complete(self.async_iterator.__anext__())
except StopAsyncIteration:
raise StopIteration

61
dbgpt/util/formatting.py Normal file
View File

@@ -0,0 +1,61 @@
"""Utilities for formatting strings."""
import json
from string import Formatter
from typing import Any, List, Mapping, Sequence, Union
class StrictFormatter(Formatter):
"""A subclass of formatter that checks for extra keys."""
def check_unused_args(
self,
used_args: Sequence[Union[int, str]],
args: Sequence,
kwargs: Mapping[str, Any],
) -> None:
"""Check to see if extra parameters are passed."""
extra = set(kwargs).difference(used_args)
if extra:
raise KeyError(extra)
def vformat(
self, format_string: str, args: Sequence, kwargs: Mapping[str, Any]
) -> str:
"""Check that no arguments are provided."""
if len(args) > 0:
raise ValueError(
"No arguments should be provided, "
"everything should be passed as keyword arguments."
)
return super().vformat(format_string, args, kwargs)
def validate_input_variables(
self, format_string: str, input_variables: List[str]
) -> None:
dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
super().format(format_string, **dummy_inputs)
class NoStrictFormatter(StrictFormatter):
def check_unused_args(
self,
used_args: Sequence[Union[int, str]],
args: Sequence,
kwargs: Mapping[str, Any],
) -> None:
"""Not check unused args"""
pass
formatter = StrictFormatter()
no_strict_formatter = NoStrictFormatter()
class MyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, set):
return list(obj)
elif hasattr(obj, "__dict__"):
return obj.__dict__
else:
return json.JSONEncoder.default(self, obj)

448
dbgpt/util/global_helper.py Normal file
View File

@@ -0,0 +1,448 @@
"""General util functions."""
import asyncio
import os
import random
import sys
import time
import traceback
import uuid
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial, wraps
from itertools import islice
from pathlib import Path
from typing import (
Any,
AsyncGenerator,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Set,
Type,
Union,
cast,
)
class GlobalsHelper:
"""Helper to retrieve globals.
Helpful for global caching of certain variables that can be expensive to load.
(e.g. tokenization)
"""
_tokenizer: Optional[Callable[[str], List]] = None
_stopwords: Optional[List[str]] = None
@property
def tokenizer(self) -> Callable[[str], List]:
"""Get tokenizer."""
if self._tokenizer is None:
tiktoken_import_err = (
"`tiktoken` package not found, please run `pip install tiktoken`"
)
try:
import tiktoken
except ImportError:
raise ImportError(tiktoken_import_err)
enc = tiktoken.get_encoding("gpt2")
self._tokenizer = cast(Callable[[str], List], enc.encode)
self._tokenizer = partial(self._tokenizer, allowed_special="all")
return self._tokenizer # type: ignore
@property
def stopwords(self) -> List[str]:
"""Get stopwords."""
if self._stopwords is None:
try:
import nltk
from nltk.corpus import stopwords
except ImportError:
raise ImportError(
"`nltk` package not found, please run `pip install nltk`"
)
from llama_index.utils import get_cache_dir
cache_dir = get_cache_dir()
nltk_data_dir = os.environ.get("NLTK_DATA", cache_dir)
# update nltk path for nltk so that it finds the data
if nltk_data_dir not in nltk.data.path:
nltk.data.path.append(nltk_data_dir)
try:
nltk.data.find("corpora/stopwords")
except LookupError:
nltk.download("stopwords", download_dir=nltk_data_dir)
self._stopwords = stopwords.words("english")
return self._stopwords
globals_helper = GlobalsHelper()
def get_new_id(d: Set) -> str:
"""Get a new ID."""
while True:
new_id = str(uuid.uuid4())
if new_id not in d:
break
return new_id
def get_new_int_id(d: Set) -> int:
"""Get a new integer ID."""
while True:
new_id = random.randint(0, sys.maxsize)
if new_id not in d:
break
return new_id
@contextmanager
def temp_set_attrs(obj: Any, **kwargs: Any) -> Generator:
"""Temporary setter.
Utility class for setting a temporary value for an attribute on a class.
Taken from: https://tinyurl.com/2p89xymh
"""
prev_values = {k: getattr(obj, k) for k in kwargs}
for k, v in kwargs.items():
setattr(obj, k, v)
try:
yield
finally:
for k, v in prev_values.items():
setattr(obj, k, v)
@dataclass
class ErrorToRetry:
"""Exception types that should be retried.
Args:
exception_cls (Type[Exception]): Class of exception.
check_fn (Optional[Callable[[Any]], bool]]):
A function that takes an exception instance as input and returns
whether to retry.
"""
exception_cls: Type[Exception]
check_fn: Optional[Callable[[Any], bool]] = None
def retry_on_exceptions_with_backoff(
lambda_fn: Callable,
errors_to_retry: List[ErrorToRetry],
max_tries: int = 10,
min_backoff_secs: float = 0.5,
max_backoff_secs: float = 60.0,
) -> Any:
"""Execute lambda function with retries and exponential backoff.
Args:
lambda_fn (Callable): Function to be called and output we want.
errors_to_retry (List[ErrorToRetry]): List of errors to retry.
At least one needs to be provided.
max_tries (int): Maximum number of tries, including the first. Defaults to 10.
min_backoff_secs (float): Minimum amount of backoff time between attempts.
Defaults to 0.5.
max_backoff_secs (float): Maximum amount of backoff time between attempts.
Defaults to 60.
"""
if not errors_to_retry:
raise ValueError("At least one error to retry needs to be provided")
error_checks = {
error_to_retry.exception_cls: error_to_retry.check_fn
for error_to_retry in errors_to_retry
}
exception_class_tuples = tuple(error_checks.keys())
backoff_secs = min_backoff_secs
tries = 0
while True:
try:
return lambda_fn()
except exception_class_tuples as e:
traceback.print_exc()
tries += 1
if tries >= max_tries:
raise
check_fn = error_checks.get(e.__class__)
if check_fn and not check_fn(e):
raise
time.sleep(backoff_secs)
backoff_secs = min(backoff_secs * 2, max_backoff_secs)
def truncate_text(text: str, max_length: int) -> str:
"""Truncate text to a maximum length."""
if len(text) <= max_length:
return text
return text[: max_length - 3] + "..."
def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable:
"""Iterate over an iterable in batches.
>>> list(iter_batch([1,2,3,4,5], 3))
[[1, 2, 3], [4, 5]]
"""
source_iter = iter(iterable)
while source_iter:
b = list(islice(source_iter, size))
if len(b) == 0:
break
yield b
def concat_dirs(dirname: str, basename: str) -> str:
"""
Append basename to dirname, avoiding backslashes when running on windows.
os.path.join(dirname, basename) will add a backslash before dirname if
basename does not end with a slash, so we make sure it does.
"""
dirname += "/" if dirname[-1] != "/" else ""
return os.path.join(dirname, basename)
def get_tqdm_iterable(items: Iterable, show_progress: bool, desc: str) -> Iterable:
"""
Optionally get a tqdm iterable. Ensures tqdm.auto is used.
"""
_iterator = items
if show_progress:
try:
from tqdm.auto import tqdm
return tqdm(items, desc=desc)
except ImportError:
pass
return _iterator
def count_tokens(text: str) -> int:
tokens = globals_helper.tokenizer(text)
return len(tokens)
def get_transformer_tokenizer_fn(model_name: str) -> Callable[[str], List[str]]:
"""
Args:
model_name(str): the model name of the tokenizer.
For instance, fxmarty/tiny-llama-fast-tokenizer.
"""
try:
from transformers import AutoTokenizer
except ImportError:
raise ValueError(
"`transformers` package not found, please run `pip install transformers`"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
return tokenizer.tokenize
def get_cache_dir() -> str:
"""Locate a platform-appropriate cache directory for llama_index,
and create it if it doesn't yet exist.
"""
# User override
if "LLAMA_INDEX_CACHE_DIR" in os.environ:
path = Path(os.environ["LLAMA_INDEX_CACHE_DIR"])
# Linux, Unix, AIX, etc.
elif os.name == "posix" and sys.platform != "darwin":
path = Path("/tmp/llama_index")
# Mac OS
elif sys.platform == "darwin":
path = Path(os.path.expanduser("~"), "Library/Caches/llama_index")
# Windows (hopefully)
else:
local = os.environ.get("LOCALAPPDATA", None) or os.path.expanduser(
"~\\AppData\\Local"
)
path = Path(local, "llama_index")
if not os.path.exists(path):
os.makedirs(
path, exist_ok=True
) # prevents https://github.com/jerryjliu/llama_index/issues/7362
return str(path)
def add_sync_version(func: Any) -> Any:
"""Decorator for adding sync version of an async function. The sync version
is added as a function attribute to the original function, func.
Args:
func(Any): the async function for which a sync variant will be built.
"""
assert asyncio.iscoroutinefunction(func)
@wraps(func)
def _wrapper(*args: Any, **kwds: Any) -> Any:
return asyncio.get_event_loop().run_until_complete(func(*args, **kwds))
func.sync = _wrapper
return func
# Sample text from llama_index's readme
SAMPLE_TEXT = """
Context
LLMs are a phenomenal piece of technology for knowledge generation and reasoning.
They are pre-trained on large amounts of publicly available data.
How do we best augment LLMs with our own private data?
We need a comprehensive toolkit to help perform this data augmentation for LLMs.
Proposed Solution
That's where LlamaIndex comes in. LlamaIndex is a "data framework" to help
you build LLM apps. It provides the following tools:
Offers data connectors to ingest your existing data sources and data formats
(APIs, PDFs, docs, SQL, etc.)
Provides ways to structure your data (indices, graphs) so that this data can be
easily used with LLMs.
Provides an advanced retrieval/query interface over your data:
Feed in any LLM input prompt, get back retrieved context and knowledge-augmented output.
Allows easy integrations with your outer application framework
(e.g. with LangChain, Flask, Docker, ChatGPT, anything else).
LlamaIndex provides tools for both beginner users and advanced users.
Our high-level API allows beginner users to use LlamaIndex to ingest and
query their data in 5 lines of code. Our lower-level APIs allow advanced users to
customize and extend any module (data connectors, indices, retrievers, query engines,
reranking modules), to fit their needs.
"""
_LLAMA_INDEX_COLORS = {
"llama_pink": "38;2;237;90;200",
"llama_blue": "38;2;90;149;237",
"llama_turquoise": "38;2;11;159;203",
"llama_lavender": "38;2;155;135;227",
}
_ANSI_COLORS = {
"red": "31",
"green": "32",
"yellow": "33",
"blue": "34",
"magenta": "35",
"cyan": "36",
"pink": "38;5;200",
}
def get_color_mapping(
items: List[str], use_llama_index_colors: bool = True
) -> Dict[str, str]:
"""
Get a mapping of items to colors.
Args:
items (List[str]): List of items to be mapped to colors.
use_llama_index_colors (bool, optional): Flag to indicate
whether to use LlamaIndex colors or ANSI colors.
Defaults to True.
Returns:
Dict[str, str]: Mapping of items to colors.
"""
if use_llama_index_colors:
color_palette = _LLAMA_INDEX_COLORS
else:
color_palette = _ANSI_COLORS
colors = list(color_palette.keys())
return {item: colors[i % len(colors)] for i, item in enumerate(items)}
def _get_colored_text(text: str, color: str) -> str:
"""
Get the colored version of the input text.
Args:
text (str): Input text.
color (str): Color to be applied to the text.
Returns:
str: Colored version of the input text.
"""
all_colors = {**_LLAMA_INDEX_COLORS, **_ANSI_COLORS}
if color not in all_colors:
return f"\033[1;3m{text}\033[0m" # just bolded and italicized
color = all_colors[color]
return f"\033[1;3;{color}m{text}\033[0m"
def print_text(text: str, color: Optional[str] = None, end: str = "") -> None:
"""
Print the text with the specified color.
Args:
text (str): Text to be printed.
color (str, optional): Color to be applied to the text. Supported colors are:
llama_pink, llama_blue, llama_turquoise, llama_lavender,
red, green, yellow, blue, magenta, cyan, pink.
end (str, optional): String appended after the last character of the text.
Returns:
None
"""
text_to_print = _get_colored_text(text, color) if color is not None else text
print(text_to_print, end=end)
def infer_torch_device() -> str:
"""Infer the input to torch.device."""
try:
has_cuda = torch.cuda.is_available()
except NameError:
import torch
has_cuda = torch.cuda.is_available()
if has_cuda:
return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"
def unit_generator(x: Any) -> Generator[Any, None, None]:
"""A function that returns a generator of a single element.
Args:
x (Any): the element to build yield
Yields:
Any: the single element
"""
yield x
async def async_unit_generator(x: Any) -> AsyncGenerator[Any, None]:
"""A function that returns a generator of a single element.
Args:
x (Any): the element to build yield
Yields:
Any: the single element
"""
yield x

14
dbgpt/util/json_utils.py Normal file
View File

@@ -0,0 +1,14 @@
import json
from datetime import date, datetime
def serialize(obj):
if isinstance(obj, date):
return obj.isoformat()
class DateTimeEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat()
return super().default(obj)

View File

@@ -0,0 +1,11 @@
from typing import Any
from pympler import asizeof
def _get_object_bytes(obj: Any) -> int:
"""Get the bytes of a object in memory
Args:
obj (Any): The object to return the bytes
"""
return asizeof.asizeof(obj)

84
dbgpt/util/model_utils.py Normal file
View File

@@ -0,0 +1,84 @@
from typing import List, Tuple
from dataclasses import dataclass
import logging
logger = logging.getLogger(__name__)
def _clear_model_cache(device="cuda"):
try:
# clear torch cache
import torch
_clear_torch_cache(device)
except ImportError:
logger.warn("Torch not installed, skip clear torch cache")
# TODO clear other cache
def _clear_torch_cache(device="cuda"):
import torch
import gc
gc.collect()
if device != "cpu":
if torch.has_mps:
try:
from torch.mps import empty_cache
empty_cache()
except Exception as e:
logger.warn(f"Clear mps torch cache error, {str(e)}")
elif torch.has_cuda:
device_count = torch.cuda.device_count()
for device_id in range(device_count):
cuda_device = f"cuda:{device_id}"
logger.info(f"Clear torch cache of device: {cuda_device}")
with torch.cuda.device(cuda_device):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
else:
logger.info("No cuda or mps, not support clear torch cache yet")
@dataclass
class GPUInfo:
total_memory_gb: float
allocated_memory_gb: float
cached_memory_gb: float
available_memory_gb: float
def _get_current_cuda_memory() -> List[GPUInfo]:
try:
import torch
except ImportError:
logger.warn("Torch not installed")
return []
if torch.cuda.is_available():
num_gpus = torch.cuda.device_count()
gpu_infos = []
for gpu_id in range(num_gpus):
with torch.cuda.device(gpu_id):
device = torch.cuda.current_device()
gpu_properties = torch.cuda.get_device_properties(device)
total_memory = round(gpu_properties.total_memory / (1.0 * 1024**3), 2)
allocated_memory = round(
torch.cuda.memory_allocated() / (1.0 * 1024**3), 2
)
cached_memory = round(
torch.cuda.memory_reserved() / (1.0 * 1024**3), 2
)
available_memory = total_memory - allocated_memory
gpu_infos.append(
GPUInfo(
total_memory_gb=total_memory,
allocated_memory_gb=allocated_memory,
cached_memory_gb=cached_memory,
available_memory_gb=available_memory,
)
)
return gpu_infos
else:
logger.warn("CUDA is not available.")
return []

View File

@@ -0,0 +1,28 @@
from typing import Type
from importlib import import_module
def import_from_string(module_path: str, ignore_import_error: bool = False):
try:
module_path, class_name = module_path.rsplit(".", 1)
except ValueError:
raise ImportError(f"{module_path} doesn't look like a module path")
module = import_module(module_path)
try:
return getattr(module, class_name)
except AttributeError:
if ignore_import_error:
return None
raise ImportError(
f'Module "{module_path}" does not define a "{class_name}" attribute/class'
)
def import_from_checked_string(module_path: str, supper_cls: Type):
cls = import_from_string(module_path)
if not issubclass(cls, supper_cls):
raise ImportError(
f'Module "{module_path}" does not the subclass of {str(supper_cls)}'
)
return cls

24
dbgpt/util/net_utils.py Normal file
View File

@@ -0,0 +1,24 @@
import socket
import errno
def _get_ip_address(address: str = "10.254.254.254:1") -> str:
ip, port = address.split(":")
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.settimeout(0)
curr_address = "127.0.0.1"
try:
# doesn't even have to be reachable
s.connect((ip, int(port)))
curr_address = s.getsockname()[0]
except OSError as e:
IP = "127.0.0.1"
if e.errno == errno.ENETUNREACH:
try:
hostname = socket.getfqdn(socket.gethostname())
curr_address = socket.gethostbyname(hostname)
except Exception:
pass
finally:
s.close()
return curr_address

View File

@@ -0,0 +1,99 @@
from typing import Dict, Any, Awaitable, Callable, Optional, Iterator
import httpx
import asyncio
import logging
import json
logger = logging.getLogger(__name__)
MessageCaller = Callable[[str], Awaitable[None]]
async def _do_chat_completion(
url: str,
chat_data: Dict[str, Any],
client: httpx.AsyncClient,
headers: Dict[str, Any] = {},
timeout: int = 60,
caller: Optional[MessageCaller] = None,
) -> Iterator[str]:
async with client.stream(
"POST",
url,
headers=headers,
json=chat_data,
timeout=timeout,
) as res:
if res.status_code != 200:
error_message = await res.aread()
if error_message:
error_message = error_message.decode("utf-8")
logger.error(
f"Request failed with status {res.status_code}. Error: {error_message}"
)
raise httpx.RequestError(
f"Request failed with status {res.status_code}",
request=res.request,
)
async for line in res.aiter_lines():
if line:
if not line.startswith("data: "):
if caller:
await caller(line)
yield line
else:
decoded_line = line.split("data: ", 1)[1]
if decoded_line.lower().strip() != "[DONE]".lower():
obj = json.loads(decoded_line)
if obj["choices"][0]["delta"].get("content") is not None:
text = obj["choices"][0]["delta"].get("content")
if caller:
await caller(text)
yield text
await asyncio.sleep(0.02)
async def chat_completion_stream(
url: str,
chat_data: Dict[str, Any],
client: Optional[httpx.AsyncClient] = None,
headers: Dict[str, Any] = {},
timeout: int = 60,
caller: Optional[MessageCaller] = None,
) -> Iterator[str]:
if client:
async for text in _do_chat_completion(
url,
chat_data,
client=client,
headers=headers,
timeout=timeout,
caller=caller,
):
yield text
else:
async with httpx.AsyncClient() as client:
async for text in _do_chat_completion(
url,
chat_data,
client=client,
headers=headers,
timeout=timeout,
caller=caller,
):
yield text
async def chat_completion(
url: str,
chat_data: Dict[str, Any],
client: Optional[httpx.AsyncClient] = None,
headers: Dict[str, Any] = {},
timeout: int = 60,
caller: Optional[MessageCaller] = None,
) -> str:
full_text = ""
async for text in chat_completion_stream(
url, chat_data, client, headers=headers, timeout=timeout, caller=caller
):
full_text += text
return full_text

View File

@@ -0,0 +1,740 @@
import argparse
import os
from dataclasses import dataclass, fields, MISSING, asdict, field, is_dataclass
from typing import Any, List, Optional, Type, Union, Callable, Dict, TYPE_CHECKING
from collections import OrderedDict
if TYPE_CHECKING:
from dbgpt._private.pydantic import BaseModel
MISSING_DEFAULT_VALUE = "__MISSING_DEFAULT_VALUE__"
@dataclass
class ParameterDescription:
param_class: str
param_name: str
param_type: str
default_value: Optional[Any]
description: str
required: Optional[bool]
valid_values: Optional[List[Any]]
ext_metadata: Dict
@dataclass
class BaseParameters:
@classmethod
def from_dict(
cls, data: dict, ignore_extra_fields: bool = False
) -> "BaseParameters":
"""Create an instance of the dataclass from a dictionary.
Args:
data: A dictionary containing values for the dataclass fields.
ignore_extra_fields: If True, any extra fields in the data dictionary that are
not part of the dataclass will be ignored.
If False, extra fields will raise an error. Defaults to False.
Returns:
An instance of the dataclass with values populated from the given dictionary.
Raises:
TypeError: If `ignore_extra_fields` is False and there are fields in the
dictionary that aren't present in the dataclass.
"""
all_field_names = {f.name for f in fields(cls)}
if ignore_extra_fields:
data = {key: value for key, value in data.items() if key in all_field_names}
else:
extra_fields = set(data.keys()) - all_field_names
if extra_fields:
raise TypeError(f"Unexpected fields: {', '.join(extra_fields)}")
return cls(**data)
def update_from(self, source: Union["BaseParameters", dict]) -> bool:
"""
Update the attributes of this object using the values from another object (of the same or parent type) or a dictionary.
Only update if the new value is different from the current value and the field is not marked as "fixed" in metadata.
Args:
source (Union[BaseParameters, dict]): The source to update from. Can be another object of the same type or a dictionary.
Returns:
bool: True if at least one field was updated, otherwise False.
"""
updated = False # Flag to indicate whether any field was updated
if isinstance(source, (BaseParameters, dict)):
for field_info in fields(self):
# Check if the field has a "fixed" tag in metadata
tags = field_info.metadata.get("tags")
tags = [] if not tags else tags.split(",")
if tags and "fixed" in tags:
continue # skip this field
# Get the new value from source (either another BaseParameters object or a dict)
new_value = (
getattr(source, field_info.name)
if isinstance(source, BaseParameters)
else source.get(field_info.name, None)
)
# If the new value is not None and different from the current value, update the field and set the flag
if new_value is not None and new_value != getattr(
self, field_info.name
):
setattr(self, field_info.name, new_value)
updated = True
else:
raise ValueError(
"Source must be an instance of BaseParameters (or its derived class) or a dictionary."
)
return updated
def __str__(self) -> str:
return _get_dataclass_print_str(self)
def to_command_args(self, args_prefix: str = "--") -> List[str]:
"""Convert the fields of the dataclass to a list of command line arguments.
Args:
args_prefix: args prefix
Returns:
A list of strings where each field is represented by two items:
one for the field name prefixed by args_prefix, and one for its value.
"""
return _dict_to_command_args(asdict(self), args_prefix=args_prefix)
def _get_dataclass_print_str(obj):
class_name = obj.__class__.__name__
parameters = [
f"\n\n=========================== {class_name} ===========================\n"
]
for field_info in fields(obj):
value = _get_simple_privacy_field_value(obj, field_info)
parameters.append(f"{field_info.name}: {value}")
parameters.append(
"\n======================================================================\n\n"
)
return "\n".join(parameters)
def _dict_to_command_args(obj: Dict, args_prefix: str = "--") -> List[str]:
"""Convert dict to a list of command line arguments
Args:
obj: dict
Returns:
A list of strings where each field is represented by two items:
one for the field name prefixed by args_prefix, and one for its value.
"""
args = []
for key, value in obj.items():
if value is None:
continue
args.append(f"{args_prefix}{key}")
args.append(str(value))
return args
def _get_simple_privacy_field_value(obj, field_info):
"""Retrieve the value of a field from a dataclass instance, applying privacy rules if necessary.
This function reads the metadata of a field to check if it's tagged with 'privacy'.
If the 'privacy' tag is present, then it modifies the value based on its type
for privacy concerns:
- int: returns -999
- float: returns -999.0
- bool: returns False
- str: if length > 5, masks the middle part and returns first and last char;
otherwise, returns "******"
Args:
obj: The dataclass instance.
field_info: A Field object that contains information about the dataclass field.
Returns:
The original or modified value of the field based on the privacy rules.
Example usage:
@dataclass
class Person:
name: str
age: int
ssn: str = field(metadata={"tags": "privacy"})
p = Person("Alice", 30, "123-45-6789")
print(_get_simple_privacy_field_value(p, Person.ssn)) # A******9
"""
tags = field_info.metadata.get("tags")
tags = [] if not tags else tags.split(",")
is_privacy = False
if tags and "privacy" in tags:
is_privacy = True
value = getattr(obj, field_info.name)
if not is_privacy or not value:
return value
field_type = EnvArgumentParser._get_argparse_type(field_info.type)
if field_type is int:
return -999
if field_type is float:
return -999.0
if field_type is bool:
return False
# str
if len(value) > 5:
return value[0] + "******" + value[-1]
return "******"
def _genenv_ignoring_key_case(env_key: str, env_prefix: str = None, default_value=None):
"""Get the value from the environment variable, ignoring the case of the key"""
if env_prefix:
env_key = env_prefix + env_key
return os.getenv(
env_key, os.getenv(env_key.upper(), os.getenv(env_key.lower(), default_value))
)
def _genenv_ignoring_key_case_with_prefixes(
env_key: str, env_prefixes: List[str] = None, default_value=None
) -> str:
if env_prefixes:
for env_prefix in env_prefixes:
env_var_value = _genenv_ignoring_key_case(env_key, env_prefix)
if env_var_value:
return env_var_value
return _genenv_ignoring_key_case(env_key, default_value=default_value)
class EnvArgumentParser:
@staticmethod
def get_env_prefix(env_key: str) -> str:
if not env_key:
return None
env_key = env_key.replace("-", "_")
return env_key + "_"
def parse_args_into_dataclass(
self,
dataclass_type: Type,
env_prefixes: List[str] = None,
command_args: List[str] = None,
**kwargs,
) -> Any:
"""Parse parameters from environment variables and command lines and populate them into data class"""
parser = argparse.ArgumentParser()
for field in fields(dataclass_type):
env_var_value = _genenv_ignoring_key_case_with_prefixes(
field.name, env_prefixes
)
if env_var_value:
env_var_value = env_var_value.strip()
if field.type is int or field.type == Optional[int]:
env_var_value = int(env_var_value)
elif field.type is float or field.type == Optional[float]:
env_var_value = float(env_var_value)
elif field.type is bool or field.type == Optional[bool]:
env_var_value = env_var_value.lower() == "true"
elif field.type is str or field.type == Optional[str]:
pass
else:
raise ValueError(f"Unsupported parameter type {field.type}")
if not env_var_value:
env_var_value = kwargs.get(field.name)
# print(f"env_var_value: {env_var_value} for {field.name}")
# Add a command-line argument for this field
EnvArgumentParser._build_single_argparse_option(
parser, field, env_var_value
)
# Parse the command-line arguments
cmd_args, cmd_argv = parser.parse_known_args(args=command_args)
# cmd_args = parser.parse_args(args=command_args)
# print(f"cmd_args: {cmd_args}")
for field in fields(dataclass_type):
# cmd_line_value = getattr(cmd_args, field.name)
if field.name in cmd_args:
cmd_line_value = getattr(cmd_args, field.name)
if cmd_line_value is not None:
kwargs[field.name] = cmd_line_value
return dataclass_type(**kwargs)
@staticmethod
def _create_arg_parser(dataclass_type: Type) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description=dataclass_type.__doc__)
for field in fields(dataclass_type):
help_text = field.metadata.get("help", "")
valid_values = field.metadata.get("valid_values", None)
argument_kwargs = {
"type": EnvArgumentParser._get_argparse_type(field.type),
"help": help_text,
"choices": valid_values,
"required": EnvArgumentParser._is_require_type(field.type),
}
if field.default != MISSING:
argument_kwargs["default"] = field.default
argument_kwargs["required"] = False
parser.add_argument(f"--{field.name}", **argument_kwargs)
return parser
@staticmethod
def _create_click_option_from_field(field_name: str, field: Type, is_func=True):
import click
help_text = field.metadata.get("help", "")
valid_values = field.metadata.get("valid_values", None)
cli_params = {
"default": None if field.default is MISSING else field.default,
"help": help_text,
"show_default": True,
"required": field.default is MISSING,
}
if valid_values:
cli_params["type"] = click.Choice(valid_values)
real_type = EnvArgumentParser._get_argparse_type(field.type)
if real_type is int:
cli_params["type"] = click.INT
elif real_type is float:
cli_params["type"] = click.FLOAT
elif real_type is str:
cli_params["type"] = click.STRING
elif real_type is bool:
cli_params["is_flag"] = True
name = f"--{field_name}"
if is_func:
return click.option(
name,
**cli_params,
)
else:
return click.Option([name], **cli_params)
@staticmethod
def create_click_option(
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
):
import functools
from collections import OrderedDict
combined_fields = OrderedDict()
if _dynamic_factory:
_types = _dynamic_factory()
if _types:
dataclass_types = list(_types)
for dataclass_type in dataclass_types:
for field in fields(dataclass_type):
if field.name not in combined_fields:
combined_fields[field.name] = field
def decorator(func):
for field_name, field in reversed(combined_fields.items()):
option_decorator = EnvArgumentParser._create_click_option_from_field(
field_name, field
)
func = option_decorator(func)
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
return decorator
@staticmethod
def _create_raw_click_option(
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
):
combined_fields = _merge_dataclass_types(
*dataclass_types, _dynamic_factory=_dynamic_factory
)
options = []
for field_name, field in reversed(combined_fields.items()):
options.append(
EnvArgumentParser._create_click_option_from_field(
field_name, field, is_func=False
)
)
return options
@staticmethod
def create_argparse_option(
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
) -> argparse.ArgumentParser:
combined_fields = _merge_dataclass_types(
*dataclass_types, _dynamic_factory=_dynamic_factory
)
parser = argparse.ArgumentParser()
for _, field in reversed(combined_fields.items()):
EnvArgumentParser._build_single_argparse_option(parser, field)
return parser
@staticmethod
def _build_single_argparse_option(
parser: argparse.ArgumentParser, field, default_value=None
):
# Add a command-line argument for this field
help_text = field.metadata.get("help", "")
valid_values = field.metadata.get("valid_values", None)
short_name = field.metadata.get("short", None)
argument_kwargs = {
"type": EnvArgumentParser._get_argparse_type(field.type),
"help": help_text,
"choices": valid_values,
"required": EnvArgumentParser._is_require_type(field.type),
}
if field.default != MISSING:
argument_kwargs["default"] = field.default
argument_kwargs["required"] = False
if default_value:
argument_kwargs["default"] = default_value
argument_kwargs["required"] = False
if field.type is bool or field.type == Optional[bool]:
argument_kwargs["action"] = "store_true"
del argument_kwargs["type"]
del argument_kwargs["choices"]
names = []
if short_name:
names.append(f"-{short_name}")
names.append(f"--{field.name}")
parser.add_argument(*names, **argument_kwargs)
@staticmethod
def _get_argparse_type(field_type: Type) -> Type:
# Return the appropriate type for argparse to use based on the field type
if field_type is int or field_type == Optional[int]:
return int
elif field_type is float or field_type == Optional[float]:
return float
elif field_type is bool or field_type == Optional[bool]:
return bool
elif field_type is str or field_type == Optional[str]:
return str
else:
raise ValueError(f"Unsupported parameter type {field_type}")
@staticmethod
def _get_argparse_type_str(field_type: Type) -> str:
argparse_type = EnvArgumentParser._get_argparse_type(field_type)
if argparse_type is int:
return "int"
elif argparse_type is float:
return "float"
elif argparse_type is bool:
return "bool"
else:
return "str"
@staticmethod
def _is_require_type(field_type: Type) -> str:
return field_type not in [Optional[int], Optional[float], Optional[bool]]
@staticmethod
def _kwargs_to_env_key_value(
kwargs: Dict, prefix: str = "__dbgpt_gunicorn__env_prefix__"
) -> Dict[str, str]:
return {prefix + k: str(v) for k, v in kwargs.items()}
@staticmethod
def _read_env_key_value(
prefix: str = "__dbgpt_gunicorn__env_prefix__",
) -> List[str]:
env_args = []
for key, value in os.environ.items():
if key.startswith(prefix):
arg_key = "--" + key.replace(prefix, "")
if value.lower() in ["true", "1"]:
# Flag args
env_args.append(arg_key)
elif not value.lower() in ["false", "0"]:
env_args.extend([arg_key, value])
return env_args
def _merge_dataclass_types(
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
) -> OrderedDict:
combined_fields = OrderedDict()
if _dynamic_factory:
_types = _dynamic_factory()
if _types:
dataclass_types = list(_types)
for dataclass_type in dataclass_types:
for field in fields(dataclass_type):
if field.name not in combined_fields:
combined_fields[field.name] = field
return combined_fields
def _type_str_to_python_type(type_str: str) -> Type:
type_mapping: Dict[str, Type] = {
"int": int,
"float": float,
"bool": bool,
"str": str,
}
return type_mapping.get(type_str, str)
def _get_parameter_descriptions(
dataclass_type: Type, **kwargs
) -> List[ParameterDescription]:
descriptions = []
for field in fields(dataclass_type):
ext_metadata = {
k: v for k, v in field.metadata.items() if k not in ["help", "valid_values"]
}
default_value = field.default if field.default != MISSING else None
if field.name in kwargs:
default_value = kwargs[field.name]
descriptions.append(
ParameterDescription(
param_class=f"{dataclass_type.__module__}.{dataclass_type.__name__}",
param_name=field.name,
param_type=EnvArgumentParser._get_argparse_type_str(field.type),
description=field.metadata.get("help", None),
required=field.default is MISSING,
default_value=default_value,
valid_values=field.metadata.get("valid_values", None),
ext_metadata=ext_metadata,
)
)
return descriptions
def _build_parameter_class(desc: List[ParameterDescription]) -> Type:
from dbgpt.util.module_utils import import_from_string
if not desc:
raise ValueError("Parameter descriptions cant be empty")
param_class_str = desc[0].param_class
if param_class_str:
param_class = import_from_string(param_class_str, ignore_import_error=True)
if param_class:
return param_class
module_name, _, class_name = param_class_str.rpartition(".")
fields_dict = {} # This will store field names and their default values or field()
annotations = {} # This will store the type annotations for the fields
for d in desc:
metadata = d.ext_metadata if d.ext_metadata else {}
metadata["help"] = d.description
metadata["valid_values"] = d.valid_values
annotations[d.param_name] = _type_str_to_python_type(
d.param_type
) # Set type annotation
fields_dict[d.param_name] = field(default=d.default_value, metadata=metadata)
# Create the new class. Note the setting of __annotations__ for type hints
new_class = type(
class_name, (object,), {**fields_dict, "__annotations__": annotations}
)
result_class = dataclass(new_class) # Make it a dataclass
return result_class
def _extract_parameter_details(
parser: argparse.ArgumentParser,
param_class: str = None,
skip_names: List[str] = None,
overwrite_default_values: Dict = {},
) -> List[ParameterDescription]:
descriptions = []
for action in parser._actions:
if (
action.default == argparse.SUPPRESS
): # typically this means the argument was not provided
continue
# determine parameter class (store_true/store_false are flags)
flag_or_option = (
"flag" if isinstance(action, argparse._StoreConstAction) else "option"
)
# extract parameter name (use the first option string, typically the long form)
param_name = action.option_strings[0] if action.option_strings else action.dest
if param_name.startswith("--"):
param_name = param_name[2:]
if param_name.startswith("-"):
param_name = param_name[1:]
param_name = param_name.replace("-", "_")
if skip_names and param_name in skip_names:
continue
# gather other details
default_value = action.default
if param_name in overwrite_default_values:
default_value = overwrite_default_values[param_name]
arg_type = (
action.type if not callable(action.type) else str(action.type.__name__)
)
description = action.help
# determine if the argument is required
required = action.required
# extract valid values for choices, if provided
valid_values = action.choices if action.choices is not None else None
# set ext_metadata as an empty dict for now, can be updated later if needed
ext_metadata = {}
descriptions.append(
ParameterDescription(
param_class=param_class,
param_name=param_name,
param_type=arg_type,
default_value=default_value,
description=description,
required=required,
valid_values=valid_values,
ext_metadata=ext_metadata,
)
)
return descriptions
def _get_dict_from_obj(obj, default_value=None) -> Optional[Dict]:
if not obj:
return None
if is_dataclass(type(obj)):
params = {}
for field_info in fields(obj):
value = _get_simple_privacy_field_value(obj, field_info)
params[field_info.name] = value
return params
if isinstance(obj, dict):
return obj
return default_value
def _get_base_model_descriptions(model_cls: "BaseModel") -> List[ParameterDescription]:
from dbgpt._private import pydantic
version = int(pydantic.VERSION.split(".")[0])
schema = model_cls.model_json_schema() if version >= 2 else model_cls.schema()
required_fields = set(schema.get("required", []))
param_descs = []
for field_name, field_schema in schema.get("properties", {}).items():
field = model_cls.model_fields[field_name]
param_type = field_schema.get("type")
if not param_type and "anyOf" in field_schema:
for any_of in field_schema["anyOf"]:
if any_of["type"] != "null":
param_type = any_of["type"]
break
if version >= 2:
default_value = (
field.default
if hasattr(field, "default")
and str(field.default) != "PydanticUndefined"
else None
)
else:
default_value = (
field.default
if not field.allow_none
else (
field.default_factory() if callable(field.default_factory) else None
)
)
description = field_schema.get("description", "")
is_required = field_name in required_fields
valid_values = None
ext_metadata = None
if hasattr(field, "field_info"):
valid_values = (
list(field.field_info.choices)
if hasattr(field.field_info, "choices")
else None
)
ext_metadata = (
field.field_info.extra if hasattr(field.field_info, "extra") else None
)
param_class = (f"{model_cls.__module__}.{model_cls.__name__}",)
param_desc = ParameterDescription(
param_class=param_class,
param_name=field_name,
param_type=param_type,
default_value=default_value,
description=description,
required=is_required,
valid_values=valid_values,
ext_metadata=ext_metadata,
)
param_descs.append(param_desc)
return param_descs
class _SimpleArgParser:
def __init__(self, *args):
self.params = {arg.replace("_", "-"): None for arg in args}
def parse(self, args=None):
import sys
if args is None:
args = sys.argv[1:]
else:
args = list(args)
prev_arg = None
for arg in args:
if arg.startswith("--"):
if prev_arg:
self.params[prev_arg] = None
prev_arg = arg[2:]
else:
if prev_arg:
self.params[prev_arg] = arg
prev_arg = None
if prev_arg:
self.params[prev_arg] = None
def _get_param(self, key):
return self.params.get(key.replace("_", "-")) or self.params.get(key)
def __getattr__(self, item):
return self._get_param(item)
def __getitem__(self, key):
return self._get_param(key)
def get(self, key, default=None):
return self._get_param(key) or default
def __str__(self):
return "\n".join(
[f'{key.replace("-", "_")}: {value}' for key, value in self.params.items()]
)
def build_lazy_click_command(*dataclass_types: Type, _dynamic_factory=None):
import click
class LazyCommand(click.Command):
def __init__(self, *args, **kwargs):
super(LazyCommand, self).__init__(*args, **kwargs)
self.dynamic_params_added = False
def get_params(self, ctx):
if ctx and not self.dynamic_params_added:
dynamic_params = EnvArgumentParser._create_raw_click_option(
*dataclass_types, _dynamic_factory=_dynamic_factory
)
for param in reversed(dynamic_params):
self.params.append(param)
self.dynamic_params_added = True
return super(LazyCommand, self).get_params(ctx)
return LazyCommand

6
dbgpt/util/path_utils.py Normal file
View File

@@ -0,0 +1,6 @@
import os
def has_path(filename):
directory = os.path.dirname(filename)
return bool(directory)

6
dbgpt/util/pd_utils.py Normal file
View File

@@ -0,0 +1,6 @@
def csv_colunm_foramt(val):
if str(val).find("$") >= 0:
return float(val.replace("$", "").replace(",", ""))
if str(val).find("¥") >= 0:
return float(val.replace("¥", "").replace(",", ""))
return val

239
dbgpt/util/prompt_util.py Normal file
View File

@@ -0,0 +1,239 @@
"""General prompt helper that can help deal with LLM context window token limitations.
At its core, it calculates available context size by starting with the context window
size of an LLM and reserve token space for the prompt template, and the output.
It provides utility for "repacking" text chunks (retrieved from index) to maximally
make use of the available context window (and thereby reducing the number of LLM calls
needed), or truncating them so that they fit in a single LLM call.
"""
import logging
from string import Formatter
from typing import Callable, List, Optional, Sequence
from dbgpt._private.pydantic import Field, PrivateAttr, BaseModel
from dbgpt.util.global_helper import globals_helper
from dbgpt._private.llm_metadata import LLMMetadata
from dbgpt.rag.embedding_engine.loader.token_splitter import TokenTextSplitter
DEFAULT_PADDING = 5
DEFAULT_CHUNK_OVERLAP_RATIO = 0.1
DEFAULT_CONTEXT_WINDOW = 3000 # tokens
DEFAULT_NUM_OUTPUTS = 256 # tokens
logger = logging.getLogger(__name__)
class PromptHelper(BaseModel):
"""Prompt helper.
General prompt helper that can help deal with LLM context window token limitations.
At its core, it calculates available context size by starting with the context
window size of an LLM and reserve token space for the prompt template, and the
output.
It provides utility for "repacking" text chunks (retrieved from index) to maximally
make use of the available context window (and thereby reducing the number of LLM
calls needed), or truncating them so that they fit in a single LLM call.
Args:
context_window (int): Context window for the LLM.
num_output (int): Number of outputs for the LLM.
chunk_overlap_ratio (float): Chunk overlap as a ratio of chunk size
chunk_size_limit (Optional[int]): Maximum chunk size to use.
tokenizer (Optional[Callable[[str], List]]): Tokenizer to use.
separator (str): Separator for text splitter
"""
context_window: int = Field(
default=DEFAULT_CONTEXT_WINDOW,
description="The maximum context size that will get sent to the LLM.",
)
num_output: int = Field(
default=DEFAULT_NUM_OUTPUTS,
description="The amount of token-space to leave in input for generation.",
)
chunk_overlap_ratio: float = Field(
default=DEFAULT_CHUNK_OVERLAP_RATIO,
description="The percentage token amount that each chunk should overlap.",
)
chunk_size_limit: Optional[int] = Field(description="The maximum size of a chunk.")
separator: str = Field(
default=" ", description="The separator when chunking tokens."
)
_tokenizer: Callable[[str], List] = PrivateAttr()
def __init__(
self,
context_window: int = DEFAULT_CONTEXT_WINDOW,
num_output: int = DEFAULT_NUM_OUTPUTS,
chunk_overlap_ratio: float = DEFAULT_CHUNK_OVERLAP_RATIO,
chunk_size_limit: Optional[int] = None,
tokenizer: Optional[Callable[[str], List]] = None,
separator: str = " ",
) -> None:
"""Init params."""
if chunk_overlap_ratio > 1.0 or chunk_overlap_ratio < 0.0:
raise ValueError("chunk_overlap_ratio must be a float between 0. and 1.")
# TODO: make configurable
self._tokenizer = tokenizer or globals_helper.tokenizer
super().__init__(
context_window=context_window,
num_output=num_output,
chunk_overlap_ratio=chunk_overlap_ratio,
chunk_size_limit=chunk_size_limit,
separator=separator,
)
@classmethod
def from_llm_metadata(
cls,
llm_metadata: LLMMetadata,
chunk_overlap_ratio: float = DEFAULT_CHUNK_OVERLAP_RATIO,
chunk_size_limit: Optional[int] = None,
tokenizer: Optional[Callable[[str], List]] = None,
separator: str = " ",
) -> "PromptHelper":
"""Create from llm predictor.
This will autofill values like context_window and num_output.
"""
context_window = llm_metadata.context_window
if llm_metadata.num_output == -1:
num_output = DEFAULT_NUM_OUTPUTS
else:
num_output = llm_metadata.num_output
return cls(
context_window=context_window,
num_output=num_output,
chunk_overlap_ratio=chunk_overlap_ratio,
chunk_size_limit=chunk_size_limit,
tokenizer=tokenizer,
separator=separator,
)
@classmethod
def class_name(cls) -> str:
return "PromptHelper"
def _get_available_context_size(self, template: str) -> int:
"""Get available context size.
This is calculated as:
available context window = total context window
- input (partially filled prompt)
- output (room reserved for response)
Notes:
- Available context size is further clamped to be non-negative.
"""
empty_prompt_txt = get_empty_prompt_txt(template)
num_empty_prompt_tokens = len(self._tokenizer(empty_prompt_txt))
context_size_tokens = (
self.context_window - num_empty_prompt_tokens - self.num_output
)
if context_size_tokens < 0:
raise ValueError(
f"Calculated available context size {context_size_tokens} was"
" not non-negative."
)
return context_size_tokens
def _get_available_chunk_size(
self, prompt_template: str, num_chunks: int = 1, padding: int = 5
) -> int:
"""Get available chunk size.
This is calculated as:
available chunk size = available context window // number_chunks
- padding
Notes:
- By default, we use padding of 5 (to save space for formatting needs).
- Available chunk size is further clamped to chunk_size_limit if specified.
"""
available_context_size = self._get_available_context_size(prompt_template)
result = available_context_size // num_chunks - padding
if self.chunk_size_limit is not None:
result = min(result, self.chunk_size_limit)
return result
def get_text_splitter_given_prompt(
self,
prompt_template: str,
num_chunks: int = 1,
padding: int = DEFAULT_PADDING,
) -> TokenTextSplitter:
"""Get text splitter configured to maximally pack available context window,
taking into account of given prompt, and desired number of chunks.
"""
chunk_size = self._get_available_chunk_size(
prompt_template, num_chunks, padding=padding
)
if chunk_size <= 0:
raise ValueError(f"Chunk size {chunk_size} is not positive.")
chunk_overlap = int(self.chunk_overlap_ratio * chunk_size)
return TokenTextSplitter(
separator=self.separator,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
tokenizer=self._tokenizer,
)
def repack(
self,
prompt_template: str,
text_chunks: Sequence[str],
padding: int = DEFAULT_PADDING,
) -> List[str]:
"""Repack text chunks to fit available context window.
This will combine text chunks into consolidated chunks
that more fully "pack" the prompt template given the context_window.
"""
text_splitter = self.get_text_splitter_given_prompt(
prompt_template, padding=padding
)
combined_str = "\n\n".join([c.strip() for c in text_chunks if c.strip()])
return text_splitter.split_text(combined_str)
def get_empty_prompt_txt(template: str) -> str:
"""Get empty prompt text.
Substitute empty strings in parts of the prompt that have
not yet been filled out. Skip variables that have already
been partially formatted. This is used to compute the initial tokens.
"""
# partial_kargs = prompt.kwargs
partial_kargs = {}
template_vars = get_template_vars(template)
empty_kwargs = {v: "" for v in template_vars if v not in partial_kargs}
all_kwargs = {**partial_kargs, **empty_kwargs}
prompt = template.format(**all_kwargs)
return prompt
def get_template_vars(template_str: str) -> List[str]:
"""Get template variables from a template string."""
variables = []
formatter = Formatter()
for _, variable_name, _, _ in formatter.parse(template_str):
if variable_name:
variables.append(variable_name)
return variables

View File

View File

@@ -0,0 +1,44 @@
from abc import ABC, abstractmethod
from typing import Dict, Type
import json
from dbgpt.core.interface.serialization import Serializable, Serializer
JSON_ENCODING = "utf-8"
class JsonSerializable(Serializable, ABC):
@abstractmethod
def to_dict(self) -> Dict:
"""Return the dict of current serializable object"""
def serialize(self) -> bytes:
"""Convert the object into bytes for storage or transmission."""
return json.dumps(self.to_dict(), ensure_ascii=False).encode(JSON_ENCODING)
class JsonSerializer(Serializer):
"""The serializer abstract class for serializing cache keys and values."""
def serialize(self, obj: Serializable) -> bytes:
"""Serialize a cache object.
Args:
obj (Serializable): The object to serialize
"""
return json.dumps(obj.to_dict(), ensure_ascii=False).encode(JSON_ENCODING)
def deserialize(self, data: bytes, cls: Type[Serializable]) -> Serializable:
"""Deserialize data back into a cache object of the specified type.
Args:
data (bytes): The byte array to deserialize
cls (Type[Serializable]): The type of current object
Returns:
Serializable: The serializable object
"""
# Convert bytes back to JSON and then to the specified class
json_data = json.loads(data.decode(JSON_ENCODING))
# Assume that the cls has an __init__ that accepts a dictionary
return cls(**json_data)

24
dbgpt/util/singleton.py Normal file
View File

@@ -0,0 +1,24 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""The singleton metaclass for ensuring only one instance of a class."""
import abc
from typing import Any
class Singleton(abc.ABCMeta, type):
"""Singleton metaclass for ensuring only one instance of a class"""
_instances = {}
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
"""Call method for the singleton metaclass"""
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class AbstractSingleton(abc.ABC, metaclass=Singleton):
"""Abstract singleton class for ensuring only one instance of a class"""
pass

View File

50
dbgpt/util/speech/base.py Normal file
View File

@@ -0,0 +1,50 @@
"""Base class for all voice classes."""
import abc
from threading import Lock
from dbgpt.util.singleton import AbstractSingleton
class VoiceBase(AbstractSingleton):
"""
Base class for all voice classes.
"""
def __init__(self):
"""
Initialize the voice class.
"""
self._url = None
self._headers = None
self._api_key = None
self._voices = []
self._mutex = Lock()
self._setup()
def say(self, text: str, voice_index: int = 0) -> bool:
"""
Say the given text.
Args:
text (str): The text to say.
voice_index (int): The index of the voice to use.
"""
with self._mutex:
return self._speech(text, voice_index)
@abc.abstractmethod
def _setup(self) -> None:
"""
Setup the voices, API key, etc.
"""
pass
@abc.abstractmethod
def _speech(self, text: str, voice_index: int = 0) -> bool:
"""
Play the given text.
Args:
text (str): The text to play.
"""
pass

View File

@@ -0,0 +1,44 @@
import logging
import os
import requests
from dbgpt.util.speech.base import VoiceBase
class BrianSpeech(VoiceBase):
"""Brian speech module for autogpt"""
def _setup(self) -> None:
"""Setup the voices, API key, etc."""
pass
def _speech(self, text: str, _: int = 0) -> bool:
"""Speak text using Brian with the streamelements API
Args:
text (str): The text to speak
Returns:
bool: True if the request was successful, False otherwise
"""
from playsound import playsound
tts_url = (
f"https://api.streamelements.com/kappa/v2/speech?voice=Brian&text={text}"
)
response = requests.get(tts_url)
if response.status_code == 200:
with open("speech.mp3", "wb") as f:
f.write(response.content)
playsound("speech.mp3")
os.remove("speech.mp3")
return True
else:
logging.error(
"Request failed with status code: %s, response content: %s",
response.status_code,
response.content,
)
return False

View File

@@ -0,0 +1,89 @@
"""ElevenLabs speech module"""
import os
import logging
import requests
from dbgpt._private.config import Config
from dbgpt.util.speech.base import VoiceBase
PLACEHOLDERS = {"your-voice-id"}
logger = logging.getLogger(__name__)
class ElevenLabsSpeech(VoiceBase):
"""ElevenLabs speech class"""
def _setup(self) -> None:
"""Set up the voices, API key, etc.
Returns:
None: None
"""
cfg = Config()
default_voices = ["ErXwobaYiN019PkySvjV", "EXAVITQu4vr4xnSDxMaL"]
voice_options = {
"Rachel": "21m00Tcm4TlvDq8ikWAM",
"Domi": "AZnzlk1XvdvUeBnXmlld",
"Bella": "EXAVITQu4vr4xnSDxMaL",
"Antoni": "ErXwobaYiN019PkySvjV",
"Elli": "MF3mGyEYCl7XYWbV9V6O",
"Josh": "TxGEqnHWrfWFTfGW9XjX",
"Arnold": "VR6AewLTigWG4xSOukaG",
"Adam": "pNInz6obpgDQGcFmaJgB",
"Sam": "yoZ06aMxZJJ28mfd3POQ",
}
self._headers = {
"Content-Type": "application/json",
"xi-api_v1-key": cfg.elevenlabs_api_key,
}
self._voices = default_voices.copy()
if cfg.elevenlabs_voice_1_id in voice_options:
cfg.elevenlabs_voice_1_id = voice_options[cfg.elevenlabs_voice_1_id]
if cfg.elevenlabs_voice_2_id in voice_options:
cfg.elevenlabs_voice_2_id = voice_options[cfg.elevenlabs_voice_2_id]
self._use_custom_voice(cfg.elevenlabs_voice_1_id, 0)
self._use_custom_voice(cfg.elevenlabs_voice_2_id, 1)
def _use_custom_voice(self, voice, voice_index) -> None:
"""Use a custom voice if provided and not a placeholder
Args:
voice (str): The voice ID
voice_index (int): The voice index
Returns:
None: None
"""
# Placeholder values that should be treated as empty
if voice and voice not in PLACEHOLDERS:
self._voices[voice_index] = voice
def _speech(self, text: str, voice_index: int = 0) -> bool:
"""Speak text using elevenlabs.io's API
Args:
text (str): The text to speak
voice_index (int, optional): The voice to use. Defaults to 0.
Returns:
bool: True if the request was successful, False otherwise
"""
from playsound import playsound
tts_url = (
f"https://api.elevenlabs.io/v1/text-to-speech/{self._voices[voice_index]}"
)
response = requests.post(tts_url, headers=self._headers, json={"text": text})
if response.status_code == 200:
with open("speech.mpeg", "wb") as f:
f.write(response.content)
playsound("speech.mpeg", True)
os.remove("speech.mpeg")
return True
else:
logger.warn("Request failed with status code:", response.status_code)
logger.info("Response content:", response.content)
return False

23
dbgpt/util/speech/gtts.py Normal file
View File

@@ -0,0 +1,23 @@
""" GTTS Voice. """
import os
import gtts
from dbgpt.util.speech.base import VoiceBase
class GTTSVoice(VoiceBase):
"""GTTS Voice."""
def _setup(self) -> None:
pass
def _speech(self, text: str, _: int = 0) -> bool:
"""Play the given text."""
from playsound import playsound
tts = gtts.gTTS(text)
tts.save("speech.mp3")
playsound("speech.mp3", True)
os.remove("speech.mp3")
return True

View File

@@ -0,0 +1,21 @@
""" MacOS TTS Voice. """
import os
from dbgpt.util.speech.base import VoiceBase
class MacOSTTS(VoiceBase):
"""MacOS TTS Voice."""
def _setup(self) -> None:
pass
def _speech(self, text: str, voice_index: int = 0) -> bool:
"""Play the given text."""
if voice_index == 0:
os.system(f'say "{text}"')
elif voice_index == 1:
os.system(f'say -v "Ava (Premium)" "{text}"')
else:
os.system(f'say -v Samantha "{text}"')
return True

46
dbgpt/util/speech/say.py Normal file
View File

@@ -0,0 +1,46 @@
""" Text to speech module """
import threading
from threading import Semaphore
from dbgpt._private.config import Config
from dbgpt.util.speech.base import VoiceBase
from dbgpt.util.speech.brian import BrianSpeech
from dbgpt.util.speech.eleven_labs import ElevenLabsSpeech
from dbgpt.util.speech.gtts import GTTSVoice
from dbgpt.util.speech.macos_tts import MacOSTTS
_QUEUE_SEMAPHORE = Semaphore(
1
) # The amount of sounds to queue before blocking the main thread
def say_text(text: str, voice_index: int = 0) -> None:
"""Speak the given text using the given voice index"""
cfg = Config()
default_voice_engine, voice_engine = _get_voice_engine(cfg)
def speak() -> None:
success = voice_engine.say(text, voice_index)
if not success:
default_voice_engine.say(text)
_QUEUE_SEMAPHORE.release()
_QUEUE_SEMAPHORE.acquire(True)
thread = threading.Thread(target=speak)
thread.start()
def _get_voice_engine(config: Config) -> tuple[VoiceBase, VoiceBase]:
"""Get the voice engine to use for the given configuration"""
default_voice_engine = GTTSVoice()
if config.elevenlabs_api_key:
voice_engine = ElevenLabsSpeech()
elif config.use_mac_os_tts == "True":
voice_engine = MacOSTTS()
elif config.use_brian_tts == "True":
voice_engine = BrianSpeech()
else:
voice_engine = GTTSVoice()
return default_voice_engine, voice_engine

View File

@@ -0,0 +1,81 @@
import re
def is_all_chinese(text):
### Determine whether the string is pure Chinese
pattern = re.compile(r"^[一-龥]+$")
match = re.match(pattern, text)
return match is not None
def is_number_chinese(text):
### Determine whether the string is numbers and Chinese
pattern = re.compile(r"^[\d一-龥]+$")
match = re.match(pattern, text)
return match is not None
def is_chinese_include_number(text):
### Determine whether the string is pure Chinese or Chinese containing numbers
pattern = re.compile(r"^[一-龥]+[\d一-龥]*$")
match = re.match(pattern, text)
return match is not None
def is_scientific_notation(string):
# 科学计数法的正则表达式
pattern = r"^[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?$"
# 使用正则表达式匹配字符串
match = re.match(pattern, str(string))
# 判断是否匹配成功
if match is not None:
return True
else:
return False
def extract_content(long_string, s1, s2, is_include: bool = False):
# extract text
match_map = {}
start_index = long_string.find(s1)
while start_index != -1:
if is_include:
end_index = long_string.find(s2, start_index + len(s1) + 1)
extracted_content = long_string[start_index : end_index + len(s2)]
else:
end_index = long_string.find(s2, start_index + len(s1))
extracted_content = long_string[start_index + len(s1) : end_index]
if extracted_content:
match_map[start_index] = extracted_content
start_index = long_string.find(s1, start_index + 1)
return match_map
def extract_content_open_ending(long_string, s1, s2, is_include: bool = False):
# extract text open ending
match_map = {}
start_index = long_string.find(s1)
while start_index != -1:
if long_string.find(s2, start_index) <= 0:
end_index = len(long_string)
else:
if is_include:
end_index = long_string.find(s2, start_index + len(s1) + 1)
else:
end_index = long_string.find(s2, start_index + len(s1))
if is_include:
extracted_content = long_string[start_index : end_index + len(s2)]
else:
extracted_content = long_string[start_index + len(s1) : end_index]
if extracted_content:
match_map[start_index] = extracted_content
start_index = long_string.find(s1, start_index + 1)
return match_map
if __name__ == "__main__":
s = "abcd123efghijkjhhh456xxx123aa456yyy123bb456xx123"
s1 = "123"
s2 = "456"
print(extract_content_open_ending(s, s1, s2, True))

272
dbgpt/util/system_utils.py Normal file
View File

@@ -0,0 +1,272 @@
from dataclasses import dataclass, asdict
from enum import Enum
from typing import Tuple, Dict
import os
import platform
import subprocess
import re
from functools import cache
@dataclass
class SystemInfo:
platform: str
distribution: str
python_version: str
cpu: str
cpu_avx: str
memory: str
torch_version: str
device: str
device_version: str
device_count: int
device_other: str
def to_dict(self) -> Dict:
return asdict(self)
class AVXType(Enum):
BASIC = "basic"
AVX = "AVX"
AVX2 = "AVX2"
AVX512 = "AVX512"
@staticmethod
def of_type(avx: str):
for item in AVXType:
if item._value_ == avx:
return item
return None
class OSType(str, Enum):
WINDOWS = "win"
LINUX = "linux"
DARWIN = "darwin"
OTHER = "other"
def get_cpu_avx_support() -> Tuple[OSType, AVXType, str]:
system = platform.system()
os_type = OSType.OTHER
cpu_avx = AVXType.BASIC
env_cpu_avx = AVXType.of_type(os.getenv("DBGPT_LLAMA_CPP_AVX"))
distribution = "Unknown Distribution"
if "windows" in system.lower():
os_type = OSType.WINDOWS
output = "avx2"
distribution = "Windows " + platform.release()
print("Current platform is windows, use avx2 as default cpu architecture")
elif system == "Linux":
os_type = OSType.LINUX
result = subprocess.run(
["lscpu"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
output = result.stdout.decode()
distribution = get_linux_distribution()
elif system == "Darwin":
os_type = OSType.DARWIN
result = subprocess.run(
["sysctl", "-a"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
distribution = "Mac OS " + platform.mac_ver()[0]
output = result.stdout.decode()
else:
os_type = OSType.OTHER
print("Unsupported OS to get cpu avx, use default")
return os_type, env_cpu_avx if env_cpu_avx else cpu_avx, distribution
if "avx512" in output.lower():
cpu_avx = AVXType.AVX512
elif "avx2" in output.lower():
cpu_avx = AVXType.AVX2
elif "avx " in output.lower():
# cpu_avx = AVXType.AVX
pass
return os_type, env_cpu_avx if env_cpu_avx else cpu_avx, distribution
def get_device() -> str:
try:
import torch
return (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
except ModuleNotFoundError:
return "cpu"
def get_device_info() -> Tuple[str, str, str, int, str]:
torch_version, device, device_version, device_count, device_other = (
None,
"cpu",
None,
0,
"",
)
try:
import torch
torch_version = torch.__version__
if torch.cuda.is_available():
device = "cuda"
device_version = torch.version.cuda
device_count = torch.cuda.device_count()
elif torch.backends.mps.is_available():
device = "mps"
except ModuleNotFoundError:
pass
if not device_version:
device_version = (
get_cuda_version_from_nvcc() or get_cuda_version_from_nvidia_smi()
)
if device == "cuda":
try:
output = subprocess.check_output(
[
"nvidia-smi",
"--query-gpu=name,driver_version,memory.total,memory.free,memory.used",
"--format=csv",
]
)
device_other = output.decode("utf-8")
except:
pass
return torch_version, device, device_version, device_count, device_other
def get_cuda_version_from_nvcc():
try:
output = subprocess.check_output(["nvcc", "--version"])
version_line = [
line for line in output.decode("utf-8").split("\n") if "release" in line
][0]
return version_line.split("release")[-1].strip().split(",")[0]
except:
return None
def get_cuda_version_from_nvidia_smi():
try:
output = subprocess.check_output(["nvidia-smi"]).decode("utf-8")
match = re.search(r"CUDA Version:\s+(\d+\.\d+)", output)
if match:
return match.group(1)
else:
return None
except:
return None
def get_linux_distribution():
"""Get distribution of Linux"""
if os.path.isfile("/etc/os-release"):
with open("/etc/os-release", "r") as f:
info = {}
for line in f:
key, _, value = line.partition("=")
info[key] = value.strip().strip('"')
return f"{info.get('NAME', 'Unknown')} {info.get('VERSION_ID', '')}".strip()
return "Unknown Linux Distribution"
def get_cpu_info():
# Getting platform
os_type, avx_type, distribution = get_cpu_avx_support()
# Getting CPU information
cpu_info = "Unknown CPU"
if os_type == OSType.LINUX:
try:
output = subprocess.check_output(["lscpu"]).decode("utf-8")
match = re.search(r".*Model name:\s*(.+)", output)
if match:
cpu_info = match.group(1).strip()
match = re.search(f".*型号名称:\s*(.+)", output)
if match:
cpu_info = match.group(1).strip()
except:
pass
elif os_type == OSType.DARWIN:
try:
output = subprocess.check_output(
["sysctl", "machdep.cpu.brand_string"]
).decode("utf-8")
match = re.search(r"machdep.cpu.brand_string:\s*(.+)", output)
if match:
cpu_info = match.group(1).strip()
except:
pass
elif os_type == OSType.WINDOWS:
try:
output = subprocess.check_output("wmic cpu get Name", shell=True).decode(
"utf-8"
)
lines = output.splitlines()
cpu_info = lines[2].split(":")[-1].strip()
except:
pass
return os_type, avx_type, cpu_info, distribution
def get_memory_info(os_type: OSType) -> str:
memory = "Unknown Memory"
try:
import psutil
memory = f"{psutil.virtual_memory().total // (1024 ** 3)} GB"
except ImportError:
pass
if os_type == OSType.LINUX:
try:
with open("/proc/meminfo", "r") as f:
mem_info = f.readlines()
for line in mem_info:
if "MemTotal" in line:
memory = line.split(":")[1].strip()
break
except:
pass
return memory
@cache
def get_system_info() -> SystemInfo:
"""Get System information"""
os_type, avx_type, cpu_info, distribution = get_cpu_info()
# Getting Python version
python_version = platform.python_version()
memory = get_memory_info(os_type)
(
torch_version,
device,
device_version,
device_count,
device_other,
) = get_device_info()
return SystemInfo(
platform=os_type._value_,
distribution=distribution,
python_version=python_version,
cpu=cpu_info,
cpu_avx=avx_type._value_,
memory=memory,
torch_version=torch_version,
device=device,
device_version=device_version,
device_count=device_count,
device_other=device_other,
)

View File

View File

@@ -0,0 +1,81 @@
import argparse
import pytest
from dbgpt.util.parameter_utils import _extract_parameter_details
def create_parser():
parser = argparse.ArgumentParser()
return parser
@pytest.mark.parametrize(
"argument, expected_param_name, default_value, param_type, expected_param_type, description",
[
("--option", "option", "value", str, "str", "An option argument"),
("-option", "option", "value", str, "str", "An option argument"),
("--num-gpu", "num_gpu", 1, int, "int", "Number of GPUS"),
("--num_gpu", "num_gpu", 1, int, "int", "Number of GPUS"),
],
)
def test_extract_parameter_details_option_argument(
argument,
expected_param_name,
default_value,
param_type,
expected_param_type,
description,
):
parser = create_parser()
parser.add_argument(
argument, default=default_value, type=param_type, help=description
)
descriptions = _extract_parameter_details(parser)
assert len(descriptions) == 1
desc = descriptions[0]
assert desc.param_name == expected_param_name
assert desc.param_type == expected_param_type
assert desc.default_value == default_value
assert desc.description == description
assert desc.required == False
assert desc.valid_values is None
def test_extract_parameter_details_flag_argument():
parser = create_parser()
parser.add_argument("--flag", action="store_true", help="A flag argument")
descriptions = _extract_parameter_details(parser)
assert len(descriptions) == 1
desc = descriptions[0]
assert desc.param_name == "flag"
assert desc.description == "A flag argument"
assert desc.required == False
def test_extract_parameter_details_choice_argument():
parser = create_parser()
parser.add_argument("--choice", choices=["A", "B", "C"], help="A choice argument")
descriptions = _extract_parameter_details(parser)
assert len(descriptions) == 1
desc = descriptions[0]
assert desc.param_name == "choice"
assert desc.valid_values == ["A", "B", "C"]
def test_extract_parameter_details_required_argument():
parser = create_parser()
parser.add_argument(
"--required", required=True, type=int, help="A required argument"
)
descriptions = _extract_parameter_details(parser)
assert len(descriptions) == 1
desc = descriptions[0]
assert desc.param_name == "required"
assert desc.required == True

View File

@@ -0,0 +1,39 @@
from dbgpt.util.tracer.base import (
SpanType,
Span,
SpanTypeRunName,
Tracer,
SpanStorage,
SpanStorageType,
TracerContext,
)
from dbgpt.util.tracer.span_storage import (
MemorySpanStorage,
FileSpanStorage,
SpanStorageContainer,
)
from dbgpt.util.tracer.tracer_impl import (
root_tracer,
trace,
initialize_tracer,
DefaultTracer,
TracerManager,
)
__all__ = [
"SpanType",
"Span",
"SpanTypeRunName",
"Tracer",
"SpanStorage",
"SpanStorageType",
"TracerContext",
"MemorySpanStorage",
"FileSpanStorage",
"SpanStorageContainer",
"root_tracer",
"trace",
"initialize_tracer",
"DefaultTracer",
"TracerManager",
]

189
dbgpt/util/tracer/base.py Normal file
View File

@@ -0,0 +1,189 @@
from __future__ import annotations
from typing import Dict, Callable, Optional, List
from dataclasses import dataclass
from abc import ABC, abstractmethod
from enum import Enum
import uuid
from datetime import datetime
from dbgpt.component import BaseComponent, SystemApp, ComponentType
class SpanType(str, Enum):
BASE = "base"
RUN = "run"
CHAT = "chat"
class SpanTypeRunName(str, Enum):
WEBSERVER = "Webserver"
WORKER_MANAGER = "WorkerManager"
MODEL_WORKER = "ModelWorker"
EMBEDDING_MODEL = "EmbeddingModel"
@staticmethod
def values():
return [item.value for item in SpanTypeRunName]
class Span:
"""Represents a unit of work that is being traced.
This can be any operation like a function call or a database query.
"""
def __init__(
self,
trace_id: str,
span_id: str,
span_type: SpanType = None,
parent_span_id: str = None,
operation_name: str = None,
metadata: Dict = None,
end_caller: Callable[[Span], None] = None,
):
if not span_type:
span_type = SpanType.BASE
self.span_type = span_type
# The unique identifier for the entire trace
self.trace_id = trace_id
# Unique identifier for this span within the trace
self.span_id = span_id
# Identifier of the parent span, if this is a child span
self.parent_span_id = parent_span_id
# Descriptive name for the operation being traced
self.operation_name = operation_name
# Timestamp when this span started
self.start_time = datetime.now()
# Timestamp when this span ended, initially None
self.end_time = None
# Additional metadata associated with the span
self.metadata = metadata
self._end_callers = []
if end_caller:
self._end_callers.append(end_caller)
def end(self, **kwargs):
"""Mark the end of this span by recording the current time."""
self.end_time = datetime.now()
if "metadata" in kwargs:
self.metadata = kwargs.get("metadata")
for caller in self._end_callers:
caller(self)
def add_end_caller(self, end_caller: Callable[[Span], None]):
if end_caller:
self._end_callers.append(end_caller)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.end()
return False
def to_dict(self) -> Dict:
return {
"span_type": self.span_type.value,
"trace_id": self.trace_id,
"span_id": self.span_id,
"parent_span_id": self.parent_span_id,
"operation_name": self.operation_name,
"start_time": None
if not self.start_time
else self.start_time.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3],
"end_time": None
if not self.end_time
else self.end_time.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3],
"metadata": self.metadata,
}
class SpanStorageType(str, Enum):
ON_CREATE = "on_create"
ON_END = "on_end"
ON_CREATE_END = "on_create_end"
class SpanStorage(BaseComponent, ABC):
"""Abstract base class for storing spans.
This allows different storage mechanisms (e.g., in-memory, database) to be implemented.
"""
name = ComponentType.TRACER_SPAN_STORAGE.value
def init_app(self, system_app: SystemApp):
"""Initialize the storage with the given application context."""
pass
@abstractmethod
def append_span(self, span: Span):
"""Store the given span. This needs to be implemented by subclasses."""
def append_span_batch(self, spans: List[Span]):
"""Store the span batch"""
for span in spans:
self.append_span(span)
class Tracer(BaseComponent, ABC):
"""Abstract base class for tracing operations.
Provides the core logic for starting, ending, and retrieving spans.
"""
name = ComponentType.TRACER.value
def __init__(self, system_app: SystemApp | None = None):
super().__init__(system_app)
self.system_app = system_app # Application context
def init_app(self, system_app: SystemApp):
"""Initialize the tracer with the given application context."""
self.system_app = system_app
@abstractmethod
def append_span(self, span: Span):
"""Append the given span to storage. This needs to be implemented by subclasses."""
@abstractmethod
def start_span(
self,
operation_name: str,
parent_span_id: str = None,
span_type: SpanType = None,
metadata: Dict = None,
) -> Span:
"""Begin a new span for the given operation. If provided, the span will be
a child of the span with the given parent_span_id.
"""
@abstractmethod
def end_span(self, span: Span, **kwargs):
"""
End the given span.
"""
@abstractmethod
def get_current_span(self) -> Optional[Span]:
"""
Retrieve the span that is currently being traced.
"""
@abstractmethod
def _get_current_storage(self) -> SpanStorage:
"""
Get the storage mechanism currently in use for storing spans.
This needs to be implemented by subclasses.
"""
def _new_uuid(self) -> str:
"""
Generate a new unique identifier.
"""
return str(uuid.uuid4())
@dataclass
class TracerContext:
span_id: Optional[str] = None

View File

@@ -0,0 +1,150 @@
import os
import json
import time
import datetime
import threading
import queue
import logging
from typing import Optional, List
from concurrent.futures import Executor, ThreadPoolExecutor
from dbgpt.component import SystemApp
from dbgpt.util.tracer.base import Span, SpanStorage
logger = logging.getLogger(__name__)
class MemorySpanStorage(SpanStorage):
def __init__(self, system_app: SystemApp | None = None):
super().__init__(system_app)
self.spans = []
self._lock = threading.Lock()
def append_span(self, span: Span):
with self._lock:
self.spans.append(span)
class SpanStorageContainer(SpanStorage):
def __init__(
self,
system_app: SystemApp | None = None,
batch_size=10,
flush_interval=10,
executor: Executor = None,
):
super().__init__(system_app)
if not executor:
executor = ThreadPoolExecutor(thread_name_prefix="trace_storage_sync_")
self.executor = executor
self.storages: List[SpanStorage] = []
self.last_date = (
datetime.datetime.now().date()
) # Store the current date for checking date changes
self.queue = queue.Queue()
self.batch_size = batch_size
self.flush_interval = flush_interval
self.last_flush_time = time.time()
self.flush_signal_queue = queue.Queue()
self.flush_thread = threading.Thread(
target=self._flush_to_storages, daemon=True
)
self.flush_thread.start()
def append_storage(self, storage: SpanStorage):
"""Append sotrage to container
Args:
storage ([`SpanStorage`]): The storage to be append to current container
"""
self.storages.append(storage)
def append_span(self, span: Span):
self.queue.put(span)
if self.queue.qsize() >= self.batch_size:
try:
self.flush_signal_queue.put_nowait(True)
except queue.Full:
pass # If the signal queue is full, it's okay. The flush thread will handle it.
def _flush_to_storages(self):
while True:
interval = time.time() - self.last_flush_time
if interval < self.flush_interval:
try:
self.flush_signal_queue.get(
block=True, timeout=self.flush_interval - interval
)
except Exception:
# Timeout
pass
spans_to_write = []
while not self.queue.empty():
spans_to_write.append(self.queue.get())
for s in self.storages:
def append_and_ignore_error(
storage: SpanStorage, spans_to_write: List[SpanStorage]
):
try:
storage.append_span_batch(spans_to_write)
except Exception as e:
logger.warn(
f"Append spans to storage {str(storage)} failed: {str(e)}, span_data: {spans_to_write}"
)
self.executor.submit(append_and_ignore_error, s, spans_to_write)
self.last_flush_time = time.time()
class FileSpanStorage(SpanStorage):
def __init__(self, filename: str):
super().__init__()
self.filename = filename
# Split filename into prefix and suffix
self.filename_prefix, self.filename_suffix = os.path.splitext(filename)
if not self.filename_suffix:
self.filename_suffix = ".log"
self.last_date = (
datetime.datetime.now().date()
) # Store the current date for checking date changes
self.queue = queue.Queue()
if not os.path.exists(filename):
# New file if not exist
os.makedirs(os.path.dirname(filename), exist_ok=True)
with open(filename, "a"):
pass
def append_span(self, span: Span):
self._write_to_file([span])
def append_span_batch(self, spans: List[Span]):
self._write_to_file(spans)
def _get_dated_filename(self, date: datetime.date) -> str:
"""Return the filename based on a specific date."""
date_str = date.strftime("%Y-%m-%d")
return f"{self.filename_prefix}_{date_str}{self.filename_suffix}"
def _roll_over_if_needed(self):
"""Checks if a day has changed since the last write, and if so, renames the current file."""
current_date = datetime.datetime.now().date()
if current_date != self.last_date:
if os.path.exists(self.filename):
os.rename(self.filename, self._get_dated_filename(self.last_date))
self.last_date = current_date
def _write_to_file(self, spans: List[Span]):
self._roll_over_if_needed()
with open(self.filename, "a") as file:
for span in spans:
span_data = span.to_dict()
try:
file.write(json.dumps(span_data, ensure_ascii=False) + "\n")
except Exception as e:
logger.warning(
f"Write span to file failed: {str(e)}, span_data: {span_data}"
)

View File

View File

@@ -0,0 +1,131 @@
from typing import Dict
from dbgpt.component import SystemApp
from dbgpt.util.tracer import Span, SpanType, SpanStorage, Tracer
# Mock implementations
class MockSpanStorage(SpanStorage):
def __init__(self):
self.spans = []
def append_span(self, span: Span):
self.spans.append(span)
class MockTracer(Tracer):
def __init__(self, system_app: SystemApp | None = None):
super().__init__(system_app)
self.current_span = None
self.storage = MockSpanStorage()
def append_span(self, span: Span):
self.storage.append_span(span)
def start_span(
self, operation_name: str, parent_span_id: str = None, metadata: Dict = None
) -> Span:
trace_id = (
self._new_uuid() if parent_span_id is None else parent_span_id.split(":")[0]
)
span_id = f"{trace_id}:{self._new_uuid()}"
span = Span(
trace_id, span_id, SpanType.BASE, parent_span_id, operation_name, metadata
)
self.current_span = span
return span
def end_span(self, span: Span):
span.end()
self.append_span(span)
def get_current_span(self) -> Span:
return self.current_span
def _get_current_storage(self) -> SpanStorage:
return self.storage
# Tests
def test_span_creation():
span = Span(
"trace_id",
"span_id",
SpanType.BASE,
"parent_span_id",
"operation",
{"key": "value"},
)
assert span.trace_id == "trace_id"
assert span.span_id == "span_id"
assert span.parent_span_id == "parent_span_id"
assert span.operation_name == "operation"
assert span.metadata == {"key": "value"}
def test_span_end():
span = Span("trace_id", "span_id")
assert span.end_time is None
span.end()
assert span.end_time is not None
def test_mock_tracer_start_span():
tracer = MockTracer()
span = tracer.start_span("operation")
assert span.operation_name == "operation"
assert tracer.get_current_span() == span
def test_mock_tracer_end_span():
tracer = MockTracer()
span = tracer.start_span("operation")
tracer.end_span(span)
assert span in tracer._get_current_storage().spans
def test_mock_tracer_append_span():
tracer = MockTracer()
span = Span("trace_id", "span_id")
tracer.append_span(span)
assert span in tracer._get_current_storage().spans
def test_parent_child_span_relation():
tracer = MockTracer()
# Start a parent span
parent_span = tracer.start_span("parent_operation")
# Start a child span with parent span's ID
child_span = tracer.start_span(
"child_operation", parent_span_id=parent_span.span_id
)
# Assert the relationships
assert child_span.parent_span_id == parent_span.span_id
assert (
child_span.trace_id == parent_span.trace_id
) # Assuming children share the same trace ID
# End spans
tracer.end_span(child_span)
tracer.end_span(parent_span)
# Assert they are in the storage
assert child_span in tracer._get_current_storage().spans
assert parent_span in tracer._get_current_storage().spans
# This test checks if unique UUIDs are being generated.
# Note: This is a simple test and doesn't guarantee uniqueness for large numbers of UUIDs.
def test_new_uuid_unique():
tracer = MockTracer()
uuid_set = {tracer._new_uuid() for _ in range(1000)}
assert len(uuid_set) == 1000

View File

@@ -0,0 +1,174 @@
import os
import pytest
import asyncio
import json
import tempfile
import time
from unittest.mock import patch
from datetime import datetime, timedelta
from dbgpt.util.tracer import (
SpanStorage,
FileSpanStorage,
Span,
SpanType,
SpanStorageContainer,
)
@pytest.fixture
def storage(request):
if not request or not hasattr(request, "param"):
file_does_not_exist = False
else:
file_does_not_exist = request.param.get("file_does_not_exist", False)
if file_does_not_exist:
with tempfile.TemporaryDirectory() as tmp_dir:
filename = os.path.join(tmp_dir, "non_existent_file.jsonl")
storage_instance = FileSpanStorage(filename)
yield storage_instance
else:
with tempfile.NamedTemporaryFile(delete=True) as tmp_file:
filename = tmp_file.name
storage_instance = FileSpanStorage(filename)
yield storage_instance
@pytest.fixture
def storage_container(request):
if not request or not hasattr(request, "param"):
batch_size = 10
flush_interval = 10
else:
batch_size = request.param.get("batch_size", 10)
flush_interval = request.param.get("flush_interval", 10)
storage_container = SpanStorageContainer(
batch_size=batch_size, flush_interval=flush_interval
)
yield storage_container
def read_spans_from_file(filename):
with open(filename, "r") as f:
return [json.loads(line) for line in f.readlines()]
def test_write_span(storage: SpanStorage):
span = Span("1", "a", SpanType.BASE, "b", "op1")
storage.append_span(span)
time.sleep(0.1)
spans_in_file = read_spans_from_file(storage.filename)
assert len(spans_in_file) == 1
assert spans_in_file[0]["trace_id"] == "1"
def test_incremental_write(storage: SpanStorage):
span1 = Span("1", "a", SpanType.BASE, "b", "op1")
span2 = Span("2", "c", SpanType.BASE, "d", "op2")
storage.append_span(span1)
storage.append_span(span2)
time.sleep(0.1)
spans_in_file = read_spans_from_file(storage.filename)
assert len(spans_in_file) == 2
def test_sync_and_async_append(storage: SpanStorage):
span = Span("1", "a", SpanType.BASE, "b", "op1")
storage.append_span(span)
async def async_append():
storage.append_span(span)
asyncio.run(async_append())
time.sleep(0.1)
spans_in_file = read_spans_from_file(storage.filename)
assert len(spans_in_file) == 2
@pytest.mark.parametrize("storage", [{"file_does_not_exist": True}], indirect=True)
def test_non_existent_file(storage: SpanStorage):
span = Span("1", "a", SpanType.BASE, "b", "op1")
span2 = Span("2", "c", SpanType.BASE, "d", "op2")
storage.append_span(span)
time.sleep(0.1)
spans_in_file = read_spans_from_file(storage.filename)
assert len(spans_in_file) == 1
storage.append_span(span2)
time.sleep(0.1)
spans_in_file = read_spans_from_file(storage.filename)
assert len(spans_in_file) == 2
assert spans_in_file[0]["trace_id"] == "1"
assert spans_in_file[1]["trace_id"] == "2"
@pytest.mark.parametrize("storage", [{"file_does_not_exist": True}], indirect=True)
def test_log_rollover(storage: SpanStorage):
# mock start date
mock_start_date = datetime(2023, 10, 18, 23, 59)
with patch("datetime.datetime") as mock_datetime:
mock_datetime.now.return_value = mock_start_date
span1 = Span("1", "a", SpanType.BASE, "b", "op1")
storage.append_span(span1)
time.sleep(0.1)
# mock new day
mock_datetime.now.return_value = mock_start_date + timedelta(minutes=1)
span2 = Span("2", "c", SpanType.BASE, "d", "op2")
storage.append_span(span2)
time.sleep(0.1)
# origin filename need exists
assert os.path.exists(storage.filename)
# get roll over filename
dated_filename = os.path.join(
os.path.dirname(storage.filename),
f"{os.path.basename(storage.filename).split('.')[0]}_2023-10-18.jsonl",
)
assert os.path.exists(dated_filename)
# check origin filename just include the second span
spans_in_original_file = read_spans_from_file(storage.filename)
assert len(spans_in_original_file) == 1
assert spans_in_original_file[0]["trace_id"] == "2"
# check the roll over filename just include the first span
spans_in_dated_file = read_spans_from_file(dated_filename)
assert len(spans_in_dated_file) == 1
assert spans_in_dated_file[0]["trace_id"] == "1"
@pytest.mark.asyncio
@pytest.mark.parametrize("storage_container", [{"batch_size": 5}], indirect=True)
async def test_container_flush_policy(
storage_container: SpanStorageContainer, storage: FileSpanStorage
):
storage_container.append_storage(storage)
span = Span("1", "a", SpanType.BASE, "b", "op1")
filename = storage.filename
for _ in range(storage_container.batch_size - 1):
storage_container.append_span(span)
spans_in_file = read_spans_from_file(filename)
assert len(spans_in_file) == 0
# Trigger batch write
storage_container.append_span(span)
await asyncio.sleep(0.1)
spans_in_file = read_spans_from_file(filename)
assert len(spans_in_file) == storage_container.batch_size

View File

@@ -0,0 +1,103 @@
import pytest
from dbgpt.util.tracer import (
Span,
SpanStorageType,
SpanStorage,
DefaultTracer,
TracerManager,
Tracer,
MemorySpanStorage,
)
from dbgpt.component import SystemApp
@pytest.fixture
def system_app():
return SystemApp()
@pytest.fixture
def storage(system_app: SystemApp):
ms = MemorySpanStorage(system_app)
system_app.register_instance(ms)
return ms
@pytest.fixture
def tracer(request, system_app: SystemApp):
if not request or not hasattr(request, "param"):
return DefaultTracer(system_app)
else:
span_storage_type = request.param.get(
"span_storage_type", SpanStorageType.ON_CREATE_END
)
return DefaultTracer(system_app, span_storage_type=span_storage_type)
@pytest.fixture
def tracer_manager(system_app: SystemApp, tracer: Tracer):
system_app.register_instance(tracer)
manager = TracerManager()
manager.initialize(system_app)
return manager
def test_start_and_end_span(tracer: Tracer):
span = tracer.start_span("operation")
assert isinstance(span, Span)
assert span.operation_name == "operation"
tracer.end_span(span)
assert span.end_time is not None
stored_span = tracer._get_current_storage().spans[0]
assert stored_span == span
def test_start_and_end_span_with_tracer_manager(tracer_manager: TracerManager):
span = tracer_manager.start_span("operation")
assert isinstance(span, Span)
assert span.operation_name == "operation"
tracer_manager.end_span(span)
assert span.end_time is not None
def test_parent_child_span_relation(tracer: Tracer):
parent_span = tracer.start_span("parent_operation")
child_span = tracer.start_span(
"child_operation", parent_span_id=parent_span.span_id
)
assert child_span.parent_span_id == parent_span.span_id
assert child_span.trace_id == parent_span.trace_id
tracer.end_span(child_span)
tracer.end_span(parent_span)
assert parent_span in tracer._get_current_storage().spans
assert child_span in tracer._get_current_storage().spans
@pytest.mark.parametrize(
"tracer, expected_count, after_create_inc_count",
[
({"span_storage_type": SpanStorageType.ON_CREATE}, 1, 1),
({"span_storage_type": SpanStorageType.ON_END}, 1, 0),
({"span_storage_type": SpanStorageType.ON_CREATE_END}, 2, 1),
],
indirect=["tracer"],
)
def test_tracer_span_storage_type_and_with(
tracer: Tracer,
expected_count: int,
after_create_inc_count: int,
storage: SpanStorage,
):
span = tracer.start_span("new_span")
span.end()
assert len(storage.spans) == expected_count
with tracer.start_span("with_span") as ws:
assert len(storage.spans) == expected_count + after_create_inc_count
assert len(storage.spans) == expected_count + expected_count

View File

@@ -0,0 +1,597 @@
import os
import click
import logging
import glob
import json
from datetime import datetime
from typing import Iterable, Dict, Callable
from dbgpt.configs.model_config import LOGDIR
from dbgpt.util.tracer import SpanType, SpanTypeRunName
logger = logging.getLogger("dbgpt_cli")
_DEFAULT_FILE_PATTERN = os.path.join(LOGDIR, "dbgpt*.jsonl")
@click.group("trace")
def trace_cli_group():
"""Analyze and visualize trace spans."""
pass
@trace_cli_group.command()
@click.option(
"--trace_id",
required=False,
type=str,
default=None,
show_default=True,
help="Specify the trace ID to list",
)
@click.option(
"--span_id",
required=False,
type=str,
default=None,
show_default=True,
help="Specify the Span ID to list.",
)
@click.option(
"--span_type",
required=False,
type=str,
default=None,
show_default=True,
help="Specify the Span Type to list.",
)
@click.option(
"--parent_span_id",
required=False,
type=str,
default=None,
show_default=True,
help="Specify the Parent Span ID to list.",
)
@click.option(
"--search",
required=False,
type=str,
default=None,
show_default=True,
help="Search trace_id, span_id, parent_span_id, operation_name or content in metadata.",
)
@click.option(
"-l",
"--limit",
type=int,
default=20,
help="Limit the number of recent span displayed.",
)
@click.option(
"--start_time",
type=str,
help='Filter by start time. Format: "YYYY-MM-DD HH:MM:SS.mmm"',
)
@click.option(
"--end_time", type=str, help='Filter by end time. Format: "YYYY-MM-DD HH:MM:SS.mmm"'
)
@click.option(
"--desc",
required=False,
type=bool,
default=False,
is_flag=True,
help="Whether to use reverse sorting. By default, sorting is based on start time.",
)
@click.option(
"--output",
required=False,
type=click.Choice(["text", "html", "csv", "latex", "json"]),
default="text",
help="The output format",
)
@click.argument("files", nargs=-1, type=click.Path(exists=True, readable=True))
def list(
trace_id: str,
span_id: str,
span_type: str,
parent_span_id: str,
search: str,
limit: int,
start_time: str,
end_time: str,
desc: bool,
output: str,
files=None,
):
"""List your trace spans"""
from prettytable import PrettyTable
# If no files are explicitly specified, use the default pattern to get them
spans = read_spans_from_files(files)
if trace_id:
spans = filter(lambda s: s["trace_id"] == trace_id, spans)
if span_id:
spans = filter(lambda s: s["span_id"] == span_id, spans)
if span_type:
spans = filter(lambda s: s["span_type"] == span_type, spans)
if parent_span_id:
spans = filter(lambda s: s["parent_span_id"] == parent_span_id, spans)
# Filter spans based on the start and end times
if start_time:
start_dt = _parse_datetime(start_time)
spans = filter(
lambda span: _parse_datetime(span["start_time"]) >= start_dt, spans
)
if end_time:
end_dt = _parse_datetime(end_time)
spans = filter(
lambda span: _parse_datetime(span["start_time"]) <= end_dt, spans
)
if search:
spans = filter(_new_search_span_func(search), spans)
# Sort spans based on the start time
spans = sorted(
spans, key=lambda span: _parse_datetime(span["start_time"]), reverse=desc
)[:limit]
table = PrettyTable(
["Trace ID", "Span ID", "Operation Name", "Conversation UID"],
)
for sp in spans:
conv_uid = None
if "metadata" in sp and sp:
metadata = sp["metadata"]
if isinstance(metadata, dict):
conv_uid = metadata.get("conv_uid")
table.add_row(
[
sp.get("trace_id"),
sp.get("span_id"),
# sp.get("parent_span_id"),
sp.get("operation_name"),
conv_uid,
]
)
out_kwargs = {"ensure_ascii": False} if output == "json" else {}
print(table.get_formatted_string(out_format=output, **out_kwargs))
@trace_cli_group.command()
@click.option(
"--trace_id",
required=True,
type=str,
help="Specify the trace ID to list",
)
@click.argument("files", nargs=-1, type=click.Path(exists=True, readable=True))
def tree(trace_id: str, files):
"""Display trace links as a tree"""
hierarchy = _view_trace_hierarchy(trace_id, files)
if not hierarchy:
_print_empty_message(files)
return
_print_trace_hierarchy(hierarchy)
@trace_cli_group.command()
@click.option(
"--trace_id",
required=False,
type=str,
default=None,
help="Specify the trace ID to analyze. If None, show latest conversation details",
)
@click.option(
"--tree",
required=False,
type=bool,
default=False,
is_flag=True,
help="Display trace spans as a tree",
)
@click.option(
"--hide_conv",
required=False,
type=bool,
default=False,
is_flag=True,
help="Hide your conversation details",
)
@click.option(
"--hide_run_params",
required=False,
type=bool,
default=False,
is_flag=True,
help="Hide run params",
)
@click.option(
"--output",
required=False,
type=click.Choice(["text", "html", "csv", "latex", "json"]),
default="text",
help="The output format",
)
@click.argument("files", nargs=-1, type=click.Path(exists=False, readable=True))
def chat(
trace_id: str,
tree: bool,
hide_conv: bool,
hide_run_params: bool,
output: str,
files,
):
"""Show conversation details"""
from prettytable import PrettyTable
spans = read_spans_from_files(files)
# Sort by start time
spans = sorted(
spans, key=lambda span: _parse_datetime(span["start_time"]), reverse=True
)
spans = [sp for sp in spans]
if not spans:
_print_empty_message(files)
return
service_spans = {}
service_names = set(SpanTypeRunName.values())
found_trace_id = None
for sp in spans:
span_type = sp["span_type"]
metadata = sp.get("metadata")
if span_type == SpanType.RUN:
service_name = metadata["run_service"]
service_spans[service_name] = sp.copy()
if set(service_spans.keys()) == service_names and found_trace_id:
break
elif span_type == SpanType.CHAT and not found_trace_id:
if not trace_id:
found_trace_id = sp["trace_id"]
if trace_id and trace_id == sp["trace_id"]:
found_trace_id = trace_id
service_tables = {}
system_infos_table = {}
out_kwargs = {"ensure_ascii": False} if output == "json" else {}
for service_name, sp in service_spans.items():
metadata = sp["metadata"]
table = PrettyTable(["Config Key", "Config Value"], title=service_name)
for k, v in metadata["params"].items():
table.add_row([k, v])
service_tables[service_name] = table
sys_infos = metadata.get("sys_infos")
if sys_infos and isinstance(sys_infos, dict):
sys_table = PrettyTable(
["System Config Key", "System Config Value"],
title=f"{service_name} System information",
)
for k, v in sys_infos.items():
sys_table.add_row([k, v])
system_infos_table[service_name] = sys_table
if not hide_run_params:
merged_table1 = merge_tables_horizontally(
[
service_tables.get(SpanTypeRunName.WEBSERVER.value),
service_tables.get(SpanTypeRunName.EMBEDDING_MODEL.value),
]
)
merged_table2 = merge_tables_horizontally(
[
service_tables.get(SpanTypeRunName.MODEL_WORKER.value),
service_tables.get(SpanTypeRunName.WORKER_MANAGER.value),
]
)
sys_table = system_infos_table.get(SpanTypeRunName.WORKER_MANAGER.value)
if system_infos_table:
for k, v in system_infos_table.items():
sys_table = v
break
if output == "text":
print(merged_table1)
print(merged_table2)
else:
for service_name, table in service_tables.items():
print(table.get_formatted_string(out_format=output, **out_kwargs))
if sys_table:
print(sys_table.get_formatted_string(out_format=output, **out_kwargs))
if not found_trace_id:
print(f"Can't found conversation with trace_id: {trace_id}")
return
trace_id = found_trace_id
trace_spans = [span for span in spans if span["trace_id"] == trace_id]
trace_spans = [s for s in reversed(trace_spans)]
hierarchy = _build_trace_hierarchy(trace_spans)
if tree:
print(f"\nInvoke Trace Tree(trace_id: {trace_id}):\n")
_print_trace_hierarchy(hierarchy)
if hide_conv:
return
trace_spans = _get_ordered_trace_from(hierarchy)
table = PrettyTable(["Key", "Value Value"], title="Chat Trace Details")
split_long_text = output == "text"
for sp in trace_spans:
op = sp["operation_name"]
metadata = sp.get("metadata")
if op == "get_chat_instance" and not sp["end_time"]:
table.add_row(["trace_id", trace_id])
table.add_row(["span_id", sp["span_id"]])
table.add_row(["conv_uid", metadata.get("conv_uid")])
table.add_row(["user_input", metadata.get("user_input")])
table.add_row(["chat_mode", metadata.get("chat_mode")])
table.add_row(["select_param", metadata.get("select_param")])
table.add_row(["model_name", metadata.get("model_name")])
if op in ["BaseChat.stream_call", "BaseChat.nostream_call"]:
if not sp["end_time"]:
table.add_row(["temperature", metadata.get("temperature")])
table.add_row(["max_new_tokens", metadata.get("max_new_tokens")])
table.add_row(["echo", metadata.get("echo")])
elif "error" in metadata:
table.add_row(["BaseChat Error", metadata.get("error")])
if op == "BaseChat.do_action" and not sp["end_time"]:
if "model_output" in metadata:
table.add_row(
[
"BaseChat model_output",
split_string_by_terminal_width(
metadata.get("model_output").get("text"),
split=split_long_text,
),
]
)
if "ai_response_text" in metadata:
table.add_row(
[
"BaseChat ai_response_text",
split_string_by_terminal_width(
metadata.get("ai_response_text"), split=split_long_text
),
]
)
if "prompt_define_response" in metadata:
prompt_define_response = metadata.get("prompt_define_response") or ""
if isinstance(prompt_define_response, dict) or isinstance(
prompt_define_response, type([])
):
prompt_define_response = json.dumps(
prompt_define_response, ensure_ascii=False
)
table.add_row(
[
"BaseChat prompt_define_response",
split_string_by_terminal_width(
prompt_define_response,
split=split_long_text,
),
]
)
if op == "DefaultModelWorker_call.generate_stream_func":
if not sp["end_time"]:
table.add_row(["llm_adapter", metadata.get("llm_adapter")])
table.add_row(
[
"User prompt",
split_string_by_terminal_width(
metadata.get("prompt"), split=split_long_text
),
]
)
else:
table.add_row(
[
"Model output",
split_string_by_terminal_width(metadata.get("output")),
]
)
if (
op
in [
"DefaultModelWorker.async_generate_stream",
"DefaultModelWorker.generate_stream",
]
and metadata
and "error" in metadata
):
table.add_row(["Model Error", metadata.get("error")])
print(table.get_formatted_string(out_format=output, **out_kwargs))
def read_spans_from_files(files=None) -> Iterable[Dict]:
"""
Reads spans from multiple files based on the provided file paths.
"""
if not files:
files = [_DEFAULT_FILE_PATTERN]
for filepath in files:
for filename in glob.glob(filepath):
with open(filename, "r") as file:
for line in file:
yield json.loads(line)
def _print_empty_message(files=None):
if not files:
files = [_DEFAULT_FILE_PATTERN]
file_names = ",".join(files)
print(f"No trace span records found in your tracer files: {file_names}")
def _new_search_span_func(search: str):
def func(span: Dict) -> bool:
items = [span["trace_id"], span["span_id"], span["parent_span_id"]]
if "operation_name" in span:
items.append(span["operation_name"])
if "metadata" in span:
metadata = span["metadata"]
if isinstance(metadata, dict):
for k, v in metadata.items():
items.append(k)
items.append(v)
return any(search in str(item) for item in items if item)
return func
def _parse_datetime(dt_str):
"""Parse a datetime string to a datetime object."""
return datetime.strptime(dt_str, "%Y-%m-%d %H:%M:%S.%f")
def _build_trace_hierarchy(spans, parent_span_id=None, indent=0):
# Current spans
current_level_spans = [
span
for span in spans
if span["parent_span_id"] == parent_span_id and span["end_time"] is None
]
hierarchy = []
for start_span in current_level_spans:
# Find end span
end_span = next(
(
span
for span in spans
if span["span_id"] == start_span["span_id"]
and span["end_time"] is not None
),
None,
)
entry = {
"operation_name": start_span["operation_name"],
"parent_span_id": start_span["parent_span_id"],
"span_id": start_span["span_id"],
"start_time": start_span["start_time"],
"end_time": start_span["end_time"],
"metadata": start_span["metadata"],
"children": _build_trace_hierarchy(
spans, start_span["span_id"], indent + 1
),
}
hierarchy.append(entry)
# Append end span
if end_span:
entry_end = {
"operation_name": end_span["operation_name"],
"parent_span_id": end_span["parent_span_id"],
"span_id": end_span["span_id"],
"start_time": end_span["start_time"],
"end_time": end_span["end_time"],
"metadata": end_span["metadata"],
"children": [],
}
hierarchy.append(entry_end)
return hierarchy
def _view_trace_hierarchy(trace_id, files=None):
"""Find and display the calls of the entire link based on the given trace_id"""
spans = read_spans_from_files(files)
trace_spans = [span for span in spans if span["trace_id"] == trace_id]
if not trace_spans:
return None
hierarchy = _build_trace_hierarchy(trace_spans)
return hierarchy
def _print_trace_hierarchy(hierarchy, indent=0):
"""Print link hierarchy"""
for entry in hierarchy:
print(
" " * indent
+ f"Operation: {entry['operation_name']} (Start: {entry['start_time']}, End: {entry['end_time']})"
)
_print_trace_hierarchy(entry["children"], indent + 1)
def _get_ordered_trace_from(hierarchy):
traces = []
def func(items):
for item in items:
traces.append(item)
func(item["children"])
func(hierarchy)
return traces
def _print(service_spans: Dict):
for names in [
[SpanTypeRunName.WEBSERVER.name, SpanTypeRunName.EMBEDDING_MODEL],
[SpanTypeRunName.WORKER_MANAGER.name, SpanTypeRunName.MODEL_WORKER],
]:
pass
def merge_tables_horizontally(tables):
from prettytable import PrettyTable
if not tables:
return None
tables = [t for t in tables if t]
if not tables:
return None
max_rows = max(len(table._rows) for table in tables)
merged_table = PrettyTable()
new_field_names = []
for table in tables:
new_field_names.extend(
[
f"{name} ({table.title})" if table.title else f"{name}"
for name in table.field_names
]
)
merged_table.field_names = new_field_names
for i in range(max_rows):
merged_row = []
for table in tables:
if i < len(table._rows):
merged_row.extend(table._rows[i])
else:
# Fill empty cells for shorter tables
merged_row.extend([""] * len(table.field_names))
merged_table.add_row(merged_row)
return merged_table
def split_string_by_terminal_width(s, split=True, max_len=None, sp="\n"):
"""
Split a string into substrings based on the current terminal width.
Parameters:
- s: the input string
"""
if not split:
return s
if not max_len:
try:
max_len = int(os.get_terminal_size().columns * 0.8)
except OSError:
# Default to 80 columns if the terminal size can't be determined
max_len = 100
return sp.join([s[i : i + max_len] for i in range(0, len(s), max_len)])

View File

@@ -0,0 +1,235 @@
from typing import Dict, Optional
from contextvars import ContextVar
from functools import wraps
import asyncio
import inspect
import logging
from dbgpt.component import SystemApp, ComponentType
from dbgpt.util.tracer.base import (
SpanType,
Span,
Tracer,
SpanStorage,
SpanStorageType,
TracerContext,
)
from dbgpt.util.tracer.span_storage import MemorySpanStorage
from dbgpt.util.module_utils import import_from_checked_string
logger = logging.getLogger(__name__)
class DefaultTracer(Tracer):
def __init__(
self,
system_app: SystemApp | None = None,
default_storage: SpanStorage = None,
span_storage_type: SpanStorageType = SpanStorageType.ON_CREATE_END,
):
super().__init__(system_app)
self._span_stack_var = ContextVar("span_stack", default=[])
if not default_storage:
default_storage = MemorySpanStorage(system_app)
self._default_storage = default_storage
self._span_storage_type = span_storage_type
def append_span(self, span: Span):
self._get_current_storage().append_span(span)
def start_span(
self,
operation_name: str,
parent_span_id: str = None,
span_type: SpanType = None,
metadata: Dict = None,
) -> Span:
trace_id = (
self._new_uuid() if parent_span_id is None else parent_span_id.split(":")[0]
)
span_id = f"{trace_id}:{self._new_uuid()}"
span = Span(
trace_id,
span_id,
span_type,
parent_span_id,
operation_name,
metadata=metadata,
)
if self._span_storage_type in [
SpanStorageType.ON_END,
SpanStorageType.ON_CREATE_END,
]:
span.add_end_caller(self.append_span)
if self._span_storage_type in [
SpanStorageType.ON_CREATE,
SpanStorageType.ON_CREATE_END,
]:
self.append_span(span)
current_stack = self._span_stack_var.get()
current_stack.append(span)
self._span_stack_var.set(current_stack)
span.add_end_caller(self._remove_from_stack_top)
return span
def end_span(self, span: Span, **kwargs):
""""""
span.end(**kwargs)
def _remove_from_stack_top(self, span: Span):
current_stack = self._span_stack_var.get()
if current_stack:
current_stack.pop()
self._span_stack_var.set(current_stack)
def get_current_span(self) -> Optional[Span]:
current_stack = self._span_stack_var.get()
return current_stack[-1] if current_stack else None
def _get_current_storage(self) -> SpanStorage:
return self.system_app.get_component(
ComponentType.TRACER_SPAN_STORAGE, SpanStorage, self._default_storage
)
class TracerManager:
"""The manager of current tracer"""
def __init__(self) -> None:
self._system_app: Optional[SystemApp] = None
self._trace_context_var: ContextVar[TracerContext] = ContextVar(
"trace_context",
default=TracerContext(),
)
def initialize(
self, system_app: SystemApp, trace_context_var: ContextVar[TracerContext] = None
) -> None:
self._system_app = system_app
if trace_context_var:
self._trace_context_var = trace_context_var
def _get_tracer(self) -> Tracer:
if not self._system_app:
return None
return self._system_app.get_component(ComponentType.TRACER, Tracer, None)
def start_span(
self,
operation_name: str,
parent_span_id: str = None,
span_type: SpanType = None,
metadata: Dict = None,
) -> Span:
"""Start a new span with operation_name
This method must not throw an exception under any case and try not to block as much as possible
"""
tracer = self._get_tracer()
if not tracer:
return Span("empty_span", "empty_span")
if not parent_span_id:
parent_span_id = self.get_current_span_id()
return tracer.start_span(
operation_name, parent_span_id, span_type=span_type, metadata=metadata
)
def end_span(self, span: Span, **kwargs):
tracer = self._get_tracer()
if not tracer or not span:
return
tracer.end_span(span, **kwargs)
def get_current_span(self) -> Optional[Span]:
tracer = self._get_tracer()
if not tracer:
return None
return tracer.get_current_span()
def get_current_span_id(self) -> Optional[str]:
current_span = self.get_current_span()
if current_span:
return current_span.span_id
ctx = self._trace_context_var.get()
return ctx.span_id if ctx else None
root_tracer: TracerManager = TracerManager()
def trace(operation_name: Optional[str] = None, **trace_kwargs):
def decorator(func):
@wraps(func)
def sync_wrapper(*args, **kwargs):
name = (
operation_name if operation_name else _parse_operation_name(func, *args)
)
with root_tracer.start_span(name, **trace_kwargs):
return func(*args, **kwargs)
@wraps(func)
async def async_wrapper(*args, **kwargs):
name = (
operation_name if operation_name else _parse_operation_name(func, *args)
)
with root_tracer.start_span(name, **trace_kwargs):
return await func(*args, **kwargs)
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
return decorator
def _parse_operation_name(func, *args):
self_name = None
if inspect.signature(func).parameters.get("self"):
self_name = args[0].__class__.__name__
func_name = func.__name__
if self_name:
return f"{self_name}.{func_name}"
return func_name
def initialize_tracer(
system_app: SystemApp,
tracer_filename: str,
root_operation_name: str = "DB-GPT-Web-Entry",
tracer_storage_cls: str = None,
):
if not system_app:
return
from dbgpt.util.tracer.span_storage import FileSpanStorage, SpanStorageContainer
trace_context_var = ContextVar(
"trace_context",
default=TracerContext(),
)
tracer = DefaultTracer(system_app)
storage_container = SpanStorageContainer(system_app)
storage_container.append_storage(FileSpanStorage(tracer_filename))
if tracer_storage_cls:
logger.info(f"Begin parse storage class {tracer_storage_cls}")
storage = import_from_checked_string(tracer_storage_cls, SpanStorage)
storage_container.append_storage(storage())
system_app.register_instance(storage_container)
system_app.register_instance(tracer)
root_tracer.initialize(system_app, trace_context_var)
if system_app.app:
from dbgpt.util.tracer.tracer_middleware import TraceIDMiddleware
system_app.app.add_middleware(
TraceIDMiddleware,
trace_context_var=trace_context_var,
tracer=tracer,
root_operation_name=root_operation_name,
)

View File

@@ -0,0 +1,45 @@
import uuid
from contextvars import ContextVar
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.types import ASGIApp
from dbgpt.util.tracer import TracerContext, Tracer
_DEFAULT_EXCLUDE_PATHS = ["/api/controller/heartbeat"]
class TraceIDMiddleware(BaseHTTPMiddleware):
def __init__(
self,
app: ASGIApp,
trace_context_var: ContextVar[TracerContext],
tracer: Tracer,
root_operation_name: str = "DB-GPT-Web-Entry",
include_prefix: str = "/api",
exclude_paths=_DEFAULT_EXCLUDE_PATHS,
):
super().__init__(app)
self.trace_context_var = trace_context_var
self.tracer = tracer
self.root_operation_name = root_operation_name
self.include_prefix = include_prefix
self.exclude_paths = exclude_paths
async def dispatch(self, request: Request, call_next):
if request.url.path in self.exclude_paths or not request.url.path.startswith(
self.include_prefix
):
return await call_next(request)
span_id = request.headers.get("DBGPT_TRACER_SPAN_ID")
# if not span_id:
# span_id = str(uuid.uuid4())
# self.trace_context_var.set(TracerContext(span_id=span_id))
with self.tracer.start_span(
self.root_operation_name, span_id, metadata={"path": request.url.path}
):
response = await call_next(request)
return response

229
dbgpt/util/utils.py Normal file
View File

@@ -0,0 +1,229 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import logging
import logging.handlers
from typing import Any, List
import os
import sys
import asyncio
from dbgpt.configs.model_config import LOGDIR
server_error_msg = (
"**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
)
handler = None
def _get_logging_level() -> str:
return os.getenv("DBGPT_LOG_LEVEL", "INFO")
def setup_logging_level(logging_level=None, logger_name: str = None):
if not logging_level:
logging_level = _get_logging_level()
if type(logging_level) is str:
logging_level = logging.getLevelName(logging_level.upper())
if logger_name:
logger = logging.getLogger(logger_name)
logger.setLevel(logging_level)
else:
logging.basicConfig(level=logging_level, encoding="utf-8")
def setup_logging(logger_name: str, logging_level=None, logger_filename: str = None):
if not logging_level:
logging_level = _get_logging_level()
logger = _build_logger(logger_name, logging_level, logger_filename)
try:
import coloredlogs
color_level = logging_level if logging_level else "INFO"
coloredlogs.install(level=color_level, logger=logger)
except ImportError:
pass
def get_gpu_memory(max_gpus=None):
import torch
gpu_memory = []
num_gpus = (
torch.cuda.device_count()
if max_gpus is None
else min(max_gpus, torch.cuda.device_count())
)
for gpu_id in range(num_gpus):
with torch.cuda.device(gpu_id):
device = torch.cuda.current_device()
gpu_properties = torch.cuda.get_device_properties(device)
total_memory = gpu_properties.total_memory / (1024**3)
allocated_memory = torch.cuda.memory_allocated() / (1024**3)
available_memory = total_memory - allocated_memory
gpu_memory.append(available_memory)
return gpu_memory
def _build_logger(logger_name, logging_level=None, logger_filename: str = None):
global handler
formatter = logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# Set the format of root handlers
if not logging.getLogger().handlers:
setup_logging_level(logging_level=logging_level)
logging.getLogger().handlers[0].setFormatter(formatter)
# Redirect stdout and stderr to loggers
# stdout_logger = logging.getLogger("stdout")
# stdout_logger.setLevel(logging.INFO)
# sl_1 = StreamToLogger(stdout_logger, logging.INFO)
# sys.stdout = sl_1
#
# stderr_logger = logging.getLogger("stderr")
# stderr_logger.setLevel(logging.ERROR)
# sl = StreamToLogger(stderr_logger, logging.ERROR)
# sys.stderr = sl
# Add a file handler for all loggers
if handler is None and logger_filename:
os.makedirs(LOGDIR, exist_ok=True)
filename = os.path.join(LOGDIR, logger_filename)
handler = logging.handlers.TimedRotatingFileHandler(
filename, when="D", utc=True
)
handler.setFormatter(formatter)
for name, item in logging.root.manager.loggerDict.items():
if isinstance(item, logging.Logger):
item.addHandler(handler)
# Get logger
logger = logging.getLogger(logger_name)
setup_logging_level(logging_level=logging_level, logger_name=logger_name)
return logger
class StreamToLogger(object):
"""
Fake file-like stream object that redirects writes to a logger instance.
"""
def __init__(self, logger, log_level=logging.INFO):
self.terminal = sys.stdout
self.logger = logger
self.log_level = log_level
self.linebuf = ""
def __getattr__(self, attr):
return getattr(self.terminal, attr)
def write(self, buf):
temp_linebuf = self.linebuf + buf
self.linebuf = ""
for line in temp_linebuf.splitlines(True):
# From the io.TextIOWrapper docs:
# On output, if newline is None, any '\n' characters written
# are translated to the system default line separator.
# By default sys.stdout.write() expects '\n' newlines and then
# translates them so this is still cross platform.
if line[-1] == "\n":
encoded_message = line.encode("utf-8", "ignore").decode("utf-8")
self.logger.log(self.log_level, encoded_message.rstrip())
else:
self.linebuf += line
def flush(self):
if self.linebuf != "":
encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8")
self.logger.log(self.log_level, encoded_message.rstrip())
self.linebuf = ""
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def pretty_print_semaphore(semaphore):
if semaphore is None:
return "None"
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
def get_or_create_event_loop() -> asyncio.BaseEventLoop:
loop = None
try:
loop = asyncio.get_event_loop()
assert loop is not None
return loop
except RuntimeError as e:
if not "no running event loop" in str(e) and not "no current event loop" in str(
e
):
raise e
logging.warning("Cant not get running event loop, create new event loop now")
return asyncio.get_event_loop_policy().new_event_loop()
def logging_str_to_uvicorn_level(log_level_str):
level_str_mapping = {
"CRITICAL": "critical",
"ERROR": "error",
"WARNING": "warning",
"INFO": "info",
"DEBUG": "debug",
"NOTSET": "info",
}
return level_str_mapping.get(log_level_str.upper(), "info")
class EndpointFilter(logging.Filter):
"""Disable access log on certain endpoint
source: https://github.com/encode/starlette/issues/864#issuecomment-1254987630
"""
def __init__(
self,
path: str,
*args: Any,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
self._path = path
def filter(self, record: logging.LogRecord) -> bool:
return record.getMessage().find(self._path) == -1
def setup_http_service_logging(exclude_paths: List[str] = None):
"""Setup http service logging
Now just disable some logs
Args:
exclude_paths (List[str]): The paths to disable log
"""
if not exclude_paths:
# Not show heartbeat log
exclude_paths = ["/api/controller/heartbeat"]
uvicorn_logger = logging.getLogger("uvicorn.access")
if uvicorn_logger:
for path in exclude_paths:
uvicorn_logger.addFilter(EndpointFilter(path=path))
httpx_logger = logging.getLogger("httpx")
if httpx_logger:
httpx_logger.setLevel(logging.WARNING)