mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-12 20:53:48 +00:00
refactor: The first refactored version for sdk release (#907)
Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
8
dbgpt/util/__init__.py
Normal file
8
dbgpt/util/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from .utils import (
|
||||
get_gpu_memory,
|
||||
StreamToLogger,
|
||||
disable_torch_init,
|
||||
pretty_print_semaphore,
|
||||
server_error_msg,
|
||||
get_or_create_event_loop,
|
||||
)
|
67
dbgpt/util/annotations.py
Normal file
67
dbgpt/util/annotations.py
Normal file
@@ -0,0 +1,67 @@
|
||||
def PublicAPI(*args, **kwargs):
|
||||
"""Decorator to mark a function or class as a public API.
|
||||
|
||||
Args:
|
||||
stability: The stability of the API. Can be "alpha", "beta" or "stable".
|
||||
If "alpha", the API is in alpha may come breaking changes before becoming beta.
|
||||
If "beta", the API is in beta and may change before becoming stable.
|
||||
If "stable", the API will remain backwards compatible with the current major version.
|
||||
Defaults to "stable".
|
||||
Examples:
|
||||
>>> from dbgpt.util.annotations import PublicAPI
|
||||
>>> @PublicAPI
|
||||
... def foo():
|
||||
... pass
|
||||
|
||||
>>> @PublicAPI(stability="beta")
|
||||
... def bar():
|
||||
... pass
|
||||
|
||||
"""
|
||||
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
||||
return PublicAPI(stability="stable")(args[0])
|
||||
stability = None
|
||||
if "stability" in kwargs:
|
||||
stability = kwargs["stability"]
|
||||
if not stability:
|
||||
stability = "stable"
|
||||
assert stability in ["alpha", "beta", "stable"]
|
||||
|
||||
def decorator(obj):
|
||||
if stability in ["alpha", "beta"]:
|
||||
_modify_docstring(
|
||||
obj,
|
||||
f"**PublicAPI ({stability}):** This API is in {stability} and may change before becoming stable.",
|
||||
)
|
||||
_modify_annotation(obj, stability)
|
||||
return obj
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _modify_docstring(obj, message: str = None):
|
||||
if not message:
|
||||
return
|
||||
if not obj.__doc__:
|
||||
obj.__doc__ = ""
|
||||
original_doc = obj.__doc__
|
||||
|
||||
lines = original_doc.splitlines()
|
||||
|
||||
min_indent = float("inf")
|
||||
for line in lines[1:]:
|
||||
stripped = line.lstrip()
|
||||
if stripped:
|
||||
min_indent = min(min_indent, len(line) - len(stripped))
|
||||
|
||||
if min_indent == float("inf"):
|
||||
min_indent = 0
|
||||
indented_message = message.rstrip() + "\n" + (" " * min_indent)
|
||||
obj.__doc__ = indented_message + original_doc
|
||||
|
||||
|
||||
def _modify_annotation(obj, stability) -> None:
|
||||
if stability:
|
||||
obj._public_stability = stability
|
||||
if hasattr(obj, "__name__"):
|
||||
obj._annotated = obj.__name__
|
127
dbgpt/util/api_utils.py
Normal file
127
dbgpt/util/api_utils.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from inspect import signature
|
||||
import logging
|
||||
from typing import get_type_hints, List, Type, TypeVar, Union, Optional, Tuple
|
||||
from dataclasses import is_dataclass, asdict
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]:
|
||||
import typing_inspect
|
||||
|
||||
"""Extract actual dataclass from generic type hints like List[dataclass], Optional[dataclass], etc."""
|
||||
if typing_inspect.is_generic_type(type_hint) and typing_inspect.get_args(type_hint):
|
||||
return typing_inspect.get_args(type_hint)[0]
|
||||
return None
|
||||
|
||||
|
||||
def _build_request(self, func, path, method, *args, **kwargs):
|
||||
return_type = get_type_hints(func).get("return")
|
||||
if return_type is None:
|
||||
raise TypeError("Return type must be annotated in the decorated function.")
|
||||
|
||||
actual_dataclass = _extract_dataclass_from_generic(return_type)
|
||||
logger.debug(f"return_type: {return_type}, actual_dataclass: {actual_dataclass}")
|
||||
if not actual_dataclass:
|
||||
actual_dataclass = return_type
|
||||
sig = signature(func)
|
||||
base_url = self.base_url # Get base_url from class instance
|
||||
|
||||
bound = sig.bind(self, *args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
|
||||
formatted_url = base_url + path.format(**bound.arguments)
|
||||
|
||||
# Extract args names from signature, except "self"
|
||||
arg_names = list(sig.parameters.keys())[1:]
|
||||
|
||||
# Combine args and kwargs into a single dictionary
|
||||
combined_args = dict(zip(arg_names, args))
|
||||
combined_args.update(kwargs)
|
||||
|
||||
request_data = {}
|
||||
for key, value in combined_args.items():
|
||||
if is_dataclass(value):
|
||||
# Here, instead of adding it as a nested dictionary,
|
||||
# we set request_data directly to its dictionary representation.
|
||||
request_data = asdict(value)
|
||||
else:
|
||||
request_data[key] = value
|
||||
|
||||
request_params = {"method": method, "url": formatted_url}
|
||||
|
||||
if method in ["POST", "PUT", "PATCH"]:
|
||||
request_params["json"] = request_data
|
||||
else: # For GET, DELETE, etc.
|
||||
request_params["params"] = request_data
|
||||
|
||||
logger.debug(f"request_params: {request_params}, args: {args}, kwargs: {kwargs}")
|
||||
return return_type, actual_dataclass, request_params
|
||||
|
||||
|
||||
def _api_remote(path, method="GET"):
|
||||
def decorator(func):
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
import httpx
|
||||
|
||||
return_type, actual_dataclass, request_params = _build_request(
|
||||
self, func, path, method, *args, **kwargs
|
||||
)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.request(**request_params)
|
||||
if response.status_code == 200:
|
||||
return _parse_response(
|
||||
response.json(), return_type, actual_dataclass
|
||||
)
|
||||
else:
|
||||
error_msg = f"Remote request error, error code: {response.status_code}, error msg: {response.text}"
|
||||
raise Exception(error_msg)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _sync_api_remote(path, method="GET"):
|
||||
def decorator(func):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
import requests
|
||||
|
||||
return_type, actual_dataclass, request_params = _build_request(
|
||||
self, func, path, method, *args, **kwargs
|
||||
)
|
||||
|
||||
response = requests.request(**request_params)
|
||||
|
||||
if response.status_code == 200:
|
||||
return _parse_response(response.json(), return_type, actual_dataclass)
|
||||
else:
|
||||
error_msg = f"Remote request error, error code: {response.status_code}, error msg: {response.text}"
|
||||
raise Exception(error_msg)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _parse_response(json_response, return_type, actual_dataclass):
|
||||
# print(f'return_type.__origin__: {return_type.__origin__}, actual_dataclass: {actual_dataclass}, json_response: {json_response}')
|
||||
if is_dataclass(actual_dataclass):
|
||||
if return_type.__origin__ is list: # for List[dataclass]
|
||||
if isinstance(json_response, list):
|
||||
return [actual_dataclass(**item) for item in json_response]
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected list in response but got {type(json_response)}"
|
||||
)
|
||||
else:
|
||||
if isinstance(json_response, dict):
|
||||
return actual_dataclass(**json_response)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected dictionary in response but got {type(json_response)}"
|
||||
)
|
||||
else:
|
||||
return json_response
|
0
dbgpt/util/benchmarks/__init__.py
Normal file
0
dbgpt/util/benchmarks/__init__.py
Normal file
0
dbgpt/util/benchmarks/llm/__init__.py
Normal file
0
dbgpt/util/benchmarks/llm/__init__.py
Normal file
296
dbgpt/util/benchmarks/llm/fastchat_benchmarks_inference.py
Normal file
296
dbgpt/util/benchmarks/llm/fastchat_benchmarks_inference.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""
|
||||
Adapted from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py.
|
||||
For benchmarks.
|
||||
|
||||
"""
|
||||
import gc
|
||||
from typing import Iterable, Dict
|
||||
|
||||
import torch
|
||||
from transformers.generation.logits_process import (
|
||||
LogitsProcessorList,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
)
|
||||
|
||||
|
||||
from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length
|
||||
|
||||
|
||||
def prepare_logits_processor(
|
||||
temperature: float, repetition_penalty: float, top_p: float, top_k: int
|
||||
) -> LogitsProcessorList:
|
||||
processor_list = LogitsProcessorList()
|
||||
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
|
||||
if temperature >= 1e-5 and temperature != 1.0:
|
||||
processor_list.append(TemperatureLogitsWarper(temperature))
|
||||
if repetition_penalty > 1.0:
|
||||
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
|
||||
if 1e-8 <= top_p < 1.0:
|
||||
processor_list.append(TopPLogitsWarper(top_p))
|
||||
if top_k > 0:
|
||||
processor_list.append(TopKLogitsWarper(top_k))
|
||||
return processor_list
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_stream(
|
||||
model,
|
||||
tokenizer,
|
||||
params: Dict,
|
||||
device: str,
|
||||
context_len: int,
|
||||
stream_interval: int = 1,
|
||||
judge_sent_end: bool = False,
|
||||
):
|
||||
if hasattr(model, "device"):
|
||||
device = model.device
|
||||
|
||||
# Read parameters
|
||||
prompt = params["prompt"]
|
||||
len_prompt = len(prompt)
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
repetition_penalty = float(params.get("repetition_penalty", 1.0))
|
||||
top_p = float(params.get("top_p", 1.0))
|
||||
top_k = int(params.get("top_k", -1)) # -1 means disable
|
||||
max_new_tokens = int(params.get("max_new_tokens", 256))
|
||||
logprobs = params.get("logprobs", None) # FIXME: Support logprobs>1.
|
||||
echo = bool(params.get("echo", True))
|
||||
stop_str = params.get("stop", None)
|
||||
stop_token_ids = params.get("stop_token_ids", None) or []
|
||||
if tokenizer.eos_token_id not in stop_token_ids:
|
||||
stop_token_ids.append(tokenizer.eos_token_id)
|
||||
|
||||
logits_processor = prepare_logits_processor(
|
||||
temperature, repetition_penalty, top_p, top_k
|
||||
)
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
max_src_len = context_len
|
||||
else: # truncate
|
||||
max_src_len = context_len - max_new_tokens - 1
|
||||
|
||||
input_ids = input_ids[-max_src_len:]
|
||||
output_ids = list(input_ids)
|
||||
input_echo_len = len(input_ids)
|
||||
|
||||
# Don't stop generate until max_new_tokens is reached.
|
||||
stop_token_ids = []
|
||||
stop_str = None
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
if logprobs is not None: # FIXME: Support logprobs for encoder-decoder models.
|
||||
raise NotImplementedError
|
||||
encoder_output = model.encoder(
|
||||
input_ids=torch.as_tensor([input_ids], device=device)
|
||||
)[0]
|
||||
start_ids = torch.as_tensor(
|
||||
[[model.generation_config.decoder_start_token_id]],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
start_ids = torch.as_tensor([input_ids], device=device)
|
||||
|
||||
past_key_values = out = None
|
||||
token_logprobs = [None] # The first token has no logprobs.
|
||||
sent_interrupt = False
|
||||
finish_reason = None
|
||||
for i in range(max_new_tokens):
|
||||
if i == 0: # prefill
|
||||
if model.config.is_encoder_decoder:
|
||||
out = model.decoder(
|
||||
input_ids=start_ids,
|
||||
encoder_hidden_states=encoder_output,
|
||||
use_cache=True,
|
||||
)
|
||||
logits = model.lm_head(out[0])
|
||||
else:
|
||||
out = model(input_ids=start_ids, use_cache=True)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
|
||||
if logprobs is not None:
|
||||
# Prefull logprobs for the prompt.
|
||||
shift_input_ids = start_ids[..., 1:].contiguous()
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist()
|
||||
for label_id, logit in zip(
|
||||
shift_input_ids[0].tolist(), shift_logits[0]
|
||||
):
|
||||
token_logprobs.append(logit[label_id])
|
||||
else: # decoding
|
||||
if model.config.is_encoder_decoder:
|
||||
out = model.decoder(
|
||||
input_ids=torch.as_tensor(
|
||||
[[token] if not sent_interrupt else output_ids],
|
||||
device=device,
|
||||
),
|
||||
encoder_hidden_states=encoder_output,
|
||||
use_cache=True,
|
||||
past_key_values=past_key_values if not sent_interrupt else None,
|
||||
)
|
||||
sent_interrupt = False
|
||||
|
||||
logits = model.lm_head(out[0])
|
||||
else:
|
||||
out = model(
|
||||
input_ids=torch.as_tensor(
|
||||
[[token] if not sent_interrupt else output_ids],
|
||||
device=device,
|
||||
),
|
||||
use_cache=True,
|
||||
past_key_values=past_key_values if not sent_interrupt else None,
|
||||
)
|
||||
sent_interrupt = False
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
|
||||
if logits_processor:
|
||||
if repetition_penalty > 1.0:
|
||||
tmp_output_ids = torch.as_tensor([output_ids], device=logits.device)
|
||||
else:
|
||||
tmp_output_ids = None
|
||||
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
|
||||
else:
|
||||
last_token_logits = logits[0, -1, :]
|
||||
|
||||
if device == "mps":
|
||||
# Switch to CPU by avoiding some bugs in mps backend.
|
||||
last_token_logits = last_token_logits.float().to("cpu")
|
||||
|
||||
if temperature < 1e-5 or top_p < 1e-8: # greedy
|
||||
_, indices = torch.topk(last_token_logits, 2)
|
||||
tokens = [int(index) for index in indices.tolist()]
|
||||
else:
|
||||
probs = torch.softmax(last_token_logits, dim=-1)
|
||||
indices = torch.multinomial(probs, num_samples=2)
|
||||
tokens = [int(token) for token in indices.tolist()]
|
||||
token = tokens[0]
|
||||
output_ids.append(token)
|
||||
if logprobs is not None:
|
||||
# Cannot use last_token_logits because logprobs is based on raw logits.
|
||||
token_logprobs.append(
|
||||
torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist()
|
||||
)
|
||||
|
||||
if token in stop_token_ids:
|
||||
stopped = True
|
||||
else:
|
||||
stopped = False
|
||||
|
||||
# Yield the output tokens
|
||||
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
|
||||
if echo:
|
||||
tmp_output_ids = output_ids
|
||||
rfind_start = len_prompt
|
||||
else:
|
||||
tmp_output_ids = output_ids[input_echo_len:]
|
||||
rfind_start = 0
|
||||
|
||||
output = tokenizer.decode(
|
||||
tmp_output_ids,
|
||||
skip_special_tokens=True,
|
||||
spaces_between_special_tokens=False,
|
||||
clean_up_tokenization_spaces=True,
|
||||
)
|
||||
ret_logprobs = None
|
||||
if logprobs is not None:
|
||||
ret_logprobs = {
|
||||
"text_offset": [],
|
||||
"tokens": [
|
||||
tokenizer.decode(token)
|
||||
for token in (
|
||||
output_ids if echo else output_ids[input_echo_len:]
|
||||
)
|
||||
],
|
||||
"token_logprobs": token_logprobs
|
||||
if echo
|
||||
else token_logprobs[input_echo_len:],
|
||||
"top_logprobs": [{}]
|
||||
* len(token_logprobs if echo else token_logprobs[input_echo_len:]),
|
||||
}
|
||||
# Compute text_offset
|
||||
curr_pos = 0
|
||||
for text in ret_logprobs["tokens"]:
|
||||
ret_logprobs["text_offset"].append(curr_pos)
|
||||
curr_pos += len(text)
|
||||
|
||||
# TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way
|
||||
if judge_sent_end and stopped and not is_sentence_complete(output):
|
||||
if len(tokens) > 1:
|
||||
token = tokens[1]
|
||||
output_ids[-1] = token
|
||||
else:
|
||||
output_ids.pop()
|
||||
stopped = False
|
||||
sent_interrupt = True
|
||||
|
||||
partially_stopped = False
|
||||
if stop_str:
|
||||
if isinstance(stop_str, str):
|
||||
pos = output.rfind(stop_str, rfind_start)
|
||||
if pos != -1:
|
||||
output = output[:pos]
|
||||
stopped = True
|
||||
else:
|
||||
partially_stopped = is_partial_stop(output, stop_str)
|
||||
elif isinstance(stop_str, Iterable):
|
||||
for each_stop in stop_str:
|
||||
pos = output.rfind(each_stop, rfind_start)
|
||||
if pos != -1:
|
||||
output = output[:pos]
|
||||
stopped = True
|
||||
break
|
||||
else:
|
||||
partially_stopped = is_partial_stop(output, each_stop)
|
||||
if partially_stopped:
|
||||
break
|
||||
else:
|
||||
raise ValueError("Invalid stop field type.")
|
||||
|
||||
# Prevent yielding partial stop sequence
|
||||
if not partially_stopped:
|
||||
yield {
|
||||
"text": output,
|
||||
"logprobs": ret_logprobs,
|
||||
"usage": {
|
||||
"prompt_tokens": input_echo_len,
|
||||
"completion_tokens": i,
|
||||
"total_tokens": input_echo_len + i,
|
||||
},
|
||||
"finish_reason": None,
|
||||
}
|
||||
|
||||
if stopped:
|
||||
break
|
||||
|
||||
# Finish stream event, which contains finish reason
|
||||
else:
|
||||
finish_reason = "length"
|
||||
|
||||
if stopped:
|
||||
finish_reason = "stop"
|
||||
|
||||
yield {
|
||||
"text": output,
|
||||
"logprobs": ret_logprobs,
|
||||
"usage": {
|
||||
"prompt_tokens": input_echo_len,
|
||||
"completion_tokens": i,
|
||||
"total_tokens": input_echo_len + i,
|
||||
},
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
|
||||
# Clean
|
||||
del past_key_values, out
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
if device == "xpu":
|
||||
torch.xpu.empty_cache()
|
||||
if device == "npu":
|
||||
torch.npu.empty_cache()
|
282
dbgpt/util/benchmarks/llm/llm_benchmarks.py
Normal file
282
dbgpt/util/benchmarks/llm/llm_benchmarks.py
Normal file
@@ -0,0 +1,282 @@
|
||||
from typing import Dict, List
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import csv
|
||||
import argparse
|
||||
import logging
|
||||
import traceback
|
||||
from dbgpt.configs.model_config import ROOT_PATH, LLM_MODEL_CONFIG
|
||||
from datetime import datetime
|
||||
|
||||
from dbgpt.model.cluster.worker.manager import (
|
||||
run_worker_manager,
|
||||
initialize_worker_manager_in_client,
|
||||
WorkerManager,
|
||||
)
|
||||
|
||||
from dbgpt.core import ModelOutput, ModelInferenceMetrics
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
|
||||
|
||||
model_name = "vicuna-7b-v1.5"
|
||||
model_path = LLM_MODEL_CONFIG[model_name]
|
||||
# or vllm
|
||||
model_type = "huggingface"
|
||||
|
||||
controller_addr = "http://127.0.0.1:5000"
|
||||
|
||||
result_csv_file = None
|
||||
|
||||
parallel_nums = [1, 2, 4, 16, 32]
|
||||
# parallel_nums = [1, 2, 4]
|
||||
|
||||
|
||||
def get_result_csv_file() -> str:
|
||||
return os.path.join(
|
||||
ROOT_PATH, f"pilot/data/{model_name}_{model_type}_benchmarks_llm.csv"
|
||||
)
|
||||
|
||||
|
||||
input_lens = [64, 64]
|
||||
output_lens = [256, 512]
|
||||
|
||||
|
||||
prompt_file_map = {
|
||||
"11k": os.path.join(
|
||||
ROOT_PATH, "docker/examples/benchmarks/benchmarks_llm_11k_prompt.txt"
|
||||
)
|
||||
}
|
||||
|
||||
METRICS_HEADERS = [
|
||||
# Params
|
||||
"model_name",
|
||||
"gpu_nums",
|
||||
"parallel_nums",
|
||||
"input_length",
|
||||
"output_length",
|
||||
# Merge parallel result
|
||||
"test_time_cost_ms",
|
||||
"test_total_tokens",
|
||||
# avg_test_speed_per_second: (tokens / s), test_total_tokens / (test_time_cost_ms / 1000.0)
|
||||
"avg_test_speed_per_second(tokens/s)",
|
||||
# avg_first_token_latency_ms: sum(first_token_time_ms) / parallel_nums
|
||||
"avg_first_token_latency_ms",
|
||||
# avg_latency_ms: sum(end_time_ms - start_time_ms) / parallel_nums
|
||||
"avg_latency_ms",
|
||||
"gpu_mem(GiB)",
|
||||
# Detail for each task
|
||||
"start_time_ms",
|
||||
"end_time_ms",
|
||||
"current_time_ms",
|
||||
"first_token_time_ms",
|
||||
"first_completion_time_ms",
|
||||
"first_completion_tokens",
|
||||
"prompt_tokens",
|
||||
"completion_tokens",
|
||||
"total_tokens",
|
||||
"speed_per_second",
|
||||
]
|
||||
|
||||
|
||||
def read_prompt_from_file(file_key: str) -> str:
|
||||
full_path = prompt_file_map[file_key]
|
||||
with open(full_path, "r+", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def build_param(
|
||||
input_len: int,
|
||||
output_len: int,
|
||||
user_input: str,
|
||||
system_prompt: str = None,
|
||||
) -> Dict:
|
||||
hist = []
|
||||
if system_prompt is not None:
|
||||
hist.append(
|
||||
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=system_prompt)
|
||||
)
|
||||
hist.append(ModelMessage(role=ModelMessageRoleType.HUMAN, content=user_input))
|
||||
hist = list(h.dict() for h in hist)
|
||||
context_len = input_len + output_len + 2
|
||||
params = {
|
||||
"prompt": user_input,
|
||||
"messages": hist,
|
||||
"model": model_name,
|
||||
"echo": False,
|
||||
"max_new_tokens": output_len,
|
||||
"context_len": context_len,
|
||||
}
|
||||
return params
|
||||
|
||||
|
||||
async def run_batch(
|
||||
wh: WorkerManager,
|
||||
input_len: int,
|
||||
output_len: int,
|
||||
parallel_num: int,
|
||||
output_file: str,
|
||||
):
|
||||
tasks = []
|
||||
prompt = read_prompt_from_file("11k")
|
||||
if model_type == "vllm":
|
||||
max_input_str_len = input_len
|
||||
if "baichuan" in model_name:
|
||||
# TODO prompt handle first
|
||||
max_input_str_len *= 2
|
||||
prompt = prompt[-max_input_str_len:]
|
||||
|
||||
# Warmup first
|
||||
params = build_param(input_len, output_len, prompt, system_prompt="")
|
||||
await wh.generate(params)
|
||||
|
||||
for _ in range(parallel_num):
|
||||
params = build_param(input_len, output_len, prompt, system_prompt="")
|
||||
tasks.append(wh.generate(params))
|
||||
print(
|
||||
f"Begin run benchmarks, model name: {model_name}, input_len: {input_len}, output_len: {output_len}, parallel_num: {parallel_num}, save result to {output_file}"
|
||||
)
|
||||
start_time_ms = time.time_ns() // 1_000_000
|
||||
results: List[ModelOutput] = await asyncio.gather(*tasks)
|
||||
end_time_ms = time.time_ns() // 1_000_000
|
||||
|
||||
test_time_cost_ms = end_time_ms - start_time_ms
|
||||
test_total_tokens = 0
|
||||
first_token_latency_ms = 0
|
||||
latency_ms = 0
|
||||
gpu_nums = 0
|
||||
avg_gpu_mem = 0
|
||||
rows = []
|
||||
for r in results:
|
||||
metrics = r.metrics
|
||||
if isinstance(metrics, dict):
|
||||
metrics = ModelInferenceMetrics(**metrics)
|
||||
print(r)
|
||||
test_total_tokens += metrics.total_tokens
|
||||
first_token_latency_ms += metrics.first_token_time_ms - metrics.start_time_ms
|
||||
latency_ms += metrics.end_time_ms - metrics.start_time_ms
|
||||
row_data = metrics.to_dict()
|
||||
del row_data["collect_index"]
|
||||
if "avg_gpu_infos" in row_data:
|
||||
avg_gpu_infos = row_data["avg_gpu_infos"]
|
||||
gpu_nums = len(avg_gpu_infos)
|
||||
avg_gpu_mem = (
|
||||
sum(i["allocated_memory_gb"] for i in avg_gpu_infos) / gpu_nums
|
||||
)
|
||||
del row_data["avg_gpu_infos"]
|
||||
del row_data["current_gpu_infos"]
|
||||
rows.append(row_data)
|
||||
avg_test_speed_per_second = test_total_tokens / (test_time_cost_ms / 1000.0)
|
||||
avg_first_token_latency_ms = first_token_latency_ms / len(results)
|
||||
avg_latency_ms = latency_ms / len(results)
|
||||
|
||||
with open(output_file, "a", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=METRICS_HEADERS)
|
||||
if f.tell() == 0:
|
||||
# Fist time
|
||||
writer.writeheader()
|
||||
for row in rows:
|
||||
row["model_name"] = model_name
|
||||
row["parallel_nums"] = parallel_num
|
||||
row["input_length"] = input_len
|
||||
row["output_length"] = output_len
|
||||
row["test_time_cost_ms"] = test_time_cost_ms
|
||||
row["test_total_tokens"] = test_total_tokens
|
||||
row["avg_test_speed_per_second(tokens/s)"] = avg_test_speed_per_second
|
||||
row["avg_first_token_latency_ms"] = avg_first_token_latency_ms
|
||||
row["avg_latency_ms"] = avg_latency_ms
|
||||
row["gpu_nums"] = gpu_nums
|
||||
row["gpu_mem(GiB)"] = avg_gpu_mem
|
||||
writer.writerow(row)
|
||||
print(
|
||||
f"input_len: {input_len}, output_len: {output_len}, parallel_num: {parallel_num}, save result to {output_file}"
|
||||
)
|
||||
|
||||
|
||||
async def run_model(wh: WorkerManager) -> None:
|
||||
global result_csv_file
|
||||
if not result_csv_file:
|
||||
result_csv_file = get_result_csv_file()
|
||||
if os.path.exists(result_csv_file):
|
||||
now = datetime.now()
|
||||
now_str = now.strftime("%Y-%m-%d")
|
||||
os.rename(result_csv_file, f"{result_csv_file}.bak_{now_str}.csv")
|
||||
for parallel_num in parallel_nums:
|
||||
for input_len, output_len in zip(input_lens, output_lens):
|
||||
try:
|
||||
await run_batch(
|
||||
wh, input_len, output_len, parallel_num, result_csv_file
|
||||
)
|
||||
except Exception:
|
||||
msg = traceback.format_exc()
|
||||
logging.error(
|
||||
f"Run benchmarks error, input_len: {input_len}, output_len: {output_len}, parallel_num: {parallel_num}, error message: {msg}"
|
||||
)
|
||||
if "torch.cuda.OutOfMemoryError" in msg:
|
||||
return
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def startup_llm_env():
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
initialize_worker_manager_in_client(
|
||||
app=app,
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
run_locally=False,
|
||||
controller_addr=controller_addr,
|
||||
local_port=6000,
|
||||
start_listener=run_model,
|
||||
)
|
||||
|
||||
|
||||
def connect_to_remote_model():
|
||||
startup_llm_env()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_name", type=str, default=model_name)
|
||||
parser.add_argument("--model_path", type=str, default=None)
|
||||
parser.add_argument("--model_type", type=str, default="huggingface")
|
||||
parser.add_argument("--result_csv_file", type=str, default=None)
|
||||
parser.add_argument("--input_lens", type=str, default="8,8,256,1024")
|
||||
parser.add_argument("--output_lens", type=str, default="256,512,1024,1024")
|
||||
parser.add_argument("--parallel_nums", type=str, default="1,2,4,16,32")
|
||||
parser.add_argument(
|
||||
"--remote_model", type=bool, default=False, help="Connect to remote model"
|
||||
)
|
||||
parser.add_argument("--controller_addr", type=str, default="http://127.0.0.1:8000")
|
||||
parser.add_argument("--limit_model_concurrency", type=int, default=200)
|
||||
|
||||
args = parser.parse_args()
|
||||
print(f"args: {args}")
|
||||
model_name = args.model_name
|
||||
model_path = args.model_path or LLM_MODEL_CONFIG[model_name]
|
||||
result_csv_file = args.result_csv_file
|
||||
input_lens = [int(i) for i in args.input_lens.strip().split(",")]
|
||||
output_lens = [int(i) for i in args.output_lens.strip().split(",")]
|
||||
parallel_nums = [int(i) for i in args.parallel_nums.strip().split(",")]
|
||||
remote_model = args.remote_model
|
||||
controller_addr = args.controller_addr
|
||||
limit_model_concurrency = args.limit_model_concurrency
|
||||
model_type = args.model_type
|
||||
if len(input_lens) != len(output_lens):
|
||||
raise ValueError("input_lens size must equal output_lens size")
|
||||
|
||||
if remote_model:
|
||||
# Connect to remote model and run benchmarks
|
||||
connect_to_remote_model()
|
||||
else:
|
||||
# Start worker manager and run benchmarks
|
||||
run_worker_manager(
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
start_listener=run_model,
|
||||
limit_model_concurrency=limit_model_concurrency,
|
||||
model_type=model_type,
|
||||
)
|
162
dbgpt/util/command_utils.py
Normal file
162
dbgpt/util/command_utils.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import sys
|
||||
import os
|
||||
import subprocess
|
||||
from typing import List, Dict
|
||||
import psutil
|
||||
import platform
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
def _get_abspath_of_current_command(command_path: str):
|
||||
if not command_path.endswith(".py"):
|
||||
return command_path
|
||||
# This implementation is very ugly
|
||||
command_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
||||
"scripts",
|
||||
"cli_scripts.py",
|
||||
)
|
||||
return command_path
|
||||
|
||||
|
||||
def _run_current_with_daemon(name: str, log_file: str):
|
||||
# Get all arguments except for --daemon
|
||||
args = [arg for arg in sys.argv if arg != "--daemon" and arg != "-d"]
|
||||
args[0] = _get_abspath_of_current_command(args[0])
|
||||
|
||||
daemon_cmd = [sys.executable] + args
|
||||
daemon_cmd = " ".join(daemon_cmd)
|
||||
daemon_cmd += f" > {log_file} 2>&1"
|
||||
|
||||
print(f"daemon cmd: {daemon_cmd}")
|
||||
# Check the platform and set the appropriate flags or functions
|
||||
if "windows" in platform.system().lower():
|
||||
process = subprocess.Popen(
|
||||
daemon_cmd,
|
||||
creationflags=subprocess.CREATE_NEW_PROCESS_GROUP,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
shell=True,
|
||||
)
|
||||
else: # macOS, Linux, and other Unix-like systems
|
||||
process = subprocess.Popen(
|
||||
daemon_cmd,
|
||||
preexec_fn=os.setsid,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
shell=True,
|
||||
)
|
||||
|
||||
print(f"Started {name} in background with pid: {process.pid}")
|
||||
|
||||
|
||||
def _run_current_with_gunicorn(app: str, config_path: str, kwargs: Dict):
|
||||
try:
|
||||
import gunicorn
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
"Could not import python package: gunicorn"
|
||||
"Daemon mode need install gunicorn, please install `pip install gunicorn`"
|
||||
) from e
|
||||
|
||||
from dbgpt.util.parameter_utils import EnvArgumentParser
|
||||
|
||||
env_to_app = {}
|
||||
env_to_app.update(os.environ)
|
||||
app_env = EnvArgumentParser._kwargs_to_env_key_value(kwargs)
|
||||
env_to_app.update(app_env)
|
||||
cmd = f"uvicorn {app} --host 0.0.0.0 --port 5000"
|
||||
if "windows" in platform.system().lower():
|
||||
raise Exception("Not support on windows")
|
||||
else: # macOS, Linux, and other Unix-like systems
|
||||
process = subprocess.Popen(cmd, shell=True, env=env_to_app)
|
||||
print(f"Started {app} with gunicorn in background with pid: {process.pid}")
|
||||
|
||||
|
||||
def _stop_service(
|
||||
key: str, fullname: str, service_keys: List[str] = None, port: int = None
|
||||
):
|
||||
if not service_keys:
|
||||
service_keys = [sys.argv[0], "start", key]
|
||||
not_found = True
|
||||
for process in psutil.process_iter(attrs=["pid", "datasource", "cmdline"]):
|
||||
try:
|
||||
cmdline = " ".join(process.info["cmdline"])
|
||||
|
||||
# Check if all key fragments are in the cmdline
|
||||
if all(fragment in cmdline for fragment in service_keys):
|
||||
if port:
|
||||
for conn in process.info["datasource"]:
|
||||
if (
|
||||
conn.status == psutil.CONN_LISTEN
|
||||
and conn.laddr.port == port
|
||||
):
|
||||
psutil.Process(process.info["pid"]).terminate()
|
||||
print(
|
||||
f"Terminated the {fullname} with PID: {process.info['pid']} listening on port: {port}"
|
||||
)
|
||||
not_found = False
|
||||
else:
|
||||
psutil.Process(process.info["pid"]).terminate()
|
||||
print(f"Terminated the {fullname} with PID: {process.info['pid']}")
|
||||
not_found = False
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
continue
|
||||
|
||||
if not_found:
|
||||
print(f"{fullname} process not found.")
|
||||
|
||||
|
||||
def _get_ports_by_cmdline_part(service_keys: List[str]) -> List[int]:
|
||||
"""
|
||||
Return a list of ports that are associated with processes that have all the service_keys in their cmdline.
|
||||
|
||||
Args:
|
||||
service_keys (List[str]): List of strings that should all be present in the process's cmdline.
|
||||
|
||||
Returns:
|
||||
List[int]: List of ports sorted with preference for 8000 and 5000, and then in ascending order.
|
||||
"""
|
||||
ports = []
|
||||
|
||||
for process in psutil.process_iter(attrs=["pid", "name", "cmdline", "connections"]):
|
||||
try:
|
||||
# Convert the cmdline list to a single string for easier checking
|
||||
cmdline = ""
|
||||
if process.info.get("cmdline"):
|
||||
cmdline = " ".join(process.info["cmdline"])
|
||||
|
||||
# Check if all the service keys are present in the cmdline
|
||||
if cmdline and all(fragment in cmdline for fragment in service_keys):
|
||||
connections = process.info.get("connections")
|
||||
if connections is not None and len(ports) == 0:
|
||||
for connection in connections:
|
||||
if connection.status == psutil.CONN_LISTEN:
|
||||
ports.append(connection.laddr.port)
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
||||
pass
|
||||
|
||||
# Sort ports with preference for 8000 and 5000
|
||||
ports.sort(key=lambda x: (x != 8000, x != 5000, x))
|
||||
return ports
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def _detect_controller_address() -> str:
|
||||
controller_addr = os.getenv("CONTROLLER_ADDRESS")
|
||||
if controller_addr:
|
||||
return controller_addr
|
||||
|
||||
cmdline_fragments = [
|
||||
["python", "start", "controller"],
|
||||
["python", "controller"],
|
||||
["python", "start", "webserver"],
|
||||
["python", "dbgpt_server"],
|
||||
]
|
||||
|
||||
for fragments in cmdline_fragments:
|
||||
ports = _get_ports_by_cmdline_part(fragments)
|
||||
if ports:
|
||||
return f"http://127.0.0.1:{ports[0]}"
|
||||
|
||||
return f"http://127.0.0.1:8000"
|
34
dbgpt/util/custom_data_structure.py
Normal file
34
dbgpt/util/custom_data_structure.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from collections import OrderedDict
|
||||
from collections import deque
|
||||
|
||||
|
||||
class FixedSizeDict(OrderedDict):
|
||||
def __init__(self, max_size):
|
||||
super().__init__()
|
||||
self.max_size = max_size
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if len(self) >= self.max_size:
|
||||
self.popitem(last=False)
|
||||
super().__setitem__(key, value)
|
||||
|
||||
|
||||
class FixedSizeList:
|
||||
def __init__(self, max_size):
|
||||
self.max_size = max_size
|
||||
self.list = deque(maxlen=max_size)
|
||||
|
||||
def append(self, value):
|
||||
self.list.append(value)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.list[index]
|
||||
|
||||
def __setitem__(self, index, value):
|
||||
self.list[index] = value
|
||||
|
||||
def __len__(self):
|
||||
return len(self.list)
|
||||
|
||||
def __str__(self):
|
||||
return str(list(self.list))
|
87
dbgpt/util/executor_utils.py
Normal file
87
dbgpt/util/executor_utils.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from typing import Callable, Awaitable, Any
|
||||
import asyncio
|
||||
import contextvars
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from functools import partial
|
||||
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
|
||||
|
||||
class ExecutorFactory(BaseComponent, ABC):
|
||||
name = ComponentType.EXECUTOR_DEFAULT.value
|
||||
|
||||
@abstractmethod
|
||||
def create(self) -> "Executor":
|
||||
"""Create executor"""
|
||||
|
||||
|
||||
class DefaultExecutorFactory(ExecutorFactory):
|
||||
def __init__(self, system_app: SystemApp | None = None, max_workers=None):
|
||||
super().__init__(system_app)
|
||||
self._executor = ThreadPoolExecutor(
|
||||
max_workers=max_workers, thread_name_prefix=self.name
|
||||
)
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
pass
|
||||
|
||||
def create(self) -> Executor:
|
||||
return self._executor
|
||||
|
||||
|
||||
BlockingFunction = Callable[..., Any]
|
||||
|
||||
|
||||
async def blocking_func_to_async(
|
||||
executor: Executor, func: BlockingFunction, *args, **kwargs
|
||||
):
|
||||
"""Run a potentially blocking function within an executor.
|
||||
|
||||
Args:
|
||||
executor (Executor): The concurrent.futures.Executor to run the function within.
|
||||
func (ApplyFunction): The callable function, which should be a synchronous function.
|
||||
It should accept any number and type of arguments and return an asynchronous coroutine.
|
||||
*args (Any): Any additional arguments to pass to the function.
|
||||
**kwargs (Any): Other arguments to pass to the function
|
||||
|
||||
Returns:
|
||||
Any: The result of the function's execution.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provided function 'func' is an asynchronous coroutine function.
|
||||
|
||||
This function allows you to execute a potentially blocking function within an executor.
|
||||
It expects 'func' to be a synchronous function and will raise an error if 'func' is an asynchronous coroutine.
|
||||
"""
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
raise ValueError(f"The function {func} is not blocking function")
|
||||
|
||||
# This function will be called within the new thread, capturing the current context
|
||||
ctx = contextvars.copy_context()
|
||||
|
||||
def run_with_context():
|
||||
return ctx.run(partial(func, *args, **kwargs))
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(executor, run_with_context)
|
||||
|
||||
|
||||
class AsyncToSyncIterator:
|
||||
def __init__(self, async_iterable, loop: asyncio.BaseEventLoop):
|
||||
self.async_iterable = async_iterable
|
||||
self.async_iterator = None
|
||||
self._loop = loop
|
||||
|
||||
def __iter__(self):
|
||||
self.async_iterator = self.async_iterable.__aiter__()
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.async_iterator is None:
|
||||
raise StopIteration
|
||||
|
||||
try:
|
||||
return self._loop.run_until_complete(self.async_iterator.__anext__())
|
||||
except StopAsyncIteration:
|
||||
raise StopIteration
|
61
dbgpt/util/formatting.py
Normal file
61
dbgpt/util/formatting.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Utilities for formatting strings."""
|
||||
import json
|
||||
from string import Formatter
|
||||
from typing import Any, List, Mapping, Sequence, Union
|
||||
|
||||
|
||||
class StrictFormatter(Formatter):
|
||||
"""A subclass of formatter that checks for extra keys."""
|
||||
|
||||
def check_unused_args(
|
||||
self,
|
||||
used_args: Sequence[Union[int, str]],
|
||||
args: Sequence,
|
||||
kwargs: Mapping[str, Any],
|
||||
) -> None:
|
||||
"""Check to see if extra parameters are passed."""
|
||||
extra = set(kwargs).difference(used_args)
|
||||
if extra:
|
||||
raise KeyError(extra)
|
||||
|
||||
def vformat(
|
||||
self, format_string: str, args: Sequence, kwargs: Mapping[str, Any]
|
||||
) -> str:
|
||||
"""Check that no arguments are provided."""
|
||||
if len(args) > 0:
|
||||
raise ValueError(
|
||||
"No arguments should be provided, "
|
||||
"everything should be passed as keyword arguments."
|
||||
)
|
||||
return super().vformat(format_string, args, kwargs)
|
||||
|
||||
def validate_input_variables(
|
||||
self, format_string: str, input_variables: List[str]
|
||||
) -> None:
|
||||
dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
|
||||
super().format(format_string, **dummy_inputs)
|
||||
|
||||
|
||||
class NoStrictFormatter(StrictFormatter):
|
||||
def check_unused_args(
|
||||
self,
|
||||
used_args: Sequence[Union[int, str]],
|
||||
args: Sequence,
|
||||
kwargs: Mapping[str, Any],
|
||||
) -> None:
|
||||
"""Not check unused args"""
|
||||
pass
|
||||
|
||||
|
||||
formatter = StrictFormatter()
|
||||
no_strict_formatter = NoStrictFormatter()
|
||||
|
||||
|
||||
class MyEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, set):
|
||||
return list(obj)
|
||||
elif hasattr(obj, "__dict__"):
|
||||
return obj.__dict__
|
||||
else:
|
||||
return json.JSONEncoder.default(self, obj)
|
448
dbgpt/util/global_helper.py
Normal file
448
dbgpt/util/global_helper.py
Normal file
@@ -0,0 +1,448 @@
|
||||
"""General util functions."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import partial, wraps
|
||||
from itertools import islice
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
|
||||
class GlobalsHelper:
|
||||
"""Helper to retrieve globals.
|
||||
|
||||
Helpful for global caching of certain variables that can be expensive to load.
|
||||
(e.g. tokenization)
|
||||
|
||||
"""
|
||||
|
||||
_tokenizer: Optional[Callable[[str], List]] = None
|
||||
_stopwords: Optional[List[str]] = None
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> Callable[[str], List]:
|
||||
"""Get tokenizer."""
|
||||
if self._tokenizer is None:
|
||||
tiktoken_import_err = (
|
||||
"`tiktoken` package not found, please run `pip install tiktoken`"
|
||||
)
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
raise ImportError(tiktoken_import_err)
|
||||
enc = tiktoken.get_encoding("gpt2")
|
||||
self._tokenizer = cast(Callable[[str], List], enc.encode)
|
||||
self._tokenizer = partial(self._tokenizer, allowed_special="all")
|
||||
return self._tokenizer # type: ignore
|
||||
|
||||
@property
|
||||
def stopwords(self) -> List[str]:
|
||||
"""Get stopwords."""
|
||||
if self._stopwords is None:
|
||||
try:
|
||||
import nltk
|
||||
from nltk.corpus import stopwords
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`nltk` package not found, please run `pip install nltk`"
|
||||
)
|
||||
|
||||
from llama_index.utils import get_cache_dir
|
||||
|
||||
cache_dir = get_cache_dir()
|
||||
nltk_data_dir = os.environ.get("NLTK_DATA", cache_dir)
|
||||
|
||||
# update nltk path for nltk so that it finds the data
|
||||
if nltk_data_dir not in nltk.data.path:
|
||||
nltk.data.path.append(nltk_data_dir)
|
||||
|
||||
try:
|
||||
nltk.data.find("corpora/stopwords")
|
||||
except LookupError:
|
||||
nltk.download("stopwords", download_dir=nltk_data_dir)
|
||||
self._stopwords = stopwords.words("english")
|
||||
return self._stopwords
|
||||
|
||||
|
||||
globals_helper = GlobalsHelper()
|
||||
|
||||
|
||||
def get_new_id(d: Set) -> str:
|
||||
"""Get a new ID."""
|
||||
while True:
|
||||
new_id = str(uuid.uuid4())
|
||||
if new_id not in d:
|
||||
break
|
||||
return new_id
|
||||
|
||||
|
||||
def get_new_int_id(d: Set) -> int:
|
||||
"""Get a new integer ID."""
|
||||
while True:
|
||||
new_id = random.randint(0, sys.maxsize)
|
||||
if new_id not in d:
|
||||
break
|
||||
return new_id
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temp_set_attrs(obj: Any, **kwargs: Any) -> Generator:
|
||||
"""Temporary setter.
|
||||
|
||||
Utility class for setting a temporary value for an attribute on a class.
|
||||
Taken from: https://tinyurl.com/2p89xymh
|
||||
|
||||
"""
|
||||
prev_values = {k: getattr(obj, k) for k in kwargs}
|
||||
for k, v in kwargs.items():
|
||||
setattr(obj, k, v)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for k, v in prev_values.items():
|
||||
setattr(obj, k, v)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ErrorToRetry:
|
||||
"""Exception types that should be retried.
|
||||
|
||||
Args:
|
||||
exception_cls (Type[Exception]): Class of exception.
|
||||
check_fn (Optional[Callable[[Any]], bool]]):
|
||||
A function that takes an exception instance as input and returns
|
||||
whether to retry.
|
||||
|
||||
"""
|
||||
|
||||
exception_cls: Type[Exception]
|
||||
check_fn: Optional[Callable[[Any], bool]] = None
|
||||
|
||||
|
||||
def retry_on_exceptions_with_backoff(
|
||||
lambda_fn: Callable,
|
||||
errors_to_retry: List[ErrorToRetry],
|
||||
max_tries: int = 10,
|
||||
min_backoff_secs: float = 0.5,
|
||||
max_backoff_secs: float = 60.0,
|
||||
) -> Any:
|
||||
"""Execute lambda function with retries and exponential backoff.
|
||||
|
||||
Args:
|
||||
lambda_fn (Callable): Function to be called and output we want.
|
||||
errors_to_retry (List[ErrorToRetry]): List of errors to retry.
|
||||
At least one needs to be provided.
|
||||
max_tries (int): Maximum number of tries, including the first. Defaults to 10.
|
||||
min_backoff_secs (float): Minimum amount of backoff time between attempts.
|
||||
Defaults to 0.5.
|
||||
max_backoff_secs (float): Maximum amount of backoff time between attempts.
|
||||
Defaults to 60.
|
||||
|
||||
"""
|
||||
if not errors_to_retry:
|
||||
raise ValueError("At least one error to retry needs to be provided")
|
||||
|
||||
error_checks = {
|
||||
error_to_retry.exception_cls: error_to_retry.check_fn
|
||||
for error_to_retry in errors_to_retry
|
||||
}
|
||||
exception_class_tuples = tuple(error_checks.keys())
|
||||
|
||||
backoff_secs = min_backoff_secs
|
||||
tries = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
return lambda_fn()
|
||||
except exception_class_tuples as e:
|
||||
traceback.print_exc()
|
||||
tries += 1
|
||||
if tries >= max_tries:
|
||||
raise
|
||||
check_fn = error_checks.get(e.__class__)
|
||||
if check_fn and not check_fn(e):
|
||||
raise
|
||||
time.sleep(backoff_secs)
|
||||
backoff_secs = min(backoff_secs * 2, max_backoff_secs)
|
||||
|
||||
|
||||
def truncate_text(text: str, max_length: int) -> str:
|
||||
"""Truncate text to a maximum length."""
|
||||
if len(text) <= max_length:
|
||||
return text
|
||||
return text[: max_length - 3] + "..."
|
||||
|
||||
|
||||
def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable:
|
||||
"""Iterate over an iterable in batches.
|
||||
|
||||
>>> list(iter_batch([1,2,3,4,5], 3))
|
||||
[[1, 2, 3], [4, 5]]
|
||||
"""
|
||||
source_iter = iter(iterable)
|
||||
while source_iter:
|
||||
b = list(islice(source_iter, size))
|
||||
if len(b) == 0:
|
||||
break
|
||||
yield b
|
||||
|
||||
|
||||
def concat_dirs(dirname: str, basename: str) -> str:
|
||||
"""
|
||||
Append basename to dirname, avoiding backslashes when running on windows.
|
||||
|
||||
os.path.join(dirname, basename) will add a backslash before dirname if
|
||||
basename does not end with a slash, so we make sure it does.
|
||||
"""
|
||||
dirname += "/" if dirname[-1] != "/" else ""
|
||||
return os.path.join(dirname, basename)
|
||||
|
||||
|
||||
def get_tqdm_iterable(items: Iterable, show_progress: bool, desc: str) -> Iterable:
|
||||
"""
|
||||
Optionally get a tqdm iterable. Ensures tqdm.auto is used.
|
||||
"""
|
||||
_iterator = items
|
||||
if show_progress:
|
||||
try:
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
return tqdm(items, desc=desc)
|
||||
except ImportError:
|
||||
pass
|
||||
return _iterator
|
||||
|
||||
|
||||
def count_tokens(text: str) -> int:
|
||||
tokens = globals_helper.tokenizer(text)
|
||||
return len(tokens)
|
||||
|
||||
|
||||
def get_transformer_tokenizer_fn(model_name: str) -> Callable[[str], List[str]]:
|
||||
"""
|
||||
Args:
|
||||
model_name(str): the model name of the tokenizer.
|
||||
For instance, fxmarty/tiny-llama-fast-tokenizer.
|
||||
"""
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"`transformers` package not found, please run `pip install transformers`"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
return tokenizer.tokenize
|
||||
|
||||
|
||||
def get_cache_dir() -> str:
|
||||
"""Locate a platform-appropriate cache directory for llama_index,
|
||||
and create it if it doesn't yet exist.
|
||||
"""
|
||||
# User override
|
||||
if "LLAMA_INDEX_CACHE_DIR" in os.environ:
|
||||
path = Path(os.environ["LLAMA_INDEX_CACHE_DIR"])
|
||||
|
||||
# Linux, Unix, AIX, etc.
|
||||
elif os.name == "posix" and sys.platform != "darwin":
|
||||
path = Path("/tmp/llama_index")
|
||||
|
||||
# Mac OS
|
||||
elif sys.platform == "darwin":
|
||||
path = Path(os.path.expanduser("~"), "Library/Caches/llama_index")
|
||||
|
||||
# Windows (hopefully)
|
||||
else:
|
||||
local = os.environ.get("LOCALAPPDATA", None) or os.path.expanduser(
|
||||
"~\\AppData\\Local"
|
||||
)
|
||||
path = Path(local, "llama_index")
|
||||
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(
|
||||
path, exist_ok=True
|
||||
) # prevents https://github.com/jerryjliu/llama_index/issues/7362
|
||||
return str(path)
|
||||
|
||||
|
||||
def add_sync_version(func: Any) -> Any:
|
||||
"""Decorator for adding sync version of an async function. The sync version
|
||||
is added as a function attribute to the original function, func.
|
||||
|
||||
Args:
|
||||
func(Any): the async function for which a sync variant will be built.
|
||||
"""
|
||||
assert asyncio.iscoroutinefunction(func)
|
||||
|
||||
@wraps(func)
|
||||
def _wrapper(*args: Any, **kwds: Any) -> Any:
|
||||
return asyncio.get_event_loop().run_until_complete(func(*args, **kwds))
|
||||
|
||||
func.sync = _wrapper
|
||||
return func
|
||||
|
||||
|
||||
# Sample text from llama_index's readme
|
||||
SAMPLE_TEXT = """
|
||||
Context
|
||||
LLMs are a phenomenal piece of technology for knowledge generation and reasoning.
|
||||
They are pre-trained on large amounts of publicly available data.
|
||||
How do we best augment LLMs with our own private data?
|
||||
We need a comprehensive toolkit to help perform this data augmentation for LLMs.
|
||||
|
||||
Proposed Solution
|
||||
That's where LlamaIndex comes in. LlamaIndex is a "data framework" to help
|
||||
you build LLM apps. It provides the following tools:
|
||||
|
||||
Offers data connectors to ingest your existing data sources and data formats
|
||||
(APIs, PDFs, docs, SQL, etc.)
|
||||
Provides ways to structure your data (indices, graphs) so that this data can be
|
||||
easily used with LLMs.
|
||||
Provides an advanced retrieval/query interface over your data:
|
||||
Feed in any LLM input prompt, get back retrieved context and knowledge-augmented output.
|
||||
Allows easy integrations with your outer application framework
|
||||
(e.g. with LangChain, Flask, Docker, ChatGPT, anything else).
|
||||
LlamaIndex provides tools for both beginner users and advanced users.
|
||||
Our high-level API allows beginner users to use LlamaIndex to ingest and
|
||||
query their data in 5 lines of code. Our lower-level APIs allow advanced users to
|
||||
customize and extend any module (data connectors, indices, retrievers, query engines,
|
||||
reranking modules), to fit their needs.
|
||||
"""
|
||||
|
||||
_LLAMA_INDEX_COLORS = {
|
||||
"llama_pink": "38;2;237;90;200",
|
||||
"llama_blue": "38;2;90;149;237",
|
||||
"llama_turquoise": "38;2;11;159;203",
|
||||
"llama_lavender": "38;2;155;135;227",
|
||||
}
|
||||
|
||||
_ANSI_COLORS = {
|
||||
"red": "31",
|
||||
"green": "32",
|
||||
"yellow": "33",
|
||||
"blue": "34",
|
||||
"magenta": "35",
|
||||
"cyan": "36",
|
||||
"pink": "38;5;200",
|
||||
}
|
||||
|
||||
|
||||
def get_color_mapping(
|
||||
items: List[str], use_llama_index_colors: bool = True
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Get a mapping of items to colors.
|
||||
|
||||
Args:
|
||||
items (List[str]): List of items to be mapped to colors.
|
||||
use_llama_index_colors (bool, optional): Flag to indicate
|
||||
whether to use LlamaIndex colors or ANSI colors.
|
||||
Defaults to True.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: Mapping of items to colors.
|
||||
"""
|
||||
if use_llama_index_colors:
|
||||
color_palette = _LLAMA_INDEX_COLORS
|
||||
else:
|
||||
color_palette = _ANSI_COLORS
|
||||
|
||||
colors = list(color_palette.keys())
|
||||
return {item: colors[i % len(colors)] for i, item in enumerate(items)}
|
||||
|
||||
|
||||
def _get_colored_text(text: str, color: str) -> str:
|
||||
"""
|
||||
Get the colored version of the input text.
|
||||
|
||||
Args:
|
||||
text (str): Input text.
|
||||
color (str): Color to be applied to the text.
|
||||
|
||||
Returns:
|
||||
str: Colored version of the input text.
|
||||
"""
|
||||
all_colors = {**_LLAMA_INDEX_COLORS, **_ANSI_COLORS}
|
||||
|
||||
if color not in all_colors:
|
||||
return f"\033[1;3m{text}\033[0m" # just bolded and italicized
|
||||
|
||||
color = all_colors[color]
|
||||
|
||||
return f"\033[1;3;{color}m{text}\033[0m"
|
||||
|
||||
|
||||
def print_text(text: str, color: Optional[str] = None, end: str = "") -> None:
|
||||
"""
|
||||
Print the text with the specified color.
|
||||
|
||||
Args:
|
||||
text (str): Text to be printed.
|
||||
color (str, optional): Color to be applied to the text. Supported colors are:
|
||||
llama_pink, llama_blue, llama_turquoise, llama_lavender,
|
||||
red, green, yellow, blue, magenta, cyan, pink.
|
||||
end (str, optional): String appended after the last character of the text.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
text_to_print = _get_colored_text(text, color) if color is not None else text
|
||||
print(text_to_print, end=end)
|
||||
|
||||
|
||||
def infer_torch_device() -> str:
|
||||
"""Infer the input to torch.device."""
|
||||
try:
|
||||
has_cuda = torch.cuda.is_available()
|
||||
except NameError:
|
||||
import torch
|
||||
|
||||
has_cuda = torch.cuda.is_available()
|
||||
if has_cuda:
|
||||
return "cuda"
|
||||
if torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
return "cpu"
|
||||
|
||||
|
||||
def unit_generator(x: Any) -> Generator[Any, None, None]:
|
||||
"""A function that returns a generator of a single element.
|
||||
|
||||
Args:
|
||||
x (Any): the element to build yield
|
||||
|
||||
Yields:
|
||||
Any: the single element
|
||||
"""
|
||||
yield x
|
||||
|
||||
|
||||
async def async_unit_generator(x: Any) -> AsyncGenerator[Any, None]:
|
||||
"""A function that returns a generator of a single element.
|
||||
|
||||
Args:
|
||||
x (Any): the element to build yield
|
||||
|
||||
Yields:
|
||||
Any: the single element
|
||||
"""
|
||||
yield x
|
14
dbgpt/util/json_utils.py
Normal file
14
dbgpt/util/json_utils.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import json
|
||||
from datetime import date, datetime
|
||||
|
||||
|
||||
def serialize(obj):
|
||||
if isinstance(obj, date):
|
||||
return obj.isoformat()
|
||||
|
||||
|
||||
class DateTimeEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
return super().default(obj)
|
11
dbgpt/util/memory_utils.py
Normal file
11
dbgpt/util/memory_utils.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from typing import Any
|
||||
from pympler import asizeof
|
||||
|
||||
|
||||
def _get_object_bytes(obj: Any) -> int:
|
||||
"""Get the bytes of a object in memory
|
||||
|
||||
Args:
|
||||
obj (Any): The object to return the bytes
|
||||
"""
|
||||
return asizeof.asizeof(obj)
|
84
dbgpt/util/model_utils.py
Normal file
84
dbgpt/util/model_utils.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from typing import List, Tuple
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _clear_model_cache(device="cuda"):
|
||||
try:
|
||||
# clear torch cache
|
||||
import torch
|
||||
|
||||
_clear_torch_cache(device)
|
||||
except ImportError:
|
||||
logger.warn("Torch not installed, skip clear torch cache")
|
||||
# TODO clear other cache
|
||||
|
||||
|
||||
def _clear_torch_cache(device="cuda"):
|
||||
import torch
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
if device != "cpu":
|
||||
if torch.has_mps:
|
||||
try:
|
||||
from torch.mps import empty_cache
|
||||
|
||||
empty_cache()
|
||||
except Exception as e:
|
||||
logger.warn(f"Clear mps torch cache error, {str(e)}")
|
||||
elif torch.has_cuda:
|
||||
device_count = torch.cuda.device_count()
|
||||
for device_id in range(device_count):
|
||||
cuda_device = f"cuda:{device_id}"
|
||||
logger.info(f"Clear torch cache of device: {cuda_device}")
|
||||
with torch.cuda.device(cuda_device):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
else:
|
||||
logger.info("No cuda or mps, not support clear torch cache yet")
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPUInfo:
|
||||
total_memory_gb: float
|
||||
allocated_memory_gb: float
|
||||
cached_memory_gb: float
|
||||
available_memory_gb: float
|
||||
|
||||
|
||||
def _get_current_cuda_memory() -> List[GPUInfo]:
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
logger.warn("Torch not installed")
|
||||
return []
|
||||
if torch.cuda.is_available():
|
||||
num_gpus = torch.cuda.device_count()
|
||||
gpu_infos = []
|
||||
for gpu_id in range(num_gpus):
|
||||
with torch.cuda.device(gpu_id):
|
||||
device = torch.cuda.current_device()
|
||||
gpu_properties = torch.cuda.get_device_properties(device)
|
||||
total_memory = round(gpu_properties.total_memory / (1.0 * 1024**3), 2)
|
||||
allocated_memory = round(
|
||||
torch.cuda.memory_allocated() / (1.0 * 1024**3), 2
|
||||
)
|
||||
cached_memory = round(
|
||||
torch.cuda.memory_reserved() / (1.0 * 1024**3), 2
|
||||
)
|
||||
available_memory = total_memory - allocated_memory
|
||||
gpu_infos.append(
|
||||
GPUInfo(
|
||||
total_memory_gb=total_memory,
|
||||
allocated_memory_gb=allocated_memory,
|
||||
cached_memory_gb=cached_memory,
|
||||
available_memory_gb=available_memory,
|
||||
)
|
||||
)
|
||||
return gpu_infos
|
||||
else:
|
||||
logger.warn("CUDA is not available.")
|
||||
return []
|
28
dbgpt/util/module_utils.py
Normal file
28
dbgpt/util/module_utils.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from typing import Type
|
||||
from importlib import import_module
|
||||
|
||||
|
||||
def import_from_string(module_path: str, ignore_import_error: bool = False):
|
||||
try:
|
||||
module_path, class_name = module_path.rsplit(".", 1)
|
||||
except ValueError:
|
||||
raise ImportError(f"{module_path} doesn't look like a module path")
|
||||
module = import_module(module_path)
|
||||
|
||||
try:
|
||||
return getattr(module, class_name)
|
||||
except AttributeError:
|
||||
if ignore_import_error:
|
||||
return None
|
||||
raise ImportError(
|
||||
f'Module "{module_path}" does not define a "{class_name}" attribute/class'
|
||||
)
|
||||
|
||||
|
||||
def import_from_checked_string(module_path: str, supper_cls: Type):
|
||||
cls = import_from_string(module_path)
|
||||
if not issubclass(cls, supper_cls):
|
||||
raise ImportError(
|
||||
f'Module "{module_path}" does not the subclass of {str(supper_cls)}'
|
||||
)
|
||||
return cls
|
24
dbgpt/util/net_utils.py
Normal file
24
dbgpt/util/net_utils.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import socket
|
||||
import errno
|
||||
|
||||
|
||||
def _get_ip_address(address: str = "10.254.254.254:1") -> str:
|
||||
ip, port = address.split(":")
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
s.settimeout(0)
|
||||
curr_address = "127.0.0.1"
|
||||
try:
|
||||
# doesn't even have to be reachable
|
||||
s.connect((ip, int(port)))
|
||||
curr_address = s.getsockname()[0]
|
||||
except OSError as e:
|
||||
IP = "127.0.0.1"
|
||||
if e.errno == errno.ENETUNREACH:
|
||||
try:
|
||||
hostname = socket.getfqdn(socket.gethostname())
|
||||
curr_address = socket.gethostbyname(hostname)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
s.close()
|
||||
return curr_address
|
99
dbgpt/util/openai_utils.py
Normal file
99
dbgpt/util/openai_utils.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from typing import Dict, Any, Awaitable, Callable, Optional, Iterator
|
||||
import httpx
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
MessageCaller = Callable[[str], Awaitable[None]]
|
||||
|
||||
|
||||
async def _do_chat_completion(
|
||||
url: str,
|
||||
chat_data: Dict[str, Any],
|
||||
client: httpx.AsyncClient,
|
||||
headers: Dict[str, Any] = {},
|
||||
timeout: int = 60,
|
||||
caller: Optional[MessageCaller] = None,
|
||||
) -> Iterator[str]:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
url,
|
||||
headers=headers,
|
||||
json=chat_data,
|
||||
timeout=timeout,
|
||||
) as res:
|
||||
if res.status_code != 200:
|
||||
error_message = await res.aread()
|
||||
if error_message:
|
||||
error_message = error_message.decode("utf-8")
|
||||
logger.error(
|
||||
f"Request failed with status {res.status_code}. Error: {error_message}"
|
||||
)
|
||||
raise httpx.RequestError(
|
||||
f"Request failed with status {res.status_code}",
|
||||
request=res.request,
|
||||
)
|
||||
async for line in res.aiter_lines():
|
||||
if line:
|
||||
if not line.startswith("data: "):
|
||||
if caller:
|
||||
await caller(line)
|
||||
yield line
|
||||
else:
|
||||
decoded_line = line.split("data: ", 1)[1]
|
||||
if decoded_line.lower().strip() != "[DONE]".lower():
|
||||
obj = json.loads(decoded_line)
|
||||
if obj["choices"][0]["delta"].get("content") is not None:
|
||||
text = obj["choices"][0]["delta"].get("content")
|
||||
if caller:
|
||||
await caller(text)
|
||||
yield text
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
|
||||
async def chat_completion_stream(
|
||||
url: str,
|
||||
chat_data: Dict[str, Any],
|
||||
client: Optional[httpx.AsyncClient] = None,
|
||||
headers: Dict[str, Any] = {},
|
||||
timeout: int = 60,
|
||||
caller: Optional[MessageCaller] = None,
|
||||
) -> Iterator[str]:
|
||||
if client:
|
||||
async for text in _do_chat_completion(
|
||||
url,
|
||||
chat_data,
|
||||
client=client,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
caller=caller,
|
||||
):
|
||||
yield text
|
||||
else:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async for text in _do_chat_completion(
|
||||
url,
|
||||
chat_data,
|
||||
client=client,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
caller=caller,
|
||||
):
|
||||
yield text
|
||||
|
||||
|
||||
async def chat_completion(
|
||||
url: str,
|
||||
chat_data: Dict[str, Any],
|
||||
client: Optional[httpx.AsyncClient] = None,
|
||||
headers: Dict[str, Any] = {},
|
||||
timeout: int = 60,
|
||||
caller: Optional[MessageCaller] = None,
|
||||
) -> str:
|
||||
full_text = ""
|
||||
async for text in chat_completion_stream(
|
||||
url, chat_data, client, headers=headers, timeout=timeout, caller=caller
|
||||
):
|
||||
full_text += text
|
||||
return full_text
|
740
dbgpt/util/parameter_utils.py
Normal file
740
dbgpt/util/parameter_utils.py
Normal file
@@ -0,0 +1,740 @@
|
||||
import argparse
|
||||
import os
|
||||
from dataclasses import dataclass, fields, MISSING, asdict, field, is_dataclass
|
||||
from typing import Any, List, Optional, Type, Union, Callable, Dict, TYPE_CHECKING
|
||||
from collections import OrderedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
|
||||
MISSING_DEFAULT_VALUE = "__MISSING_DEFAULT_VALUE__"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParameterDescription:
|
||||
param_class: str
|
||||
param_name: str
|
||||
param_type: str
|
||||
default_value: Optional[Any]
|
||||
description: str
|
||||
required: Optional[bool]
|
||||
valid_values: Optional[List[Any]]
|
||||
ext_metadata: Dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseParameters:
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls, data: dict, ignore_extra_fields: bool = False
|
||||
) -> "BaseParameters":
|
||||
"""Create an instance of the dataclass from a dictionary.
|
||||
|
||||
Args:
|
||||
data: A dictionary containing values for the dataclass fields.
|
||||
ignore_extra_fields: If True, any extra fields in the data dictionary that are
|
||||
not part of the dataclass will be ignored.
|
||||
If False, extra fields will raise an error. Defaults to False.
|
||||
Returns:
|
||||
An instance of the dataclass with values populated from the given dictionary.
|
||||
|
||||
Raises:
|
||||
TypeError: If `ignore_extra_fields` is False and there are fields in the
|
||||
dictionary that aren't present in the dataclass.
|
||||
"""
|
||||
all_field_names = {f.name for f in fields(cls)}
|
||||
if ignore_extra_fields:
|
||||
data = {key: value for key, value in data.items() if key in all_field_names}
|
||||
else:
|
||||
extra_fields = set(data.keys()) - all_field_names
|
||||
if extra_fields:
|
||||
raise TypeError(f"Unexpected fields: {', '.join(extra_fields)}")
|
||||
return cls(**data)
|
||||
|
||||
def update_from(self, source: Union["BaseParameters", dict]) -> bool:
|
||||
"""
|
||||
Update the attributes of this object using the values from another object (of the same or parent type) or a dictionary.
|
||||
Only update if the new value is different from the current value and the field is not marked as "fixed" in metadata.
|
||||
|
||||
Args:
|
||||
source (Union[BaseParameters, dict]): The source to update from. Can be another object of the same type or a dictionary.
|
||||
|
||||
Returns:
|
||||
bool: True if at least one field was updated, otherwise False.
|
||||
"""
|
||||
updated = False # Flag to indicate whether any field was updated
|
||||
if isinstance(source, (BaseParameters, dict)):
|
||||
for field_info in fields(self):
|
||||
# Check if the field has a "fixed" tag in metadata
|
||||
tags = field_info.metadata.get("tags")
|
||||
tags = [] if not tags else tags.split(",")
|
||||
if tags and "fixed" in tags:
|
||||
continue # skip this field
|
||||
# Get the new value from source (either another BaseParameters object or a dict)
|
||||
new_value = (
|
||||
getattr(source, field_info.name)
|
||||
if isinstance(source, BaseParameters)
|
||||
else source.get(field_info.name, None)
|
||||
)
|
||||
|
||||
# If the new value is not None and different from the current value, update the field and set the flag
|
||||
if new_value is not None and new_value != getattr(
|
||||
self, field_info.name
|
||||
):
|
||||
setattr(self, field_info.name, new_value)
|
||||
updated = True
|
||||
else:
|
||||
raise ValueError(
|
||||
"Source must be an instance of BaseParameters (or its derived class) or a dictionary."
|
||||
)
|
||||
|
||||
return updated
|
||||
|
||||
def __str__(self) -> str:
|
||||
return _get_dataclass_print_str(self)
|
||||
|
||||
def to_command_args(self, args_prefix: str = "--") -> List[str]:
|
||||
"""Convert the fields of the dataclass to a list of command line arguments.
|
||||
|
||||
Args:
|
||||
args_prefix: args prefix
|
||||
Returns:
|
||||
A list of strings where each field is represented by two items:
|
||||
one for the field name prefixed by args_prefix, and one for its value.
|
||||
"""
|
||||
return _dict_to_command_args(asdict(self), args_prefix=args_prefix)
|
||||
|
||||
|
||||
def _get_dataclass_print_str(obj):
|
||||
class_name = obj.__class__.__name__
|
||||
parameters = [
|
||||
f"\n\n=========================== {class_name} ===========================\n"
|
||||
]
|
||||
for field_info in fields(obj):
|
||||
value = _get_simple_privacy_field_value(obj, field_info)
|
||||
parameters.append(f"{field_info.name}: {value}")
|
||||
parameters.append(
|
||||
"\n======================================================================\n\n"
|
||||
)
|
||||
return "\n".join(parameters)
|
||||
|
||||
|
||||
def _dict_to_command_args(obj: Dict, args_prefix: str = "--") -> List[str]:
|
||||
"""Convert dict to a list of command line arguments
|
||||
|
||||
Args:
|
||||
obj: dict
|
||||
Returns:
|
||||
A list of strings where each field is represented by two items:
|
||||
one for the field name prefixed by args_prefix, and one for its value.
|
||||
"""
|
||||
args = []
|
||||
for key, value in obj.items():
|
||||
if value is None:
|
||||
continue
|
||||
args.append(f"{args_prefix}{key}")
|
||||
args.append(str(value))
|
||||
return args
|
||||
|
||||
|
||||
def _get_simple_privacy_field_value(obj, field_info):
|
||||
"""Retrieve the value of a field from a dataclass instance, applying privacy rules if necessary.
|
||||
|
||||
This function reads the metadata of a field to check if it's tagged with 'privacy'.
|
||||
If the 'privacy' tag is present, then it modifies the value based on its type
|
||||
for privacy concerns:
|
||||
- int: returns -999
|
||||
- float: returns -999.0
|
||||
- bool: returns False
|
||||
- str: if length > 5, masks the middle part and returns first and last char;
|
||||
otherwise, returns "******"
|
||||
|
||||
Args:
|
||||
obj: The dataclass instance.
|
||||
field_info: A Field object that contains information about the dataclass field.
|
||||
|
||||
Returns:
|
||||
The original or modified value of the field based on the privacy rules.
|
||||
|
||||
Example usage:
|
||||
@dataclass
|
||||
class Person:
|
||||
name: str
|
||||
age: int
|
||||
ssn: str = field(metadata={"tags": "privacy"})
|
||||
p = Person("Alice", 30, "123-45-6789")
|
||||
print(_get_simple_privacy_field_value(p, Person.ssn)) # A******9
|
||||
"""
|
||||
tags = field_info.metadata.get("tags")
|
||||
tags = [] if not tags else tags.split(",")
|
||||
is_privacy = False
|
||||
if tags and "privacy" in tags:
|
||||
is_privacy = True
|
||||
value = getattr(obj, field_info.name)
|
||||
if not is_privacy or not value:
|
||||
return value
|
||||
field_type = EnvArgumentParser._get_argparse_type(field_info.type)
|
||||
if field_type is int:
|
||||
return -999
|
||||
if field_type is float:
|
||||
return -999.0
|
||||
if field_type is bool:
|
||||
return False
|
||||
# str
|
||||
if len(value) > 5:
|
||||
return value[0] + "******" + value[-1]
|
||||
return "******"
|
||||
|
||||
|
||||
def _genenv_ignoring_key_case(env_key: str, env_prefix: str = None, default_value=None):
|
||||
"""Get the value from the environment variable, ignoring the case of the key"""
|
||||
if env_prefix:
|
||||
env_key = env_prefix + env_key
|
||||
return os.getenv(
|
||||
env_key, os.getenv(env_key.upper(), os.getenv(env_key.lower(), default_value))
|
||||
)
|
||||
|
||||
|
||||
def _genenv_ignoring_key_case_with_prefixes(
|
||||
env_key: str, env_prefixes: List[str] = None, default_value=None
|
||||
) -> str:
|
||||
if env_prefixes:
|
||||
for env_prefix in env_prefixes:
|
||||
env_var_value = _genenv_ignoring_key_case(env_key, env_prefix)
|
||||
if env_var_value:
|
||||
return env_var_value
|
||||
return _genenv_ignoring_key_case(env_key, default_value=default_value)
|
||||
|
||||
|
||||
class EnvArgumentParser:
|
||||
@staticmethod
|
||||
def get_env_prefix(env_key: str) -> str:
|
||||
if not env_key:
|
||||
return None
|
||||
env_key = env_key.replace("-", "_")
|
||||
return env_key + "_"
|
||||
|
||||
def parse_args_into_dataclass(
|
||||
self,
|
||||
dataclass_type: Type,
|
||||
env_prefixes: List[str] = None,
|
||||
command_args: List[str] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Parse parameters from environment variables and command lines and populate them into data class"""
|
||||
parser = argparse.ArgumentParser()
|
||||
for field in fields(dataclass_type):
|
||||
env_var_value = _genenv_ignoring_key_case_with_prefixes(
|
||||
field.name, env_prefixes
|
||||
)
|
||||
if env_var_value:
|
||||
env_var_value = env_var_value.strip()
|
||||
if field.type is int or field.type == Optional[int]:
|
||||
env_var_value = int(env_var_value)
|
||||
elif field.type is float or field.type == Optional[float]:
|
||||
env_var_value = float(env_var_value)
|
||||
elif field.type is bool or field.type == Optional[bool]:
|
||||
env_var_value = env_var_value.lower() == "true"
|
||||
elif field.type is str or field.type == Optional[str]:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Unsupported parameter type {field.type}")
|
||||
if not env_var_value:
|
||||
env_var_value = kwargs.get(field.name)
|
||||
|
||||
# print(f"env_var_value: {env_var_value} for {field.name}")
|
||||
# Add a command-line argument for this field
|
||||
EnvArgumentParser._build_single_argparse_option(
|
||||
parser, field, env_var_value
|
||||
)
|
||||
|
||||
# Parse the command-line arguments
|
||||
cmd_args, cmd_argv = parser.parse_known_args(args=command_args)
|
||||
# cmd_args = parser.parse_args(args=command_args)
|
||||
# print(f"cmd_args: {cmd_args}")
|
||||
for field in fields(dataclass_type):
|
||||
# cmd_line_value = getattr(cmd_args, field.name)
|
||||
if field.name in cmd_args:
|
||||
cmd_line_value = getattr(cmd_args, field.name)
|
||||
if cmd_line_value is not None:
|
||||
kwargs[field.name] = cmd_line_value
|
||||
|
||||
return dataclass_type(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _create_arg_parser(dataclass_type: Type) -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description=dataclass_type.__doc__)
|
||||
for field in fields(dataclass_type):
|
||||
help_text = field.metadata.get("help", "")
|
||||
valid_values = field.metadata.get("valid_values", None)
|
||||
argument_kwargs = {
|
||||
"type": EnvArgumentParser._get_argparse_type(field.type),
|
||||
"help": help_text,
|
||||
"choices": valid_values,
|
||||
"required": EnvArgumentParser._is_require_type(field.type),
|
||||
}
|
||||
if field.default != MISSING:
|
||||
argument_kwargs["default"] = field.default
|
||||
argument_kwargs["required"] = False
|
||||
parser.add_argument(f"--{field.name}", **argument_kwargs)
|
||||
return parser
|
||||
|
||||
@staticmethod
|
||||
def _create_click_option_from_field(field_name: str, field: Type, is_func=True):
|
||||
import click
|
||||
|
||||
help_text = field.metadata.get("help", "")
|
||||
valid_values = field.metadata.get("valid_values", None)
|
||||
cli_params = {
|
||||
"default": None if field.default is MISSING else field.default,
|
||||
"help": help_text,
|
||||
"show_default": True,
|
||||
"required": field.default is MISSING,
|
||||
}
|
||||
if valid_values:
|
||||
cli_params["type"] = click.Choice(valid_values)
|
||||
real_type = EnvArgumentParser._get_argparse_type(field.type)
|
||||
if real_type is int:
|
||||
cli_params["type"] = click.INT
|
||||
elif real_type is float:
|
||||
cli_params["type"] = click.FLOAT
|
||||
elif real_type is str:
|
||||
cli_params["type"] = click.STRING
|
||||
elif real_type is bool:
|
||||
cli_params["is_flag"] = True
|
||||
name = f"--{field_name}"
|
||||
if is_func:
|
||||
return click.option(
|
||||
name,
|
||||
**cli_params,
|
||||
)
|
||||
else:
|
||||
return click.Option([name], **cli_params)
|
||||
|
||||
@staticmethod
|
||||
def create_click_option(
|
||||
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
|
||||
):
|
||||
import functools
|
||||
from collections import OrderedDict
|
||||
|
||||
combined_fields = OrderedDict()
|
||||
if _dynamic_factory:
|
||||
_types = _dynamic_factory()
|
||||
if _types:
|
||||
dataclass_types = list(_types)
|
||||
for dataclass_type in dataclass_types:
|
||||
for field in fields(dataclass_type):
|
||||
if field.name not in combined_fields:
|
||||
combined_fields[field.name] = field
|
||||
|
||||
def decorator(func):
|
||||
for field_name, field in reversed(combined_fields.items()):
|
||||
option_decorator = EnvArgumentParser._create_click_option_from_field(
|
||||
field_name, field
|
||||
)
|
||||
func = option_decorator(func)
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
@staticmethod
|
||||
def _create_raw_click_option(
|
||||
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
|
||||
):
|
||||
combined_fields = _merge_dataclass_types(
|
||||
*dataclass_types, _dynamic_factory=_dynamic_factory
|
||||
)
|
||||
options = []
|
||||
|
||||
for field_name, field in reversed(combined_fields.items()):
|
||||
options.append(
|
||||
EnvArgumentParser._create_click_option_from_field(
|
||||
field_name, field, is_func=False
|
||||
)
|
||||
)
|
||||
return options
|
||||
|
||||
@staticmethod
|
||||
def create_argparse_option(
|
||||
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
|
||||
) -> argparse.ArgumentParser:
|
||||
combined_fields = _merge_dataclass_types(
|
||||
*dataclass_types, _dynamic_factory=_dynamic_factory
|
||||
)
|
||||
parser = argparse.ArgumentParser()
|
||||
for _, field in reversed(combined_fields.items()):
|
||||
EnvArgumentParser._build_single_argparse_option(parser, field)
|
||||
return parser
|
||||
|
||||
@staticmethod
|
||||
def _build_single_argparse_option(
|
||||
parser: argparse.ArgumentParser, field, default_value=None
|
||||
):
|
||||
# Add a command-line argument for this field
|
||||
help_text = field.metadata.get("help", "")
|
||||
valid_values = field.metadata.get("valid_values", None)
|
||||
short_name = field.metadata.get("short", None)
|
||||
argument_kwargs = {
|
||||
"type": EnvArgumentParser._get_argparse_type(field.type),
|
||||
"help": help_text,
|
||||
"choices": valid_values,
|
||||
"required": EnvArgumentParser._is_require_type(field.type),
|
||||
}
|
||||
if field.default != MISSING:
|
||||
argument_kwargs["default"] = field.default
|
||||
argument_kwargs["required"] = False
|
||||
if default_value:
|
||||
argument_kwargs["default"] = default_value
|
||||
argument_kwargs["required"] = False
|
||||
if field.type is bool or field.type == Optional[bool]:
|
||||
argument_kwargs["action"] = "store_true"
|
||||
del argument_kwargs["type"]
|
||||
del argument_kwargs["choices"]
|
||||
names = []
|
||||
if short_name:
|
||||
names.append(f"-{short_name}")
|
||||
names.append(f"--{field.name}")
|
||||
parser.add_argument(*names, **argument_kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _get_argparse_type(field_type: Type) -> Type:
|
||||
# Return the appropriate type for argparse to use based on the field type
|
||||
if field_type is int or field_type == Optional[int]:
|
||||
return int
|
||||
elif field_type is float or field_type == Optional[float]:
|
||||
return float
|
||||
elif field_type is bool or field_type == Optional[bool]:
|
||||
return bool
|
||||
elif field_type is str or field_type == Optional[str]:
|
||||
return str
|
||||
else:
|
||||
raise ValueError(f"Unsupported parameter type {field_type}")
|
||||
|
||||
@staticmethod
|
||||
def _get_argparse_type_str(field_type: Type) -> str:
|
||||
argparse_type = EnvArgumentParser._get_argparse_type(field_type)
|
||||
if argparse_type is int:
|
||||
return "int"
|
||||
elif argparse_type is float:
|
||||
return "float"
|
||||
elif argparse_type is bool:
|
||||
return "bool"
|
||||
else:
|
||||
return "str"
|
||||
|
||||
@staticmethod
|
||||
def _is_require_type(field_type: Type) -> str:
|
||||
return field_type not in [Optional[int], Optional[float], Optional[bool]]
|
||||
|
||||
@staticmethod
|
||||
def _kwargs_to_env_key_value(
|
||||
kwargs: Dict, prefix: str = "__dbgpt_gunicorn__env_prefix__"
|
||||
) -> Dict[str, str]:
|
||||
return {prefix + k: str(v) for k, v in kwargs.items()}
|
||||
|
||||
@staticmethod
|
||||
def _read_env_key_value(
|
||||
prefix: str = "__dbgpt_gunicorn__env_prefix__",
|
||||
) -> List[str]:
|
||||
env_args = []
|
||||
for key, value in os.environ.items():
|
||||
if key.startswith(prefix):
|
||||
arg_key = "--" + key.replace(prefix, "")
|
||||
if value.lower() in ["true", "1"]:
|
||||
# Flag args
|
||||
env_args.append(arg_key)
|
||||
elif not value.lower() in ["false", "0"]:
|
||||
env_args.extend([arg_key, value])
|
||||
return env_args
|
||||
|
||||
|
||||
def _merge_dataclass_types(
|
||||
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
|
||||
) -> OrderedDict:
|
||||
combined_fields = OrderedDict()
|
||||
if _dynamic_factory:
|
||||
_types = _dynamic_factory()
|
||||
if _types:
|
||||
dataclass_types = list(_types)
|
||||
for dataclass_type in dataclass_types:
|
||||
for field in fields(dataclass_type):
|
||||
if field.name not in combined_fields:
|
||||
combined_fields[field.name] = field
|
||||
return combined_fields
|
||||
|
||||
|
||||
def _type_str_to_python_type(type_str: str) -> Type:
|
||||
type_mapping: Dict[str, Type] = {
|
||||
"int": int,
|
||||
"float": float,
|
||||
"bool": bool,
|
||||
"str": str,
|
||||
}
|
||||
return type_mapping.get(type_str, str)
|
||||
|
||||
|
||||
def _get_parameter_descriptions(
|
||||
dataclass_type: Type, **kwargs
|
||||
) -> List[ParameterDescription]:
|
||||
descriptions = []
|
||||
for field in fields(dataclass_type):
|
||||
ext_metadata = {
|
||||
k: v for k, v in field.metadata.items() if k not in ["help", "valid_values"]
|
||||
}
|
||||
default_value = field.default if field.default != MISSING else None
|
||||
if field.name in kwargs:
|
||||
default_value = kwargs[field.name]
|
||||
descriptions.append(
|
||||
ParameterDescription(
|
||||
param_class=f"{dataclass_type.__module__}.{dataclass_type.__name__}",
|
||||
param_name=field.name,
|
||||
param_type=EnvArgumentParser._get_argparse_type_str(field.type),
|
||||
description=field.metadata.get("help", None),
|
||||
required=field.default is MISSING,
|
||||
default_value=default_value,
|
||||
valid_values=field.metadata.get("valid_values", None),
|
||||
ext_metadata=ext_metadata,
|
||||
)
|
||||
)
|
||||
return descriptions
|
||||
|
||||
|
||||
def _build_parameter_class(desc: List[ParameterDescription]) -> Type:
|
||||
from dbgpt.util.module_utils import import_from_string
|
||||
|
||||
if not desc:
|
||||
raise ValueError("Parameter descriptions cant be empty")
|
||||
param_class_str = desc[0].param_class
|
||||
if param_class_str:
|
||||
param_class = import_from_string(param_class_str, ignore_import_error=True)
|
||||
if param_class:
|
||||
return param_class
|
||||
module_name, _, class_name = param_class_str.rpartition(".")
|
||||
|
||||
fields_dict = {} # This will store field names and their default values or field()
|
||||
annotations = {} # This will store the type annotations for the fields
|
||||
|
||||
for d in desc:
|
||||
metadata = d.ext_metadata if d.ext_metadata else {}
|
||||
metadata["help"] = d.description
|
||||
metadata["valid_values"] = d.valid_values
|
||||
|
||||
annotations[d.param_name] = _type_str_to_python_type(
|
||||
d.param_type
|
||||
) # Set type annotation
|
||||
fields_dict[d.param_name] = field(default=d.default_value, metadata=metadata)
|
||||
|
||||
# Create the new class. Note the setting of __annotations__ for type hints
|
||||
new_class = type(
|
||||
class_name, (object,), {**fields_dict, "__annotations__": annotations}
|
||||
)
|
||||
result_class = dataclass(new_class) # Make it a dataclass
|
||||
|
||||
return result_class
|
||||
|
||||
|
||||
def _extract_parameter_details(
|
||||
parser: argparse.ArgumentParser,
|
||||
param_class: str = None,
|
||||
skip_names: List[str] = None,
|
||||
overwrite_default_values: Dict = {},
|
||||
) -> List[ParameterDescription]:
|
||||
descriptions = []
|
||||
|
||||
for action in parser._actions:
|
||||
if (
|
||||
action.default == argparse.SUPPRESS
|
||||
): # typically this means the argument was not provided
|
||||
continue
|
||||
|
||||
# determine parameter class (store_true/store_false are flags)
|
||||
flag_or_option = (
|
||||
"flag" if isinstance(action, argparse._StoreConstAction) else "option"
|
||||
)
|
||||
|
||||
# extract parameter name (use the first option string, typically the long form)
|
||||
param_name = action.option_strings[0] if action.option_strings else action.dest
|
||||
if param_name.startswith("--"):
|
||||
param_name = param_name[2:]
|
||||
if param_name.startswith("-"):
|
||||
param_name = param_name[1:]
|
||||
|
||||
param_name = param_name.replace("-", "_")
|
||||
|
||||
if skip_names and param_name in skip_names:
|
||||
continue
|
||||
|
||||
# gather other details
|
||||
default_value = action.default
|
||||
if param_name in overwrite_default_values:
|
||||
default_value = overwrite_default_values[param_name]
|
||||
arg_type = (
|
||||
action.type if not callable(action.type) else str(action.type.__name__)
|
||||
)
|
||||
description = action.help
|
||||
|
||||
# determine if the argument is required
|
||||
required = action.required
|
||||
|
||||
# extract valid values for choices, if provided
|
||||
valid_values = action.choices if action.choices is not None else None
|
||||
|
||||
# set ext_metadata as an empty dict for now, can be updated later if needed
|
||||
ext_metadata = {}
|
||||
|
||||
descriptions.append(
|
||||
ParameterDescription(
|
||||
param_class=param_class,
|
||||
param_name=param_name,
|
||||
param_type=arg_type,
|
||||
default_value=default_value,
|
||||
description=description,
|
||||
required=required,
|
||||
valid_values=valid_values,
|
||||
ext_metadata=ext_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return descriptions
|
||||
|
||||
|
||||
def _get_dict_from_obj(obj, default_value=None) -> Optional[Dict]:
|
||||
if not obj:
|
||||
return None
|
||||
if is_dataclass(type(obj)):
|
||||
params = {}
|
||||
for field_info in fields(obj):
|
||||
value = _get_simple_privacy_field_value(obj, field_info)
|
||||
params[field_info.name] = value
|
||||
return params
|
||||
if isinstance(obj, dict):
|
||||
return obj
|
||||
return default_value
|
||||
|
||||
|
||||
def _get_base_model_descriptions(model_cls: "BaseModel") -> List[ParameterDescription]:
|
||||
from dbgpt._private import pydantic
|
||||
|
||||
version = int(pydantic.VERSION.split(".")[0])
|
||||
schema = model_cls.model_json_schema() if version >= 2 else model_cls.schema()
|
||||
required_fields = set(schema.get("required", []))
|
||||
param_descs = []
|
||||
for field_name, field_schema in schema.get("properties", {}).items():
|
||||
field = model_cls.model_fields[field_name]
|
||||
param_type = field_schema.get("type")
|
||||
if not param_type and "anyOf" in field_schema:
|
||||
for any_of in field_schema["anyOf"]:
|
||||
if any_of["type"] != "null":
|
||||
param_type = any_of["type"]
|
||||
break
|
||||
if version >= 2:
|
||||
default_value = (
|
||||
field.default
|
||||
if hasattr(field, "default")
|
||||
and str(field.default) != "PydanticUndefined"
|
||||
else None
|
||||
)
|
||||
else:
|
||||
default_value = (
|
||||
field.default
|
||||
if not field.allow_none
|
||||
else (
|
||||
field.default_factory() if callable(field.default_factory) else None
|
||||
)
|
||||
)
|
||||
description = field_schema.get("description", "")
|
||||
is_required = field_name in required_fields
|
||||
valid_values = None
|
||||
ext_metadata = None
|
||||
if hasattr(field, "field_info"):
|
||||
valid_values = (
|
||||
list(field.field_info.choices)
|
||||
if hasattr(field.field_info, "choices")
|
||||
else None
|
||||
)
|
||||
ext_metadata = (
|
||||
field.field_info.extra if hasattr(field.field_info, "extra") else None
|
||||
)
|
||||
param_class = (f"{model_cls.__module__}.{model_cls.__name__}",)
|
||||
param_desc = ParameterDescription(
|
||||
param_class=param_class,
|
||||
param_name=field_name,
|
||||
param_type=param_type,
|
||||
default_value=default_value,
|
||||
description=description,
|
||||
required=is_required,
|
||||
valid_values=valid_values,
|
||||
ext_metadata=ext_metadata,
|
||||
)
|
||||
param_descs.append(param_desc)
|
||||
return param_descs
|
||||
|
||||
|
||||
class _SimpleArgParser:
|
||||
def __init__(self, *args):
|
||||
self.params = {arg.replace("_", "-"): None for arg in args}
|
||||
|
||||
def parse(self, args=None):
|
||||
import sys
|
||||
|
||||
if args is None:
|
||||
args = sys.argv[1:]
|
||||
else:
|
||||
args = list(args)
|
||||
prev_arg = None
|
||||
for arg in args:
|
||||
if arg.startswith("--"):
|
||||
if prev_arg:
|
||||
self.params[prev_arg] = None
|
||||
prev_arg = arg[2:]
|
||||
else:
|
||||
if prev_arg:
|
||||
self.params[prev_arg] = arg
|
||||
prev_arg = None
|
||||
|
||||
if prev_arg:
|
||||
self.params[prev_arg] = None
|
||||
|
||||
def _get_param(self, key):
|
||||
return self.params.get(key.replace("_", "-")) or self.params.get(key)
|
||||
|
||||
def __getattr__(self, item):
|
||||
return self._get_param(item)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._get_param(key)
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self._get_param(key) or default
|
||||
|
||||
def __str__(self):
|
||||
return "\n".join(
|
||||
[f'{key.replace("-", "_")}: {value}' for key, value in self.params.items()]
|
||||
)
|
||||
|
||||
|
||||
def build_lazy_click_command(*dataclass_types: Type, _dynamic_factory=None):
|
||||
import click
|
||||
|
||||
class LazyCommand(click.Command):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(LazyCommand, self).__init__(*args, **kwargs)
|
||||
self.dynamic_params_added = False
|
||||
|
||||
def get_params(self, ctx):
|
||||
if ctx and not self.dynamic_params_added:
|
||||
dynamic_params = EnvArgumentParser._create_raw_click_option(
|
||||
*dataclass_types, _dynamic_factory=_dynamic_factory
|
||||
)
|
||||
for param in reversed(dynamic_params):
|
||||
self.params.append(param)
|
||||
self.dynamic_params_added = True
|
||||
return super(LazyCommand, self).get_params(ctx)
|
||||
|
||||
return LazyCommand
|
6
dbgpt/util/path_utils.py
Normal file
6
dbgpt/util/path_utils.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import os
|
||||
|
||||
|
||||
def has_path(filename):
|
||||
directory = os.path.dirname(filename)
|
||||
return bool(directory)
|
6
dbgpt/util/pd_utils.py
Normal file
6
dbgpt/util/pd_utils.py
Normal file
@@ -0,0 +1,6 @@
|
||||
def csv_colunm_foramt(val):
|
||||
if str(val).find("$") >= 0:
|
||||
return float(val.replace("$", "").replace(",", ""))
|
||||
if str(val).find("¥") >= 0:
|
||||
return float(val.replace("¥", "").replace(",", ""))
|
||||
return val
|
239
dbgpt/util/prompt_util.py
Normal file
239
dbgpt/util/prompt_util.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""General prompt helper that can help deal with LLM context window token limitations.
|
||||
|
||||
At its core, it calculates available context size by starting with the context window
|
||||
size of an LLM and reserve token space for the prompt template, and the output.
|
||||
|
||||
It provides utility for "repacking" text chunks (retrieved from index) to maximally
|
||||
make use of the available context window (and thereby reducing the number of LLM calls
|
||||
needed), or truncating them so that they fit in a single LLM call.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from string import Formatter
|
||||
from typing import Callable, List, Optional, Sequence
|
||||
|
||||
from dbgpt._private.pydantic import Field, PrivateAttr, BaseModel
|
||||
|
||||
from dbgpt.util.global_helper import globals_helper
|
||||
from dbgpt._private.llm_metadata import LLMMetadata
|
||||
from dbgpt.rag.embedding_engine.loader.token_splitter import TokenTextSplitter
|
||||
|
||||
DEFAULT_PADDING = 5
|
||||
DEFAULT_CHUNK_OVERLAP_RATIO = 0.1
|
||||
|
||||
DEFAULT_CONTEXT_WINDOW = 3000 # tokens
|
||||
DEFAULT_NUM_OUTPUTS = 256 # tokens
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PromptHelper(BaseModel):
|
||||
"""Prompt helper.
|
||||
|
||||
General prompt helper that can help deal with LLM context window token limitations.
|
||||
|
||||
At its core, it calculates available context size by starting with the context
|
||||
window size of an LLM and reserve token space for the prompt template, and the
|
||||
output.
|
||||
|
||||
It provides utility for "repacking" text chunks (retrieved from index) to maximally
|
||||
make use of the available context window (and thereby reducing the number of LLM
|
||||
calls needed), or truncating them so that they fit in a single LLM call.
|
||||
|
||||
Args:
|
||||
context_window (int): Context window for the LLM.
|
||||
num_output (int): Number of outputs for the LLM.
|
||||
chunk_overlap_ratio (float): Chunk overlap as a ratio of chunk size
|
||||
chunk_size_limit (Optional[int]): Maximum chunk size to use.
|
||||
tokenizer (Optional[Callable[[str], List]]): Tokenizer to use.
|
||||
separator (str): Separator for text splitter
|
||||
|
||||
"""
|
||||
|
||||
context_window: int = Field(
|
||||
default=DEFAULT_CONTEXT_WINDOW,
|
||||
description="The maximum context size that will get sent to the LLM.",
|
||||
)
|
||||
num_output: int = Field(
|
||||
default=DEFAULT_NUM_OUTPUTS,
|
||||
description="The amount of token-space to leave in input for generation.",
|
||||
)
|
||||
chunk_overlap_ratio: float = Field(
|
||||
default=DEFAULT_CHUNK_OVERLAP_RATIO,
|
||||
description="The percentage token amount that each chunk should overlap.",
|
||||
)
|
||||
chunk_size_limit: Optional[int] = Field(description="The maximum size of a chunk.")
|
||||
separator: str = Field(
|
||||
default=" ", description="The separator when chunking tokens."
|
||||
)
|
||||
|
||||
_tokenizer: Callable[[str], List] = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context_window: int = DEFAULT_CONTEXT_WINDOW,
|
||||
num_output: int = DEFAULT_NUM_OUTPUTS,
|
||||
chunk_overlap_ratio: float = DEFAULT_CHUNK_OVERLAP_RATIO,
|
||||
chunk_size_limit: Optional[int] = None,
|
||||
tokenizer: Optional[Callable[[str], List]] = None,
|
||||
separator: str = " ",
|
||||
) -> None:
|
||||
"""Init params."""
|
||||
if chunk_overlap_ratio > 1.0 or chunk_overlap_ratio < 0.0:
|
||||
raise ValueError("chunk_overlap_ratio must be a float between 0. and 1.")
|
||||
|
||||
# TODO: make configurable
|
||||
self._tokenizer = tokenizer or globals_helper.tokenizer
|
||||
|
||||
super().__init__(
|
||||
context_window=context_window,
|
||||
num_output=num_output,
|
||||
chunk_overlap_ratio=chunk_overlap_ratio,
|
||||
chunk_size_limit=chunk_size_limit,
|
||||
separator=separator,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_llm_metadata(
|
||||
cls,
|
||||
llm_metadata: LLMMetadata,
|
||||
chunk_overlap_ratio: float = DEFAULT_CHUNK_OVERLAP_RATIO,
|
||||
chunk_size_limit: Optional[int] = None,
|
||||
tokenizer: Optional[Callable[[str], List]] = None,
|
||||
separator: str = " ",
|
||||
) -> "PromptHelper":
|
||||
"""Create from llm predictor.
|
||||
|
||||
This will autofill values like context_window and num_output.
|
||||
|
||||
"""
|
||||
context_window = llm_metadata.context_window
|
||||
if llm_metadata.num_output == -1:
|
||||
num_output = DEFAULT_NUM_OUTPUTS
|
||||
else:
|
||||
num_output = llm_metadata.num_output
|
||||
|
||||
return cls(
|
||||
context_window=context_window,
|
||||
num_output=num_output,
|
||||
chunk_overlap_ratio=chunk_overlap_ratio,
|
||||
chunk_size_limit=chunk_size_limit,
|
||||
tokenizer=tokenizer,
|
||||
separator=separator,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "PromptHelper"
|
||||
|
||||
def _get_available_context_size(self, template: str) -> int:
|
||||
"""Get available context size.
|
||||
|
||||
This is calculated as:
|
||||
available context window = total context window
|
||||
- input (partially filled prompt)
|
||||
- output (room reserved for response)
|
||||
|
||||
Notes:
|
||||
- Available context size is further clamped to be non-negative.
|
||||
"""
|
||||
empty_prompt_txt = get_empty_prompt_txt(template)
|
||||
num_empty_prompt_tokens = len(self._tokenizer(empty_prompt_txt))
|
||||
context_size_tokens = (
|
||||
self.context_window - num_empty_prompt_tokens - self.num_output
|
||||
)
|
||||
if context_size_tokens < 0:
|
||||
raise ValueError(
|
||||
f"Calculated available context size {context_size_tokens} was"
|
||||
" not non-negative."
|
||||
)
|
||||
return context_size_tokens
|
||||
|
||||
def _get_available_chunk_size(
|
||||
self, prompt_template: str, num_chunks: int = 1, padding: int = 5
|
||||
) -> int:
|
||||
"""Get available chunk size.
|
||||
|
||||
This is calculated as:
|
||||
available chunk size = available context window // number_chunks
|
||||
- padding
|
||||
|
||||
Notes:
|
||||
- By default, we use padding of 5 (to save space for formatting needs).
|
||||
- Available chunk size is further clamped to chunk_size_limit if specified.
|
||||
"""
|
||||
available_context_size = self._get_available_context_size(prompt_template)
|
||||
result = available_context_size // num_chunks - padding
|
||||
if self.chunk_size_limit is not None:
|
||||
result = min(result, self.chunk_size_limit)
|
||||
return result
|
||||
|
||||
def get_text_splitter_given_prompt(
|
||||
self,
|
||||
prompt_template: str,
|
||||
num_chunks: int = 1,
|
||||
padding: int = DEFAULT_PADDING,
|
||||
) -> TokenTextSplitter:
|
||||
"""Get text splitter configured to maximally pack available context window,
|
||||
taking into account of given prompt, and desired number of chunks.
|
||||
"""
|
||||
chunk_size = self._get_available_chunk_size(
|
||||
prompt_template, num_chunks, padding=padding
|
||||
)
|
||||
if chunk_size <= 0:
|
||||
raise ValueError(f"Chunk size {chunk_size} is not positive.")
|
||||
chunk_overlap = int(self.chunk_overlap_ratio * chunk_size)
|
||||
return TokenTextSplitter(
|
||||
separator=self.separator,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
tokenizer=self._tokenizer,
|
||||
)
|
||||
|
||||
def repack(
|
||||
self,
|
||||
prompt_template: str,
|
||||
text_chunks: Sequence[str],
|
||||
padding: int = DEFAULT_PADDING,
|
||||
) -> List[str]:
|
||||
"""Repack text chunks to fit available context window.
|
||||
|
||||
This will combine text chunks into consolidated chunks
|
||||
that more fully "pack" the prompt template given the context_window.
|
||||
|
||||
"""
|
||||
text_splitter = self.get_text_splitter_given_prompt(
|
||||
prompt_template, padding=padding
|
||||
)
|
||||
combined_str = "\n\n".join([c.strip() for c in text_chunks if c.strip()])
|
||||
return text_splitter.split_text(combined_str)
|
||||
|
||||
|
||||
def get_empty_prompt_txt(template: str) -> str:
|
||||
"""Get empty prompt text.
|
||||
|
||||
Substitute empty strings in parts of the prompt that have
|
||||
not yet been filled out. Skip variables that have already
|
||||
been partially formatted. This is used to compute the initial tokens.
|
||||
|
||||
"""
|
||||
# partial_kargs = prompt.kwargs
|
||||
|
||||
partial_kargs = {}
|
||||
template_vars = get_template_vars(template)
|
||||
empty_kwargs = {v: "" for v in template_vars if v not in partial_kargs}
|
||||
all_kwargs = {**partial_kargs, **empty_kwargs}
|
||||
prompt = template.format(**all_kwargs)
|
||||
return prompt
|
||||
|
||||
|
||||
def get_template_vars(template_str: str) -> List[str]:
|
||||
"""Get template variables from a template string."""
|
||||
variables = []
|
||||
formatter = Formatter()
|
||||
|
||||
for _, variable_name, _, _ in formatter.parse(template_str):
|
||||
if variable_name:
|
||||
variables.append(variable_name)
|
||||
|
||||
return variables
|
0
dbgpt/util/serialization/__init__.py
Normal file
0
dbgpt/util/serialization/__init__.py
Normal file
44
dbgpt/util/serialization/json_serialization.py
Normal file
44
dbgpt/util/serialization/json_serialization.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Type
|
||||
import json
|
||||
|
||||
from dbgpt.core.interface.serialization import Serializable, Serializer
|
||||
|
||||
JSON_ENCODING = "utf-8"
|
||||
|
||||
|
||||
class JsonSerializable(Serializable, ABC):
|
||||
@abstractmethod
|
||||
def to_dict(self) -> Dict:
|
||||
"""Return the dict of current serializable object"""
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
"""Convert the object into bytes for storage or transmission."""
|
||||
return json.dumps(self.to_dict(), ensure_ascii=False).encode(JSON_ENCODING)
|
||||
|
||||
|
||||
class JsonSerializer(Serializer):
|
||||
"""The serializer abstract class for serializing cache keys and values."""
|
||||
|
||||
def serialize(self, obj: Serializable) -> bytes:
|
||||
"""Serialize a cache object.
|
||||
|
||||
Args:
|
||||
obj (Serializable): The object to serialize
|
||||
"""
|
||||
return json.dumps(obj.to_dict(), ensure_ascii=False).encode(JSON_ENCODING)
|
||||
|
||||
def deserialize(self, data: bytes, cls: Type[Serializable]) -> Serializable:
|
||||
"""Deserialize data back into a cache object of the specified type.
|
||||
|
||||
Args:
|
||||
data (bytes): The byte array to deserialize
|
||||
cls (Type[Serializable]): The type of current object
|
||||
|
||||
Returns:
|
||||
Serializable: The serializable object
|
||||
"""
|
||||
# Convert bytes back to JSON and then to the specified class
|
||||
json_data = json.loads(data.decode(JSON_ENCODING))
|
||||
# Assume that the cls has an __init__ that accepts a dictionary
|
||||
return cls(**json_data)
|
24
dbgpt/util/singleton.py
Normal file
24
dbgpt/util/singleton.py
Normal file
@@ -0,0 +1,24 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""The singleton metaclass for ensuring only one instance of a class."""
|
||||
import abc
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Singleton(abc.ABCMeta, type):
|
||||
"""Singleton metaclass for ensuring only one instance of a class"""
|
||||
|
||||
_instances = {}
|
||||
|
||||
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Call method for the singleton metaclass"""
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
|
||||
return cls._instances[cls]
|
||||
|
||||
|
||||
class AbstractSingleton(abc.ABC, metaclass=Singleton):
|
||||
"""Abstract singleton class for ensuring only one instance of a class"""
|
||||
|
||||
pass
|
0
dbgpt/util/speech/__init__.py
Normal file
0
dbgpt/util/speech/__init__.py
Normal file
50
dbgpt/util/speech/base.py
Normal file
50
dbgpt/util/speech/base.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Base class for all voice classes."""
|
||||
import abc
|
||||
from threading import Lock
|
||||
|
||||
from dbgpt.util.singleton import AbstractSingleton
|
||||
|
||||
|
||||
class VoiceBase(AbstractSingleton):
|
||||
"""
|
||||
Base class for all voice classes.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the voice class.
|
||||
"""
|
||||
self._url = None
|
||||
self._headers = None
|
||||
self._api_key = None
|
||||
self._voices = []
|
||||
self._mutex = Lock()
|
||||
self._setup()
|
||||
|
||||
def say(self, text: str, voice_index: int = 0) -> bool:
|
||||
"""
|
||||
Say the given text.
|
||||
|
||||
Args:
|
||||
text (str): The text to say.
|
||||
voice_index (int): The index of the voice to use.
|
||||
"""
|
||||
with self._mutex:
|
||||
return self._speech(text, voice_index)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _setup(self) -> None:
|
||||
"""
|
||||
Setup the voices, API key, etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _speech(self, text: str, voice_index: int = 0) -> bool:
|
||||
"""
|
||||
Play the given text.
|
||||
|
||||
Args:
|
||||
text (str): The text to play.
|
||||
"""
|
||||
pass
|
44
dbgpt/util/speech/brian.py
Normal file
44
dbgpt/util/speech/brian.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
from dbgpt.util.speech.base import VoiceBase
|
||||
|
||||
|
||||
class BrianSpeech(VoiceBase):
|
||||
"""Brian speech module for autogpt"""
|
||||
|
||||
def _setup(self) -> None:
|
||||
"""Setup the voices, API key, etc."""
|
||||
pass
|
||||
|
||||
def _speech(self, text: str, _: int = 0) -> bool:
|
||||
"""Speak text using Brian with the streamelements API
|
||||
|
||||
Args:
|
||||
text (str): The text to speak
|
||||
|
||||
Returns:
|
||||
bool: True if the request was successful, False otherwise
|
||||
"""
|
||||
from playsound import playsound
|
||||
|
||||
tts_url = (
|
||||
f"https://api.streamelements.com/kappa/v2/speech?voice=Brian&text={text}"
|
||||
)
|
||||
response = requests.get(tts_url)
|
||||
|
||||
if response.status_code == 200:
|
||||
with open("speech.mp3", "wb") as f:
|
||||
f.write(response.content)
|
||||
playsound("speech.mp3")
|
||||
os.remove("speech.mp3")
|
||||
return True
|
||||
else:
|
||||
logging.error(
|
||||
"Request failed with status code: %s, response content: %s",
|
||||
response.status_code,
|
||||
response.content,
|
||||
)
|
||||
return False
|
89
dbgpt/util/speech/eleven_labs.py
Normal file
89
dbgpt/util/speech/eleven_labs.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""ElevenLabs speech module"""
|
||||
import os
|
||||
import logging
|
||||
import requests
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.util.speech.base import VoiceBase
|
||||
|
||||
PLACEHOLDERS = {"your-voice-id"}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ElevenLabsSpeech(VoiceBase):
|
||||
"""ElevenLabs speech class"""
|
||||
|
||||
def _setup(self) -> None:
|
||||
"""Set up the voices, API key, etc.
|
||||
|
||||
Returns:
|
||||
None: None
|
||||
"""
|
||||
|
||||
cfg = Config()
|
||||
default_voices = ["ErXwobaYiN019PkySvjV", "EXAVITQu4vr4xnSDxMaL"]
|
||||
voice_options = {
|
||||
"Rachel": "21m00Tcm4TlvDq8ikWAM",
|
||||
"Domi": "AZnzlk1XvdvUeBnXmlld",
|
||||
"Bella": "EXAVITQu4vr4xnSDxMaL",
|
||||
"Antoni": "ErXwobaYiN019PkySvjV",
|
||||
"Elli": "MF3mGyEYCl7XYWbV9V6O",
|
||||
"Josh": "TxGEqnHWrfWFTfGW9XjX",
|
||||
"Arnold": "VR6AewLTigWG4xSOukaG",
|
||||
"Adam": "pNInz6obpgDQGcFmaJgB",
|
||||
"Sam": "yoZ06aMxZJJ28mfd3POQ",
|
||||
}
|
||||
self._headers = {
|
||||
"Content-Type": "application/json",
|
||||
"xi-api_v1-key": cfg.elevenlabs_api_key,
|
||||
}
|
||||
self._voices = default_voices.copy()
|
||||
if cfg.elevenlabs_voice_1_id in voice_options:
|
||||
cfg.elevenlabs_voice_1_id = voice_options[cfg.elevenlabs_voice_1_id]
|
||||
if cfg.elevenlabs_voice_2_id in voice_options:
|
||||
cfg.elevenlabs_voice_2_id = voice_options[cfg.elevenlabs_voice_2_id]
|
||||
self._use_custom_voice(cfg.elevenlabs_voice_1_id, 0)
|
||||
self._use_custom_voice(cfg.elevenlabs_voice_2_id, 1)
|
||||
|
||||
def _use_custom_voice(self, voice, voice_index) -> None:
|
||||
"""Use a custom voice if provided and not a placeholder
|
||||
|
||||
Args:
|
||||
voice (str): The voice ID
|
||||
voice_index (int): The voice index
|
||||
|
||||
Returns:
|
||||
None: None
|
||||
"""
|
||||
# Placeholder values that should be treated as empty
|
||||
if voice and voice not in PLACEHOLDERS:
|
||||
self._voices[voice_index] = voice
|
||||
|
||||
def _speech(self, text: str, voice_index: int = 0) -> bool:
|
||||
"""Speak text using elevenlabs.io's API
|
||||
|
||||
Args:
|
||||
text (str): The text to speak
|
||||
voice_index (int, optional): The voice to use. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
bool: True if the request was successful, False otherwise
|
||||
"""
|
||||
from playsound import playsound
|
||||
|
||||
tts_url = (
|
||||
f"https://api.elevenlabs.io/v1/text-to-speech/{self._voices[voice_index]}"
|
||||
)
|
||||
response = requests.post(tts_url, headers=self._headers, json={"text": text})
|
||||
|
||||
if response.status_code == 200:
|
||||
with open("speech.mpeg", "wb") as f:
|
||||
f.write(response.content)
|
||||
playsound("speech.mpeg", True)
|
||||
os.remove("speech.mpeg")
|
||||
return True
|
||||
else:
|
||||
logger.warn("Request failed with status code:", response.status_code)
|
||||
logger.info("Response content:", response.content)
|
||||
return False
|
23
dbgpt/util/speech/gtts.py
Normal file
23
dbgpt/util/speech/gtts.py
Normal file
@@ -0,0 +1,23 @@
|
||||
""" GTTS Voice. """
|
||||
import os
|
||||
|
||||
import gtts
|
||||
|
||||
from dbgpt.util.speech.base import VoiceBase
|
||||
|
||||
|
||||
class GTTSVoice(VoiceBase):
|
||||
"""GTTS Voice."""
|
||||
|
||||
def _setup(self) -> None:
|
||||
pass
|
||||
|
||||
def _speech(self, text: str, _: int = 0) -> bool:
|
||||
"""Play the given text."""
|
||||
from playsound import playsound
|
||||
|
||||
tts = gtts.gTTS(text)
|
||||
tts.save("speech.mp3")
|
||||
playsound("speech.mp3", True)
|
||||
os.remove("speech.mp3")
|
||||
return True
|
21
dbgpt/util/speech/macos_tts.py
Normal file
21
dbgpt/util/speech/macos_tts.py
Normal file
@@ -0,0 +1,21 @@
|
||||
""" MacOS TTS Voice. """
|
||||
import os
|
||||
|
||||
from dbgpt.util.speech.base import VoiceBase
|
||||
|
||||
|
||||
class MacOSTTS(VoiceBase):
|
||||
"""MacOS TTS Voice."""
|
||||
|
||||
def _setup(self) -> None:
|
||||
pass
|
||||
|
||||
def _speech(self, text: str, voice_index: int = 0) -> bool:
|
||||
"""Play the given text."""
|
||||
if voice_index == 0:
|
||||
os.system(f'say "{text}"')
|
||||
elif voice_index == 1:
|
||||
os.system(f'say -v "Ava (Premium)" "{text}"')
|
||||
else:
|
||||
os.system(f'say -v Samantha "{text}"')
|
||||
return True
|
46
dbgpt/util/speech/say.py
Normal file
46
dbgpt/util/speech/say.py
Normal file
@@ -0,0 +1,46 @@
|
||||
""" Text to speech module """
|
||||
import threading
|
||||
from threading import Semaphore
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.util.speech.base import VoiceBase
|
||||
from dbgpt.util.speech.brian import BrianSpeech
|
||||
from dbgpt.util.speech.eleven_labs import ElevenLabsSpeech
|
||||
from dbgpt.util.speech.gtts import GTTSVoice
|
||||
from dbgpt.util.speech.macos_tts import MacOSTTS
|
||||
|
||||
_QUEUE_SEMAPHORE = Semaphore(
|
||||
1
|
||||
) # The amount of sounds to queue before blocking the main thread
|
||||
|
||||
|
||||
def say_text(text: str, voice_index: int = 0) -> None:
|
||||
"""Speak the given text using the given voice index"""
|
||||
cfg = Config()
|
||||
default_voice_engine, voice_engine = _get_voice_engine(cfg)
|
||||
|
||||
def speak() -> None:
|
||||
success = voice_engine.say(text, voice_index)
|
||||
if not success:
|
||||
default_voice_engine.say(text)
|
||||
|
||||
_QUEUE_SEMAPHORE.release()
|
||||
|
||||
_QUEUE_SEMAPHORE.acquire(True)
|
||||
thread = threading.Thread(target=speak)
|
||||
thread.start()
|
||||
|
||||
|
||||
def _get_voice_engine(config: Config) -> tuple[VoiceBase, VoiceBase]:
|
||||
"""Get the voice engine to use for the given configuration"""
|
||||
default_voice_engine = GTTSVoice()
|
||||
if config.elevenlabs_api_key:
|
||||
voice_engine = ElevenLabsSpeech()
|
||||
elif config.use_mac_os_tts == "True":
|
||||
voice_engine = MacOSTTS()
|
||||
elif config.use_brian_tts == "True":
|
||||
voice_engine = BrianSpeech()
|
||||
else:
|
||||
voice_engine = GTTSVoice()
|
||||
|
||||
return default_voice_engine, voice_engine
|
81
dbgpt/util/string_utils.py
Normal file
81
dbgpt/util/string_utils.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import re
|
||||
|
||||
|
||||
def is_all_chinese(text):
|
||||
### Determine whether the string is pure Chinese
|
||||
pattern = re.compile(r"^[一-龥]+$")
|
||||
match = re.match(pattern, text)
|
||||
return match is not None
|
||||
|
||||
|
||||
def is_number_chinese(text):
|
||||
### Determine whether the string is numbers and Chinese
|
||||
pattern = re.compile(r"^[\d一-龥]+$")
|
||||
match = re.match(pattern, text)
|
||||
return match is not None
|
||||
|
||||
|
||||
def is_chinese_include_number(text):
|
||||
### Determine whether the string is pure Chinese or Chinese containing numbers
|
||||
pattern = re.compile(r"^[一-龥]+[\d一-龥]*$")
|
||||
match = re.match(pattern, text)
|
||||
return match is not None
|
||||
|
||||
|
||||
def is_scientific_notation(string):
|
||||
# 科学计数法的正则表达式
|
||||
pattern = r"^[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?$"
|
||||
# 使用正则表达式匹配字符串
|
||||
match = re.match(pattern, str(string))
|
||||
# 判断是否匹配成功
|
||||
if match is not None:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def extract_content(long_string, s1, s2, is_include: bool = False):
|
||||
# extract text
|
||||
match_map = {}
|
||||
start_index = long_string.find(s1)
|
||||
while start_index != -1:
|
||||
if is_include:
|
||||
end_index = long_string.find(s2, start_index + len(s1) + 1)
|
||||
extracted_content = long_string[start_index : end_index + len(s2)]
|
||||
else:
|
||||
end_index = long_string.find(s2, start_index + len(s1))
|
||||
extracted_content = long_string[start_index + len(s1) : end_index]
|
||||
if extracted_content:
|
||||
match_map[start_index] = extracted_content
|
||||
start_index = long_string.find(s1, start_index + 1)
|
||||
return match_map
|
||||
|
||||
|
||||
def extract_content_open_ending(long_string, s1, s2, is_include: bool = False):
|
||||
# extract text open ending
|
||||
match_map = {}
|
||||
start_index = long_string.find(s1)
|
||||
while start_index != -1:
|
||||
if long_string.find(s2, start_index) <= 0:
|
||||
end_index = len(long_string)
|
||||
else:
|
||||
if is_include:
|
||||
end_index = long_string.find(s2, start_index + len(s1) + 1)
|
||||
else:
|
||||
end_index = long_string.find(s2, start_index + len(s1))
|
||||
if is_include:
|
||||
extracted_content = long_string[start_index : end_index + len(s2)]
|
||||
else:
|
||||
extracted_content = long_string[start_index + len(s1) : end_index]
|
||||
if extracted_content:
|
||||
match_map[start_index] = extracted_content
|
||||
start_index = long_string.find(s1, start_index + 1)
|
||||
return match_map
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
s = "abcd123efghijkjhhh456xxx123aa456yyy123bb456xx123"
|
||||
s1 = "123"
|
||||
s2 = "456"
|
||||
|
||||
print(extract_content_open_ending(s, s1, s2, True))
|
272
dbgpt/util/system_utils.py
Normal file
272
dbgpt/util/system_utils.py
Normal file
@@ -0,0 +1,272 @@
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
from typing import Tuple, Dict
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import re
|
||||
from functools import cache
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemInfo:
|
||||
platform: str
|
||||
distribution: str
|
||||
python_version: str
|
||||
cpu: str
|
||||
cpu_avx: str
|
||||
memory: str
|
||||
torch_version: str
|
||||
device: str
|
||||
device_version: str
|
||||
device_count: int
|
||||
device_other: str
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
class AVXType(Enum):
|
||||
BASIC = "basic"
|
||||
AVX = "AVX"
|
||||
AVX2 = "AVX2"
|
||||
AVX512 = "AVX512"
|
||||
|
||||
@staticmethod
|
||||
def of_type(avx: str):
|
||||
for item in AVXType:
|
||||
if item._value_ == avx:
|
||||
return item
|
||||
return None
|
||||
|
||||
|
||||
class OSType(str, Enum):
|
||||
WINDOWS = "win"
|
||||
LINUX = "linux"
|
||||
DARWIN = "darwin"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
def get_cpu_avx_support() -> Tuple[OSType, AVXType, str]:
|
||||
system = platform.system()
|
||||
os_type = OSType.OTHER
|
||||
cpu_avx = AVXType.BASIC
|
||||
env_cpu_avx = AVXType.of_type(os.getenv("DBGPT_LLAMA_CPP_AVX"))
|
||||
distribution = "Unknown Distribution"
|
||||
if "windows" in system.lower():
|
||||
os_type = OSType.WINDOWS
|
||||
output = "avx2"
|
||||
distribution = "Windows " + platform.release()
|
||||
print("Current platform is windows, use avx2 as default cpu architecture")
|
||||
elif system == "Linux":
|
||||
os_type = OSType.LINUX
|
||||
result = subprocess.run(
|
||||
["lscpu"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
output = result.stdout.decode()
|
||||
distribution = get_linux_distribution()
|
||||
elif system == "Darwin":
|
||||
os_type = OSType.DARWIN
|
||||
result = subprocess.run(
|
||||
["sysctl", "-a"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
distribution = "Mac OS " + platform.mac_ver()[0]
|
||||
output = result.stdout.decode()
|
||||
else:
|
||||
os_type = OSType.OTHER
|
||||
print("Unsupported OS to get cpu avx, use default")
|
||||
return os_type, env_cpu_avx if env_cpu_avx else cpu_avx, distribution
|
||||
|
||||
if "avx512" in output.lower():
|
||||
cpu_avx = AVXType.AVX512
|
||||
elif "avx2" in output.lower():
|
||||
cpu_avx = AVXType.AVX2
|
||||
elif "avx " in output.lower():
|
||||
# cpu_avx = AVXType.AVX
|
||||
pass
|
||||
return os_type, env_cpu_avx if env_cpu_avx else cpu_avx, distribution
|
||||
|
||||
|
||||
def get_device() -> str:
|
||||
try:
|
||||
import torch
|
||||
|
||||
return (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
return "cpu"
|
||||
|
||||
|
||||
def get_device_info() -> Tuple[str, str, str, int, str]:
|
||||
torch_version, device, device_version, device_count, device_other = (
|
||||
None,
|
||||
"cpu",
|
||||
None,
|
||||
0,
|
||||
"",
|
||||
)
|
||||
try:
|
||||
import torch
|
||||
|
||||
torch_version = torch.__version__
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
device_version = torch.version.cuda
|
||||
device_count = torch.cuda.device_count()
|
||||
elif torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
if not device_version:
|
||||
device_version = (
|
||||
get_cuda_version_from_nvcc() or get_cuda_version_from_nvidia_smi()
|
||||
)
|
||||
if device == "cuda":
|
||||
try:
|
||||
output = subprocess.check_output(
|
||||
[
|
||||
"nvidia-smi",
|
||||
"--query-gpu=name,driver_version,memory.total,memory.free,memory.used",
|
||||
"--format=csv",
|
||||
]
|
||||
)
|
||||
device_other = output.decode("utf-8")
|
||||
except:
|
||||
pass
|
||||
return torch_version, device, device_version, device_count, device_other
|
||||
|
||||
|
||||
def get_cuda_version_from_nvcc():
|
||||
try:
|
||||
output = subprocess.check_output(["nvcc", "--version"])
|
||||
version_line = [
|
||||
line for line in output.decode("utf-8").split("\n") if "release" in line
|
||||
][0]
|
||||
return version_line.split("release")[-1].strip().split(",")[0]
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def get_cuda_version_from_nvidia_smi():
|
||||
try:
|
||||
output = subprocess.check_output(["nvidia-smi"]).decode("utf-8")
|
||||
match = re.search(r"CUDA Version:\s+(\d+\.\d+)", output)
|
||||
if match:
|
||||
return match.group(1)
|
||||
else:
|
||||
return None
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def get_linux_distribution():
|
||||
"""Get distribution of Linux"""
|
||||
if os.path.isfile("/etc/os-release"):
|
||||
with open("/etc/os-release", "r") as f:
|
||||
info = {}
|
||||
for line in f:
|
||||
key, _, value = line.partition("=")
|
||||
info[key] = value.strip().strip('"')
|
||||
return f"{info.get('NAME', 'Unknown')} {info.get('VERSION_ID', '')}".strip()
|
||||
return "Unknown Linux Distribution"
|
||||
|
||||
|
||||
def get_cpu_info():
|
||||
# Getting platform
|
||||
os_type, avx_type, distribution = get_cpu_avx_support()
|
||||
|
||||
# Getting CPU information
|
||||
cpu_info = "Unknown CPU"
|
||||
if os_type == OSType.LINUX:
|
||||
try:
|
||||
output = subprocess.check_output(["lscpu"]).decode("utf-8")
|
||||
match = re.search(r".*Model name:\s*(.+)", output)
|
||||
if match:
|
||||
cpu_info = match.group(1).strip()
|
||||
match = re.search(f".*型号名称:\s*(.+)", output)
|
||||
if match:
|
||||
cpu_info = match.group(1).strip()
|
||||
except:
|
||||
pass
|
||||
elif os_type == OSType.DARWIN:
|
||||
try:
|
||||
output = subprocess.check_output(
|
||||
["sysctl", "machdep.cpu.brand_string"]
|
||||
).decode("utf-8")
|
||||
match = re.search(r"machdep.cpu.brand_string:\s*(.+)", output)
|
||||
if match:
|
||||
cpu_info = match.group(1).strip()
|
||||
except:
|
||||
pass
|
||||
elif os_type == OSType.WINDOWS:
|
||||
try:
|
||||
output = subprocess.check_output("wmic cpu get Name", shell=True).decode(
|
||||
"utf-8"
|
||||
)
|
||||
lines = output.splitlines()
|
||||
cpu_info = lines[2].split(":")[-1].strip()
|
||||
except:
|
||||
pass
|
||||
|
||||
return os_type, avx_type, cpu_info, distribution
|
||||
|
||||
|
||||
def get_memory_info(os_type: OSType) -> str:
|
||||
memory = "Unknown Memory"
|
||||
try:
|
||||
import psutil
|
||||
|
||||
memory = f"{psutil.virtual_memory().total // (1024 ** 3)} GB"
|
||||
except ImportError:
|
||||
pass
|
||||
if os_type == OSType.LINUX:
|
||||
try:
|
||||
with open("/proc/meminfo", "r") as f:
|
||||
mem_info = f.readlines()
|
||||
for line in mem_info:
|
||||
if "MemTotal" in line:
|
||||
memory = line.split(":")[1].strip()
|
||||
break
|
||||
except:
|
||||
pass
|
||||
return memory
|
||||
|
||||
|
||||
@cache
|
||||
def get_system_info() -> SystemInfo:
|
||||
"""Get System information"""
|
||||
|
||||
os_type, avx_type, cpu_info, distribution = get_cpu_info()
|
||||
|
||||
# Getting Python version
|
||||
python_version = platform.python_version()
|
||||
|
||||
memory = get_memory_info(os_type)
|
||||
|
||||
(
|
||||
torch_version,
|
||||
device,
|
||||
device_version,
|
||||
device_count,
|
||||
device_other,
|
||||
) = get_device_info()
|
||||
|
||||
return SystemInfo(
|
||||
platform=os_type._value_,
|
||||
distribution=distribution,
|
||||
python_version=python_version,
|
||||
cpu=cpu_info,
|
||||
cpu_avx=avx_type._value_,
|
||||
memory=memory,
|
||||
torch_version=torch_version,
|
||||
device=device,
|
||||
device_version=device_version,
|
||||
device_count=device_count,
|
||||
device_other=device_other,
|
||||
)
|
0
dbgpt/util/tests/__init__.py
Normal file
0
dbgpt/util/tests/__init__.py
Normal file
81
dbgpt/util/tests/test_parameter_utils.py
Normal file
81
dbgpt/util/tests/test_parameter_utils.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import argparse
|
||||
import pytest
|
||||
from dbgpt.util.parameter_utils import _extract_parameter_details
|
||||
|
||||
|
||||
def create_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
return parser
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"argument, expected_param_name, default_value, param_type, expected_param_type, description",
|
||||
[
|
||||
("--option", "option", "value", str, "str", "An option argument"),
|
||||
("-option", "option", "value", str, "str", "An option argument"),
|
||||
("--num-gpu", "num_gpu", 1, int, "int", "Number of GPUS"),
|
||||
("--num_gpu", "num_gpu", 1, int, "int", "Number of GPUS"),
|
||||
],
|
||||
)
|
||||
def test_extract_parameter_details_option_argument(
|
||||
argument,
|
||||
expected_param_name,
|
||||
default_value,
|
||||
param_type,
|
||||
expected_param_type,
|
||||
description,
|
||||
):
|
||||
parser = create_parser()
|
||||
parser.add_argument(
|
||||
argument, default=default_value, type=param_type, help=description
|
||||
)
|
||||
descriptions = _extract_parameter_details(parser)
|
||||
|
||||
assert len(descriptions) == 1
|
||||
desc = descriptions[0]
|
||||
|
||||
assert desc.param_name == expected_param_name
|
||||
assert desc.param_type == expected_param_type
|
||||
assert desc.default_value == default_value
|
||||
assert desc.description == description
|
||||
assert desc.required == False
|
||||
assert desc.valid_values is None
|
||||
|
||||
|
||||
def test_extract_parameter_details_flag_argument():
|
||||
parser = create_parser()
|
||||
parser.add_argument("--flag", action="store_true", help="A flag argument")
|
||||
descriptions = _extract_parameter_details(parser)
|
||||
|
||||
assert len(descriptions) == 1
|
||||
desc = descriptions[0]
|
||||
|
||||
assert desc.param_name == "flag"
|
||||
assert desc.description == "A flag argument"
|
||||
assert desc.required == False
|
||||
|
||||
|
||||
def test_extract_parameter_details_choice_argument():
|
||||
parser = create_parser()
|
||||
parser.add_argument("--choice", choices=["A", "B", "C"], help="A choice argument")
|
||||
descriptions = _extract_parameter_details(parser)
|
||||
|
||||
assert len(descriptions) == 1
|
||||
desc = descriptions[0]
|
||||
|
||||
assert desc.param_name == "choice"
|
||||
assert desc.valid_values == ["A", "B", "C"]
|
||||
|
||||
|
||||
def test_extract_parameter_details_required_argument():
|
||||
parser = create_parser()
|
||||
parser.add_argument(
|
||||
"--required", required=True, type=int, help="A required argument"
|
||||
)
|
||||
descriptions = _extract_parameter_details(parser)
|
||||
|
||||
assert len(descriptions) == 1
|
||||
desc = descriptions[0]
|
||||
|
||||
assert desc.param_name == "required"
|
||||
assert desc.required == True
|
39
dbgpt/util/tracer/__init__.py
Normal file
39
dbgpt/util/tracer/__init__.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from dbgpt.util.tracer.base import (
|
||||
SpanType,
|
||||
Span,
|
||||
SpanTypeRunName,
|
||||
Tracer,
|
||||
SpanStorage,
|
||||
SpanStorageType,
|
||||
TracerContext,
|
||||
)
|
||||
from dbgpt.util.tracer.span_storage import (
|
||||
MemorySpanStorage,
|
||||
FileSpanStorage,
|
||||
SpanStorageContainer,
|
||||
)
|
||||
from dbgpt.util.tracer.tracer_impl import (
|
||||
root_tracer,
|
||||
trace,
|
||||
initialize_tracer,
|
||||
DefaultTracer,
|
||||
TracerManager,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SpanType",
|
||||
"Span",
|
||||
"SpanTypeRunName",
|
||||
"Tracer",
|
||||
"SpanStorage",
|
||||
"SpanStorageType",
|
||||
"TracerContext",
|
||||
"MemorySpanStorage",
|
||||
"FileSpanStorage",
|
||||
"SpanStorageContainer",
|
||||
"root_tracer",
|
||||
"trace",
|
||||
"initialize_tracer",
|
||||
"DefaultTracer",
|
||||
"TracerManager",
|
||||
]
|
189
dbgpt/util/tracer/base.py
Normal file
189
dbgpt/util/tracer/base.py
Normal file
@@ -0,0 +1,189 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Callable, Optional, List
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from dbgpt.component import BaseComponent, SystemApp, ComponentType
|
||||
|
||||
|
||||
class SpanType(str, Enum):
|
||||
BASE = "base"
|
||||
RUN = "run"
|
||||
CHAT = "chat"
|
||||
|
||||
|
||||
class SpanTypeRunName(str, Enum):
|
||||
WEBSERVER = "Webserver"
|
||||
WORKER_MANAGER = "WorkerManager"
|
||||
MODEL_WORKER = "ModelWorker"
|
||||
EMBEDDING_MODEL = "EmbeddingModel"
|
||||
|
||||
@staticmethod
|
||||
def values():
|
||||
return [item.value for item in SpanTypeRunName]
|
||||
|
||||
|
||||
class Span:
|
||||
"""Represents a unit of work that is being traced.
|
||||
This can be any operation like a function call or a database query.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
trace_id: str,
|
||||
span_id: str,
|
||||
span_type: SpanType = None,
|
||||
parent_span_id: str = None,
|
||||
operation_name: str = None,
|
||||
metadata: Dict = None,
|
||||
end_caller: Callable[[Span], None] = None,
|
||||
):
|
||||
if not span_type:
|
||||
span_type = SpanType.BASE
|
||||
self.span_type = span_type
|
||||
# The unique identifier for the entire trace
|
||||
self.trace_id = trace_id
|
||||
# Unique identifier for this span within the trace
|
||||
self.span_id = span_id
|
||||
# Identifier of the parent span, if this is a child span
|
||||
self.parent_span_id = parent_span_id
|
||||
# Descriptive name for the operation being traced
|
||||
self.operation_name = operation_name
|
||||
# Timestamp when this span started
|
||||
self.start_time = datetime.now()
|
||||
# Timestamp when this span ended, initially None
|
||||
self.end_time = None
|
||||
# Additional metadata associated with the span
|
||||
self.metadata = metadata
|
||||
self._end_callers = []
|
||||
if end_caller:
|
||||
self._end_callers.append(end_caller)
|
||||
|
||||
def end(self, **kwargs):
|
||||
"""Mark the end of this span by recording the current time."""
|
||||
self.end_time = datetime.now()
|
||||
if "metadata" in kwargs:
|
||||
self.metadata = kwargs.get("metadata")
|
||||
for caller in self._end_callers:
|
||||
caller(self)
|
||||
|
||||
def add_end_caller(self, end_caller: Callable[[Span], None]):
|
||||
if end_caller:
|
||||
self._end_callers.append(end_caller)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.end()
|
||||
return False
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return {
|
||||
"span_type": self.span_type.value,
|
||||
"trace_id": self.trace_id,
|
||||
"span_id": self.span_id,
|
||||
"parent_span_id": self.parent_span_id,
|
||||
"operation_name": self.operation_name,
|
||||
"start_time": None
|
||||
if not self.start_time
|
||||
else self.start_time.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3],
|
||||
"end_time": None
|
||||
if not self.end_time
|
||||
else self.end_time.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3],
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
|
||||
class SpanStorageType(str, Enum):
|
||||
ON_CREATE = "on_create"
|
||||
ON_END = "on_end"
|
||||
ON_CREATE_END = "on_create_end"
|
||||
|
||||
|
||||
class SpanStorage(BaseComponent, ABC):
|
||||
"""Abstract base class for storing spans.
|
||||
|
||||
This allows different storage mechanisms (e.g., in-memory, database) to be implemented.
|
||||
"""
|
||||
|
||||
name = ComponentType.TRACER_SPAN_STORAGE.value
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
"""Initialize the storage with the given application context."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def append_span(self, span: Span):
|
||||
"""Store the given span. This needs to be implemented by subclasses."""
|
||||
|
||||
def append_span_batch(self, spans: List[Span]):
|
||||
"""Store the span batch"""
|
||||
for span in spans:
|
||||
self.append_span(span)
|
||||
|
||||
|
||||
class Tracer(BaseComponent, ABC):
|
||||
"""Abstract base class for tracing operations.
|
||||
Provides the core logic for starting, ending, and retrieving spans.
|
||||
"""
|
||||
|
||||
name = ComponentType.TRACER.value
|
||||
|
||||
def __init__(self, system_app: SystemApp | None = None):
|
||||
super().__init__(system_app)
|
||||
self.system_app = system_app # Application context
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
"""Initialize the tracer with the given application context."""
|
||||
self.system_app = system_app
|
||||
|
||||
@abstractmethod
|
||||
def append_span(self, span: Span):
|
||||
"""Append the given span to storage. This needs to be implemented by subclasses."""
|
||||
|
||||
@abstractmethod
|
||||
def start_span(
|
||||
self,
|
||||
operation_name: str,
|
||||
parent_span_id: str = None,
|
||||
span_type: SpanType = None,
|
||||
metadata: Dict = None,
|
||||
) -> Span:
|
||||
"""Begin a new span for the given operation. If provided, the span will be
|
||||
a child of the span with the given parent_span_id.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def end_span(self, span: Span, **kwargs):
|
||||
"""
|
||||
End the given span.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_current_span(self) -> Optional[Span]:
|
||||
"""
|
||||
Retrieve the span that is currently being traced.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _get_current_storage(self) -> SpanStorage:
|
||||
"""
|
||||
Get the storage mechanism currently in use for storing spans.
|
||||
This needs to be implemented by subclasses.
|
||||
"""
|
||||
|
||||
def _new_uuid(self) -> str:
|
||||
"""
|
||||
Generate a new unique identifier.
|
||||
"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@dataclass
|
||||
class TracerContext:
|
||||
span_id: Optional[str] = None
|
150
dbgpt/util/tracer/span_storage.py
Normal file
150
dbgpt/util/tracer/span_storage.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import datetime
|
||||
import threading
|
||||
import queue
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.util.tracer.base import Span, SpanStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemorySpanStorage(SpanStorage):
|
||||
def __init__(self, system_app: SystemApp | None = None):
|
||||
super().__init__(system_app)
|
||||
self.spans = []
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def append_span(self, span: Span):
|
||||
with self._lock:
|
||||
self.spans.append(span)
|
||||
|
||||
|
||||
class SpanStorageContainer(SpanStorage):
|
||||
def __init__(
|
||||
self,
|
||||
system_app: SystemApp | None = None,
|
||||
batch_size=10,
|
||||
flush_interval=10,
|
||||
executor: Executor = None,
|
||||
):
|
||||
super().__init__(system_app)
|
||||
if not executor:
|
||||
executor = ThreadPoolExecutor(thread_name_prefix="trace_storage_sync_")
|
||||
self.executor = executor
|
||||
self.storages: List[SpanStorage] = []
|
||||
self.last_date = (
|
||||
datetime.datetime.now().date()
|
||||
) # Store the current date for checking date changes
|
||||
self.queue = queue.Queue()
|
||||
self.batch_size = batch_size
|
||||
self.flush_interval = flush_interval
|
||||
self.last_flush_time = time.time()
|
||||
self.flush_signal_queue = queue.Queue()
|
||||
self.flush_thread = threading.Thread(
|
||||
target=self._flush_to_storages, daemon=True
|
||||
)
|
||||
self.flush_thread.start()
|
||||
|
||||
def append_storage(self, storage: SpanStorage):
|
||||
"""Append sotrage to container
|
||||
|
||||
Args:
|
||||
storage ([`SpanStorage`]): The storage to be append to current container
|
||||
"""
|
||||
self.storages.append(storage)
|
||||
|
||||
def append_span(self, span: Span):
|
||||
self.queue.put(span)
|
||||
if self.queue.qsize() >= self.batch_size:
|
||||
try:
|
||||
self.flush_signal_queue.put_nowait(True)
|
||||
except queue.Full:
|
||||
pass # If the signal queue is full, it's okay. The flush thread will handle it.
|
||||
|
||||
def _flush_to_storages(self):
|
||||
while True:
|
||||
interval = time.time() - self.last_flush_time
|
||||
if interval < self.flush_interval:
|
||||
try:
|
||||
self.flush_signal_queue.get(
|
||||
block=True, timeout=self.flush_interval - interval
|
||||
)
|
||||
except Exception:
|
||||
# Timeout
|
||||
pass
|
||||
|
||||
spans_to_write = []
|
||||
while not self.queue.empty():
|
||||
spans_to_write.append(self.queue.get())
|
||||
for s in self.storages:
|
||||
|
||||
def append_and_ignore_error(
|
||||
storage: SpanStorage, spans_to_write: List[SpanStorage]
|
||||
):
|
||||
try:
|
||||
storage.append_span_batch(spans_to_write)
|
||||
except Exception as e:
|
||||
logger.warn(
|
||||
f"Append spans to storage {str(storage)} failed: {str(e)}, span_data: {spans_to_write}"
|
||||
)
|
||||
|
||||
self.executor.submit(append_and_ignore_error, s, spans_to_write)
|
||||
self.last_flush_time = time.time()
|
||||
|
||||
|
||||
class FileSpanStorage(SpanStorage):
|
||||
def __init__(self, filename: str):
|
||||
super().__init__()
|
||||
self.filename = filename
|
||||
# Split filename into prefix and suffix
|
||||
self.filename_prefix, self.filename_suffix = os.path.splitext(filename)
|
||||
if not self.filename_suffix:
|
||||
self.filename_suffix = ".log"
|
||||
self.last_date = (
|
||||
datetime.datetime.now().date()
|
||||
) # Store the current date for checking date changes
|
||||
self.queue = queue.Queue()
|
||||
|
||||
if not os.path.exists(filename):
|
||||
# New file if not exist
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
with open(filename, "a"):
|
||||
pass
|
||||
|
||||
def append_span(self, span: Span):
|
||||
self._write_to_file([span])
|
||||
|
||||
def append_span_batch(self, spans: List[Span]):
|
||||
self._write_to_file(spans)
|
||||
|
||||
def _get_dated_filename(self, date: datetime.date) -> str:
|
||||
"""Return the filename based on a specific date."""
|
||||
date_str = date.strftime("%Y-%m-%d")
|
||||
return f"{self.filename_prefix}_{date_str}{self.filename_suffix}"
|
||||
|
||||
def _roll_over_if_needed(self):
|
||||
"""Checks if a day has changed since the last write, and if so, renames the current file."""
|
||||
current_date = datetime.datetime.now().date()
|
||||
if current_date != self.last_date:
|
||||
if os.path.exists(self.filename):
|
||||
os.rename(self.filename, self._get_dated_filename(self.last_date))
|
||||
self.last_date = current_date
|
||||
|
||||
def _write_to_file(self, spans: List[Span]):
|
||||
self._roll_over_if_needed()
|
||||
|
||||
with open(self.filename, "a") as file:
|
||||
for span in spans:
|
||||
span_data = span.to_dict()
|
||||
try:
|
||||
file.write(json.dumps(span_data, ensure_ascii=False) + "\n")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Write span to file failed: {str(e)}, span_data: {span_data}"
|
||||
)
|
0
dbgpt/util/tracer/tests/__init__.py
Normal file
0
dbgpt/util/tracer/tests/__init__.py
Normal file
131
dbgpt/util/tracer/tests/test_base.py
Normal file
131
dbgpt/util/tracer/tests/test_base.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from typing import Dict
|
||||
from dbgpt.component import SystemApp
|
||||
|
||||
from dbgpt.util.tracer import Span, SpanType, SpanStorage, Tracer
|
||||
|
||||
|
||||
# Mock implementations
|
||||
|
||||
|
||||
class MockSpanStorage(SpanStorage):
|
||||
def __init__(self):
|
||||
self.spans = []
|
||||
|
||||
def append_span(self, span: Span):
|
||||
self.spans.append(span)
|
||||
|
||||
|
||||
class MockTracer(Tracer):
|
||||
def __init__(self, system_app: SystemApp | None = None):
|
||||
super().__init__(system_app)
|
||||
self.current_span = None
|
||||
self.storage = MockSpanStorage()
|
||||
|
||||
def append_span(self, span: Span):
|
||||
self.storage.append_span(span)
|
||||
|
||||
def start_span(
|
||||
self, operation_name: str, parent_span_id: str = None, metadata: Dict = None
|
||||
) -> Span:
|
||||
trace_id = (
|
||||
self._new_uuid() if parent_span_id is None else parent_span_id.split(":")[0]
|
||||
)
|
||||
span_id = f"{trace_id}:{self._new_uuid()}"
|
||||
span = Span(
|
||||
trace_id, span_id, SpanType.BASE, parent_span_id, operation_name, metadata
|
||||
)
|
||||
self.current_span = span
|
||||
return span
|
||||
|
||||
def end_span(self, span: Span):
|
||||
span.end()
|
||||
self.append_span(span)
|
||||
|
||||
def get_current_span(self) -> Span:
|
||||
return self.current_span
|
||||
|
||||
def _get_current_storage(self) -> SpanStorage:
|
||||
return self.storage
|
||||
|
||||
|
||||
# Tests
|
||||
|
||||
|
||||
def test_span_creation():
|
||||
span = Span(
|
||||
"trace_id",
|
||||
"span_id",
|
||||
SpanType.BASE,
|
||||
"parent_span_id",
|
||||
"operation",
|
||||
{"key": "value"},
|
||||
)
|
||||
assert span.trace_id == "trace_id"
|
||||
assert span.span_id == "span_id"
|
||||
assert span.parent_span_id == "parent_span_id"
|
||||
assert span.operation_name == "operation"
|
||||
assert span.metadata == {"key": "value"}
|
||||
|
||||
|
||||
def test_span_end():
|
||||
span = Span("trace_id", "span_id")
|
||||
assert span.end_time is None
|
||||
span.end()
|
||||
assert span.end_time is not None
|
||||
|
||||
|
||||
def test_mock_tracer_start_span():
|
||||
tracer = MockTracer()
|
||||
span = tracer.start_span("operation")
|
||||
assert span.operation_name == "operation"
|
||||
assert tracer.get_current_span() == span
|
||||
|
||||
|
||||
def test_mock_tracer_end_span():
|
||||
tracer = MockTracer()
|
||||
span = tracer.start_span("operation")
|
||||
tracer.end_span(span)
|
||||
assert span in tracer._get_current_storage().spans
|
||||
|
||||
|
||||
def test_mock_tracer_append_span():
|
||||
tracer = MockTracer()
|
||||
span = Span("trace_id", "span_id")
|
||||
tracer.append_span(span)
|
||||
assert span in tracer._get_current_storage().spans
|
||||
|
||||
|
||||
def test_parent_child_span_relation():
|
||||
tracer = MockTracer()
|
||||
|
||||
# Start a parent span
|
||||
parent_span = tracer.start_span("parent_operation")
|
||||
|
||||
# Start a child span with parent span's ID
|
||||
child_span = tracer.start_span(
|
||||
"child_operation", parent_span_id=parent_span.span_id
|
||||
)
|
||||
|
||||
# Assert the relationships
|
||||
assert child_span.parent_span_id == parent_span.span_id
|
||||
assert (
|
||||
child_span.trace_id == parent_span.trace_id
|
||||
) # Assuming children share the same trace ID
|
||||
|
||||
# End spans
|
||||
tracer.end_span(child_span)
|
||||
tracer.end_span(parent_span)
|
||||
|
||||
# Assert they are in the storage
|
||||
assert child_span in tracer._get_current_storage().spans
|
||||
assert parent_span in tracer._get_current_storage().spans
|
||||
|
||||
|
||||
# This test checks if unique UUIDs are being generated.
|
||||
# Note: This is a simple test and doesn't guarantee uniqueness for large numbers of UUIDs.
|
||||
|
||||
|
||||
def test_new_uuid_unique():
|
||||
tracer = MockTracer()
|
||||
uuid_set = {tracer._new_uuid() for _ in range(1000)}
|
||||
assert len(uuid_set) == 1000
|
174
dbgpt/util/tracer/tests/test_span_storage.py
Normal file
174
dbgpt/util/tracer/tests/test_span_storage.py
Normal file
@@ -0,0 +1,174 @@
|
||||
import os
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
import tempfile
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from dbgpt.util.tracer import (
|
||||
SpanStorage,
|
||||
FileSpanStorage,
|
||||
Span,
|
||||
SpanType,
|
||||
SpanStorageContainer,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage(request):
|
||||
if not request or not hasattr(request, "param"):
|
||||
file_does_not_exist = False
|
||||
else:
|
||||
file_does_not_exist = request.param.get("file_does_not_exist", False)
|
||||
|
||||
if file_does_not_exist:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
filename = os.path.join(tmp_dir, "non_existent_file.jsonl")
|
||||
storage_instance = FileSpanStorage(filename)
|
||||
yield storage_instance
|
||||
else:
|
||||
with tempfile.NamedTemporaryFile(delete=True) as tmp_file:
|
||||
filename = tmp_file.name
|
||||
storage_instance = FileSpanStorage(filename)
|
||||
yield storage_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage_container(request):
|
||||
if not request or not hasattr(request, "param"):
|
||||
batch_size = 10
|
||||
flush_interval = 10
|
||||
else:
|
||||
batch_size = request.param.get("batch_size", 10)
|
||||
flush_interval = request.param.get("flush_interval", 10)
|
||||
storage_container = SpanStorageContainer(
|
||||
batch_size=batch_size, flush_interval=flush_interval
|
||||
)
|
||||
yield storage_container
|
||||
|
||||
|
||||
def read_spans_from_file(filename):
|
||||
with open(filename, "r") as f:
|
||||
return [json.loads(line) for line in f.readlines()]
|
||||
|
||||
|
||||
def test_write_span(storage: SpanStorage):
|
||||
span = Span("1", "a", SpanType.BASE, "b", "op1")
|
||||
storage.append_span(span)
|
||||
time.sleep(0.1)
|
||||
|
||||
spans_in_file = read_spans_from_file(storage.filename)
|
||||
assert len(spans_in_file) == 1
|
||||
assert spans_in_file[0]["trace_id"] == "1"
|
||||
|
||||
|
||||
def test_incremental_write(storage: SpanStorage):
|
||||
span1 = Span("1", "a", SpanType.BASE, "b", "op1")
|
||||
span2 = Span("2", "c", SpanType.BASE, "d", "op2")
|
||||
|
||||
storage.append_span(span1)
|
||||
storage.append_span(span2)
|
||||
time.sleep(0.1)
|
||||
|
||||
spans_in_file = read_spans_from_file(storage.filename)
|
||||
assert len(spans_in_file) == 2
|
||||
|
||||
|
||||
def test_sync_and_async_append(storage: SpanStorage):
|
||||
span = Span("1", "a", SpanType.BASE, "b", "op1")
|
||||
|
||||
storage.append_span(span)
|
||||
|
||||
async def async_append():
|
||||
storage.append_span(span)
|
||||
|
||||
asyncio.run(async_append())
|
||||
|
||||
time.sleep(0.1)
|
||||
spans_in_file = read_spans_from_file(storage.filename)
|
||||
assert len(spans_in_file) == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize("storage", [{"file_does_not_exist": True}], indirect=True)
|
||||
def test_non_existent_file(storage: SpanStorage):
|
||||
span = Span("1", "a", SpanType.BASE, "b", "op1")
|
||||
span2 = Span("2", "c", SpanType.BASE, "d", "op2")
|
||||
storage.append_span(span)
|
||||
time.sleep(0.1)
|
||||
|
||||
spans_in_file = read_spans_from_file(storage.filename)
|
||||
assert len(spans_in_file) == 1
|
||||
|
||||
storage.append_span(span2)
|
||||
time.sleep(0.1)
|
||||
spans_in_file = read_spans_from_file(storage.filename)
|
||||
assert len(spans_in_file) == 2
|
||||
assert spans_in_file[0]["trace_id"] == "1"
|
||||
assert spans_in_file[1]["trace_id"] == "2"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("storage", [{"file_does_not_exist": True}], indirect=True)
|
||||
def test_log_rollover(storage: SpanStorage):
|
||||
# mock start date
|
||||
mock_start_date = datetime(2023, 10, 18, 23, 59)
|
||||
|
||||
with patch("datetime.datetime") as mock_datetime:
|
||||
mock_datetime.now.return_value = mock_start_date
|
||||
|
||||
span1 = Span("1", "a", SpanType.BASE, "b", "op1")
|
||||
storage.append_span(span1)
|
||||
time.sleep(0.1)
|
||||
|
||||
# mock new day
|
||||
mock_datetime.now.return_value = mock_start_date + timedelta(minutes=1)
|
||||
|
||||
span2 = Span("2", "c", SpanType.BASE, "d", "op2")
|
||||
storage.append_span(span2)
|
||||
time.sleep(0.1)
|
||||
|
||||
# origin filename need exists
|
||||
assert os.path.exists(storage.filename)
|
||||
|
||||
# get roll over filename
|
||||
dated_filename = os.path.join(
|
||||
os.path.dirname(storage.filename),
|
||||
f"{os.path.basename(storage.filename).split('.')[0]}_2023-10-18.jsonl",
|
||||
)
|
||||
|
||||
assert os.path.exists(dated_filename)
|
||||
|
||||
# check origin filename just include the second span
|
||||
spans_in_original_file = read_spans_from_file(storage.filename)
|
||||
assert len(spans_in_original_file) == 1
|
||||
assert spans_in_original_file[0]["trace_id"] == "2"
|
||||
|
||||
# check the roll over filename just include the first span
|
||||
spans_in_dated_file = read_spans_from_file(dated_filename)
|
||||
assert len(spans_in_dated_file) == 1
|
||||
assert spans_in_dated_file[0]["trace_id"] == "1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("storage_container", [{"batch_size": 5}], indirect=True)
|
||||
async def test_container_flush_policy(
|
||||
storage_container: SpanStorageContainer, storage: FileSpanStorage
|
||||
):
|
||||
storage_container.append_storage(storage)
|
||||
span = Span("1", "a", SpanType.BASE, "b", "op1")
|
||||
|
||||
filename = storage.filename
|
||||
|
||||
for _ in range(storage_container.batch_size - 1):
|
||||
storage_container.append_span(span)
|
||||
|
||||
spans_in_file = read_spans_from_file(filename)
|
||||
assert len(spans_in_file) == 0
|
||||
|
||||
# Trigger batch write
|
||||
storage_container.append_span(span)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
spans_in_file = read_spans_from_file(filename)
|
||||
assert len(spans_in_file) == storage_container.batch_size
|
103
dbgpt/util/tracer/tests/test_tracer_impl.py
Normal file
103
dbgpt/util/tracer/tests/test_tracer_impl.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import pytest
|
||||
from dbgpt.util.tracer import (
|
||||
Span,
|
||||
SpanStorageType,
|
||||
SpanStorage,
|
||||
DefaultTracer,
|
||||
TracerManager,
|
||||
Tracer,
|
||||
MemorySpanStorage,
|
||||
)
|
||||
from dbgpt.component import SystemApp
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def system_app():
|
||||
return SystemApp()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage(system_app: SystemApp):
|
||||
ms = MemorySpanStorage(system_app)
|
||||
system_app.register_instance(ms)
|
||||
return ms
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tracer(request, system_app: SystemApp):
|
||||
if not request or not hasattr(request, "param"):
|
||||
return DefaultTracer(system_app)
|
||||
else:
|
||||
span_storage_type = request.param.get(
|
||||
"span_storage_type", SpanStorageType.ON_CREATE_END
|
||||
)
|
||||
return DefaultTracer(system_app, span_storage_type=span_storage_type)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tracer_manager(system_app: SystemApp, tracer: Tracer):
|
||||
system_app.register_instance(tracer)
|
||||
manager = TracerManager()
|
||||
manager.initialize(system_app)
|
||||
return manager
|
||||
|
||||
|
||||
def test_start_and_end_span(tracer: Tracer):
|
||||
span = tracer.start_span("operation")
|
||||
assert isinstance(span, Span)
|
||||
assert span.operation_name == "operation"
|
||||
|
||||
tracer.end_span(span)
|
||||
assert span.end_time is not None
|
||||
|
||||
stored_span = tracer._get_current_storage().spans[0]
|
||||
assert stored_span == span
|
||||
|
||||
|
||||
def test_start_and_end_span_with_tracer_manager(tracer_manager: TracerManager):
|
||||
span = tracer_manager.start_span("operation")
|
||||
assert isinstance(span, Span)
|
||||
assert span.operation_name == "operation"
|
||||
|
||||
tracer_manager.end_span(span)
|
||||
assert span.end_time is not None
|
||||
|
||||
|
||||
def test_parent_child_span_relation(tracer: Tracer):
|
||||
parent_span = tracer.start_span("parent_operation")
|
||||
child_span = tracer.start_span(
|
||||
"child_operation", parent_span_id=parent_span.span_id
|
||||
)
|
||||
|
||||
assert child_span.parent_span_id == parent_span.span_id
|
||||
assert child_span.trace_id == parent_span.trace_id
|
||||
|
||||
tracer.end_span(child_span)
|
||||
tracer.end_span(parent_span)
|
||||
|
||||
assert parent_span in tracer._get_current_storage().spans
|
||||
assert child_span in tracer._get_current_storage().spans
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tracer, expected_count, after_create_inc_count",
|
||||
[
|
||||
({"span_storage_type": SpanStorageType.ON_CREATE}, 1, 1),
|
||||
({"span_storage_type": SpanStorageType.ON_END}, 1, 0),
|
||||
({"span_storage_type": SpanStorageType.ON_CREATE_END}, 2, 1),
|
||||
],
|
||||
indirect=["tracer"],
|
||||
)
|
||||
def test_tracer_span_storage_type_and_with(
|
||||
tracer: Tracer,
|
||||
expected_count: int,
|
||||
after_create_inc_count: int,
|
||||
storage: SpanStorage,
|
||||
):
|
||||
span = tracer.start_span("new_span")
|
||||
span.end()
|
||||
assert len(storage.spans) == expected_count
|
||||
|
||||
with tracer.start_span("with_span") as ws:
|
||||
assert len(storage.spans) == expected_count + after_create_inc_count
|
||||
assert len(storage.spans) == expected_count + expected_count
|
597
dbgpt/util/tracer/tracer_cli.py
Normal file
597
dbgpt/util/tracer/tracer_cli.py
Normal file
@@ -0,0 +1,597 @@
|
||||
import os
|
||||
import click
|
||||
import logging
|
||||
import glob
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Iterable, Dict, Callable
|
||||
from dbgpt.configs.model_config import LOGDIR
|
||||
from dbgpt.util.tracer import SpanType, SpanTypeRunName
|
||||
|
||||
logger = logging.getLogger("dbgpt_cli")
|
||||
|
||||
|
||||
_DEFAULT_FILE_PATTERN = os.path.join(LOGDIR, "dbgpt*.jsonl")
|
||||
|
||||
|
||||
@click.group("trace")
|
||||
def trace_cli_group():
|
||||
"""Analyze and visualize trace spans."""
|
||||
pass
|
||||
|
||||
|
||||
@trace_cli_group.command()
|
||||
@click.option(
|
||||
"--trace_id",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
show_default=True,
|
||||
help="Specify the trace ID to list",
|
||||
)
|
||||
@click.option(
|
||||
"--span_id",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
show_default=True,
|
||||
help="Specify the Span ID to list.",
|
||||
)
|
||||
@click.option(
|
||||
"--span_type",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
show_default=True,
|
||||
help="Specify the Span Type to list.",
|
||||
)
|
||||
@click.option(
|
||||
"--parent_span_id",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
show_default=True,
|
||||
help="Specify the Parent Span ID to list.",
|
||||
)
|
||||
@click.option(
|
||||
"--search",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
show_default=True,
|
||||
help="Search trace_id, span_id, parent_span_id, operation_name or content in metadata.",
|
||||
)
|
||||
@click.option(
|
||||
"-l",
|
||||
"--limit",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Limit the number of recent span displayed.",
|
||||
)
|
||||
@click.option(
|
||||
"--start_time",
|
||||
type=str,
|
||||
help='Filter by start time. Format: "YYYY-MM-DD HH:MM:SS.mmm"',
|
||||
)
|
||||
@click.option(
|
||||
"--end_time", type=str, help='Filter by end time. Format: "YYYY-MM-DD HH:MM:SS.mmm"'
|
||||
)
|
||||
@click.option(
|
||||
"--desc",
|
||||
required=False,
|
||||
type=bool,
|
||||
default=False,
|
||||
is_flag=True,
|
||||
help="Whether to use reverse sorting. By default, sorting is based on start time.",
|
||||
)
|
||||
@click.option(
|
||||
"--output",
|
||||
required=False,
|
||||
type=click.Choice(["text", "html", "csv", "latex", "json"]),
|
||||
default="text",
|
||||
help="The output format",
|
||||
)
|
||||
@click.argument("files", nargs=-1, type=click.Path(exists=True, readable=True))
|
||||
def list(
|
||||
trace_id: str,
|
||||
span_id: str,
|
||||
span_type: str,
|
||||
parent_span_id: str,
|
||||
search: str,
|
||||
limit: int,
|
||||
start_time: str,
|
||||
end_time: str,
|
||||
desc: bool,
|
||||
output: str,
|
||||
files=None,
|
||||
):
|
||||
"""List your trace spans"""
|
||||
from prettytable import PrettyTable
|
||||
|
||||
# If no files are explicitly specified, use the default pattern to get them
|
||||
spans = read_spans_from_files(files)
|
||||
|
||||
if trace_id:
|
||||
spans = filter(lambda s: s["trace_id"] == trace_id, spans)
|
||||
if span_id:
|
||||
spans = filter(lambda s: s["span_id"] == span_id, spans)
|
||||
if span_type:
|
||||
spans = filter(lambda s: s["span_type"] == span_type, spans)
|
||||
if parent_span_id:
|
||||
spans = filter(lambda s: s["parent_span_id"] == parent_span_id, spans)
|
||||
# Filter spans based on the start and end times
|
||||
if start_time:
|
||||
start_dt = _parse_datetime(start_time)
|
||||
spans = filter(
|
||||
lambda span: _parse_datetime(span["start_time"]) >= start_dt, spans
|
||||
)
|
||||
|
||||
if end_time:
|
||||
end_dt = _parse_datetime(end_time)
|
||||
spans = filter(
|
||||
lambda span: _parse_datetime(span["start_time"]) <= end_dt, spans
|
||||
)
|
||||
|
||||
if search:
|
||||
spans = filter(_new_search_span_func(search), spans)
|
||||
|
||||
# Sort spans based on the start time
|
||||
spans = sorted(
|
||||
spans, key=lambda span: _parse_datetime(span["start_time"]), reverse=desc
|
||||
)[:limit]
|
||||
|
||||
table = PrettyTable(
|
||||
["Trace ID", "Span ID", "Operation Name", "Conversation UID"],
|
||||
)
|
||||
|
||||
for sp in spans:
|
||||
conv_uid = None
|
||||
if "metadata" in sp and sp:
|
||||
metadata = sp["metadata"]
|
||||
if isinstance(metadata, dict):
|
||||
conv_uid = metadata.get("conv_uid")
|
||||
table.add_row(
|
||||
[
|
||||
sp.get("trace_id"),
|
||||
sp.get("span_id"),
|
||||
# sp.get("parent_span_id"),
|
||||
sp.get("operation_name"),
|
||||
conv_uid,
|
||||
]
|
||||
)
|
||||
out_kwargs = {"ensure_ascii": False} if output == "json" else {}
|
||||
print(table.get_formatted_string(out_format=output, **out_kwargs))
|
||||
|
||||
|
||||
@trace_cli_group.command()
|
||||
@click.option(
|
||||
"--trace_id",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Specify the trace ID to list",
|
||||
)
|
||||
@click.argument("files", nargs=-1, type=click.Path(exists=True, readable=True))
|
||||
def tree(trace_id: str, files):
|
||||
"""Display trace links as a tree"""
|
||||
hierarchy = _view_trace_hierarchy(trace_id, files)
|
||||
if not hierarchy:
|
||||
_print_empty_message(files)
|
||||
return
|
||||
_print_trace_hierarchy(hierarchy)
|
||||
|
||||
|
||||
@trace_cli_group.command()
|
||||
@click.option(
|
||||
"--trace_id",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify the trace ID to analyze. If None, show latest conversation details",
|
||||
)
|
||||
@click.option(
|
||||
"--tree",
|
||||
required=False,
|
||||
type=bool,
|
||||
default=False,
|
||||
is_flag=True,
|
||||
help="Display trace spans as a tree",
|
||||
)
|
||||
@click.option(
|
||||
"--hide_conv",
|
||||
required=False,
|
||||
type=bool,
|
||||
default=False,
|
||||
is_flag=True,
|
||||
help="Hide your conversation details",
|
||||
)
|
||||
@click.option(
|
||||
"--hide_run_params",
|
||||
required=False,
|
||||
type=bool,
|
||||
default=False,
|
||||
is_flag=True,
|
||||
help="Hide run params",
|
||||
)
|
||||
@click.option(
|
||||
"--output",
|
||||
required=False,
|
||||
type=click.Choice(["text", "html", "csv", "latex", "json"]),
|
||||
default="text",
|
||||
help="The output format",
|
||||
)
|
||||
@click.argument("files", nargs=-1, type=click.Path(exists=False, readable=True))
|
||||
def chat(
|
||||
trace_id: str,
|
||||
tree: bool,
|
||||
hide_conv: bool,
|
||||
hide_run_params: bool,
|
||||
output: str,
|
||||
files,
|
||||
):
|
||||
"""Show conversation details"""
|
||||
from prettytable import PrettyTable
|
||||
|
||||
spans = read_spans_from_files(files)
|
||||
|
||||
# Sort by start time
|
||||
spans = sorted(
|
||||
spans, key=lambda span: _parse_datetime(span["start_time"]), reverse=True
|
||||
)
|
||||
spans = [sp for sp in spans]
|
||||
if not spans:
|
||||
_print_empty_message(files)
|
||||
return
|
||||
service_spans = {}
|
||||
service_names = set(SpanTypeRunName.values())
|
||||
found_trace_id = None
|
||||
for sp in spans:
|
||||
span_type = sp["span_type"]
|
||||
metadata = sp.get("metadata")
|
||||
if span_type == SpanType.RUN:
|
||||
service_name = metadata["run_service"]
|
||||
service_spans[service_name] = sp.copy()
|
||||
if set(service_spans.keys()) == service_names and found_trace_id:
|
||||
break
|
||||
elif span_type == SpanType.CHAT and not found_trace_id:
|
||||
if not trace_id:
|
||||
found_trace_id = sp["trace_id"]
|
||||
if trace_id and trace_id == sp["trace_id"]:
|
||||
found_trace_id = trace_id
|
||||
|
||||
service_tables = {}
|
||||
system_infos_table = {}
|
||||
out_kwargs = {"ensure_ascii": False} if output == "json" else {}
|
||||
for service_name, sp in service_spans.items():
|
||||
metadata = sp["metadata"]
|
||||
table = PrettyTable(["Config Key", "Config Value"], title=service_name)
|
||||
for k, v in metadata["params"].items():
|
||||
table.add_row([k, v])
|
||||
service_tables[service_name] = table
|
||||
sys_infos = metadata.get("sys_infos")
|
||||
if sys_infos and isinstance(sys_infos, dict):
|
||||
sys_table = PrettyTable(
|
||||
["System Config Key", "System Config Value"],
|
||||
title=f"{service_name} System information",
|
||||
)
|
||||
for k, v in sys_infos.items():
|
||||
sys_table.add_row([k, v])
|
||||
system_infos_table[service_name] = sys_table
|
||||
|
||||
if not hide_run_params:
|
||||
merged_table1 = merge_tables_horizontally(
|
||||
[
|
||||
service_tables.get(SpanTypeRunName.WEBSERVER.value),
|
||||
service_tables.get(SpanTypeRunName.EMBEDDING_MODEL.value),
|
||||
]
|
||||
)
|
||||
merged_table2 = merge_tables_horizontally(
|
||||
[
|
||||
service_tables.get(SpanTypeRunName.MODEL_WORKER.value),
|
||||
service_tables.get(SpanTypeRunName.WORKER_MANAGER.value),
|
||||
]
|
||||
)
|
||||
sys_table = system_infos_table.get(SpanTypeRunName.WORKER_MANAGER.value)
|
||||
if system_infos_table:
|
||||
for k, v in system_infos_table.items():
|
||||
sys_table = v
|
||||
break
|
||||
if output == "text":
|
||||
print(merged_table1)
|
||||
print(merged_table2)
|
||||
else:
|
||||
for service_name, table in service_tables.items():
|
||||
print(table.get_formatted_string(out_format=output, **out_kwargs))
|
||||
if sys_table:
|
||||
print(sys_table.get_formatted_string(out_format=output, **out_kwargs))
|
||||
|
||||
if not found_trace_id:
|
||||
print(f"Can't found conversation with trace_id: {trace_id}")
|
||||
return
|
||||
trace_id = found_trace_id
|
||||
|
||||
trace_spans = [span for span in spans if span["trace_id"] == trace_id]
|
||||
trace_spans = [s for s in reversed(trace_spans)]
|
||||
hierarchy = _build_trace_hierarchy(trace_spans)
|
||||
if tree:
|
||||
print(f"\nInvoke Trace Tree(trace_id: {trace_id}):\n")
|
||||
_print_trace_hierarchy(hierarchy)
|
||||
|
||||
if hide_conv:
|
||||
return
|
||||
|
||||
trace_spans = _get_ordered_trace_from(hierarchy)
|
||||
table = PrettyTable(["Key", "Value Value"], title="Chat Trace Details")
|
||||
split_long_text = output == "text"
|
||||
|
||||
for sp in trace_spans:
|
||||
op = sp["operation_name"]
|
||||
metadata = sp.get("metadata")
|
||||
if op == "get_chat_instance" and not sp["end_time"]:
|
||||
table.add_row(["trace_id", trace_id])
|
||||
table.add_row(["span_id", sp["span_id"]])
|
||||
table.add_row(["conv_uid", metadata.get("conv_uid")])
|
||||
table.add_row(["user_input", metadata.get("user_input")])
|
||||
table.add_row(["chat_mode", metadata.get("chat_mode")])
|
||||
table.add_row(["select_param", metadata.get("select_param")])
|
||||
table.add_row(["model_name", metadata.get("model_name")])
|
||||
if op in ["BaseChat.stream_call", "BaseChat.nostream_call"]:
|
||||
if not sp["end_time"]:
|
||||
table.add_row(["temperature", metadata.get("temperature")])
|
||||
table.add_row(["max_new_tokens", metadata.get("max_new_tokens")])
|
||||
table.add_row(["echo", metadata.get("echo")])
|
||||
elif "error" in metadata:
|
||||
table.add_row(["BaseChat Error", metadata.get("error")])
|
||||
if op == "BaseChat.do_action" and not sp["end_time"]:
|
||||
if "model_output" in metadata:
|
||||
table.add_row(
|
||||
[
|
||||
"BaseChat model_output",
|
||||
split_string_by_terminal_width(
|
||||
metadata.get("model_output").get("text"),
|
||||
split=split_long_text,
|
||||
),
|
||||
]
|
||||
)
|
||||
if "ai_response_text" in metadata:
|
||||
table.add_row(
|
||||
[
|
||||
"BaseChat ai_response_text",
|
||||
split_string_by_terminal_width(
|
||||
metadata.get("ai_response_text"), split=split_long_text
|
||||
),
|
||||
]
|
||||
)
|
||||
if "prompt_define_response" in metadata:
|
||||
prompt_define_response = metadata.get("prompt_define_response") or ""
|
||||
if isinstance(prompt_define_response, dict) or isinstance(
|
||||
prompt_define_response, type([])
|
||||
):
|
||||
prompt_define_response = json.dumps(
|
||||
prompt_define_response, ensure_ascii=False
|
||||
)
|
||||
table.add_row(
|
||||
[
|
||||
"BaseChat prompt_define_response",
|
||||
split_string_by_terminal_width(
|
||||
prompt_define_response,
|
||||
split=split_long_text,
|
||||
),
|
||||
]
|
||||
)
|
||||
if op == "DefaultModelWorker_call.generate_stream_func":
|
||||
if not sp["end_time"]:
|
||||
table.add_row(["llm_adapter", metadata.get("llm_adapter")])
|
||||
table.add_row(
|
||||
[
|
||||
"User prompt",
|
||||
split_string_by_terminal_width(
|
||||
metadata.get("prompt"), split=split_long_text
|
||||
),
|
||||
]
|
||||
)
|
||||
else:
|
||||
table.add_row(
|
||||
[
|
||||
"Model output",
|
||||
split_string_by_terminal_width(metadata.get("output")),
|
||||
]
|
||||
)
|
||||
if (
|
||||
op
|
||||
in [
|
||||
"DefaultModelWorker.async_generate_stream",
|
||||
"DefaultModelWorker.generate_stream",
|
||||
]
|
||||
and metadata
|
||||
and "error" in metadata
|
||||
):
|
||||
table.add_row(["Model Error", metadata.get("error")])
|
||||
print(table.get_formatted_string(out_format=output, **out_kwargs))
|
||||
|
||||
|
||||
def read_spans_from_files(files=None) -> Iterable[Dict]:
|
||||
"""
|
||||
Reads spans from multiple files based on the provided file paths.
|
||||
"""
|
||||
if not files:
|
||||
files = [_DEFAULT_FILE_PATTERN]
|
||||
|
||||
for filepath in files:
|
||||
for filename in glob.glob(filepath):
|
||||
with open(filename, "r") as file:
|
||||
for line in file:
|
||||
yield json.loads(line)
|
||||
|
||||
|
||||
def _print_empty_message(files=None):
|
||||
if not files:
|
||||
files = [_DEFAULT_FILE_PATTERN]
|
||||
file_names = ",".join(files)
|
||||
print(f"No trace span records found in your tracer files: {file_names}")
|
||||
|
||||
|
||||
def _new_search_span_func(search: str):
|
||||
def func(span: Dict) -> bool:
|
||||
items = [span["trace_id"], span["span_id"], span["parent_span_id"]]
|
||||
if "operation_name" in span:
|
||||
items.append(span["operation_name"])
|
||||
if "metadata" in span:
|
||||
metadata = span["metadata"]
|
||||
if isinstance(metadata, dict):
|
||||
for k, v in metadata.items():
|
||||
items.append(k)
|
||||
items.append(v)
|
||||
return any(search in str(item) for item in items if item)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def _parse_datetime(dt_str):
|
||||
"""Parse a datetime string to a datetime object."""
|
||||
return datetime.strptime(dt_str, "%Y-%m-%d %H:%M:%S.%f")
|
||||
|
||||
|
||||
def _build_trace_hierarchy(spans, parent_span_id=None, indent=0):
|
||||
# Current spans
|
||||
current_level_spans = [
|
||||
span
|
||||
for span in spans
|
||||
if span["parent_span_id"] == parent_span_id and span["end_time"] is None
|
||||
]
|
||||
|
||||
hierarchy = []
|
||||
|
||||
for start_span in current_level_spans:
|
||||
# Find end span
|
||||
end_span = next(
|
||||
(
|
||||
span
|
||||
for span in spans
|
||||
if span["span_id"] == start_span["span_id"]
|
||||
and span["end_time"] is not None
|
||||
),
|
||||
None,
|
||||
)
|
||||
entry = {
|
||||
"operation_name": start_span["operation_name"],
|
||||
"parent_span_id": start_span["parent_span_id"],
|
||||
"span_id": start_span["span_id"],
|
||||
"start_time": start_span["start_time"],
|
||||
"end_time": start_span["end_time"],
|
||||
"metadata": start_span["metadata"],
|
||||
"children": _build_trace_hierarchy(
|
||||
spans, start_span["span_id"], indent + 1
|
||||
),
|
||||
}
|
||||
hierarchy.append(entry)
|
||||
|
||||
# Append end span
|
||||
if end_span:
|
||||
entry_end = {
|
||||
"operation_name": end_span["operation_name"],
|
||||
"parent_span_id": end_span["parent_span_id"],
|
||||
"span_id": end_span["span_id"],
|
||||
"start_time": end_span["start_time"],
|
||||
"end_time": end_span["end_time"],
|
||||
"metadata": end_span["metadata"],
|
||||
"children": [],
|
||||
}
|
||||
hierarchy.append(entry_end)
|
||||
|
||||
return hierarchy
|
||||
|
||||
|
||||
def _view_trace_hierarchy(trace_id, files=None):
|
||||
"""Find and display the calls of the entire link based on the given trace_id"""
|
||||
spans = read_spans_from_files(files)
|
||||
trace_spans = [span for span in spans if span["trace_id"] == trace_id]
|
||||
if not trace_spans:
|
||||
return None
|
||||
hierarchy = _build_trace_hierarchy(trace_spans)
|
||||
return hierarchy
|
||||
|
||||
|
||||
def _print_trace_hierarchy(hierarchy, indent=0):
|
||||
"""Print link hierarchy"""
|
||||
for entry in hierarchy:
|
||||
print(
|
||||
" " * indent
|
||||
+ f"Operation: {entry['operation_name']} (Start: {entry['start_time']}, End: {entry['end_time']})"
|
||||
)
|
||||
_print_trace_hierarchy(entry["children"], indent + 1)
|
||||
|
||||
|
||||
def _get_ordered_trace_from(hierarchy):
|
||||
traces = []
|
||||
|
||||
def func(items):
|
||||
for item in items:
|
||||
traces.append(item)
|
||||
func(item["children"])
|
||||
|
||||
func(hierarchy)
|
||||
return traces
|
||||
|
||||
|
||||
def _print(service_spans: Dict):
|
||||
for names in [
|
||||
[SpanTypeRunName.WEBSERVER.name, SpanTypeRunName.EMBEDDING_MODEL],
|
||||
[SpanTypeRunName.WORKER_MANAGER.name, SpanTypeRunName.MODEL_WORKER],
|
||||
]:
|
||||
pass
|
||||
|
||||
|
||||
def merge_tables_horizontally(tables):
|
||||
from prettytable import PrettyTable
|
||||
|
||||
if not tables:
|
||||
return None
|
||||
|
||||
tables = [t for t in tables if t]
|
||||
if not tables:
|
||||
return None
|
||||
|
||||
max_rows = max(len(table._rows) for table in tables)
|
||||
|
||||
merged_table = PrettyTable()
|
||||
|
||||
new_field_names = []
|
||||
for table in tables:
|
||||
new_field_names.extend(
|
||||
[
|
||||
f"{name} ({table.title})" if table.title else f"{name}"
|
||||
for name in table.field_names
|
||||
]
|
||||
)
|
||||
|
||||
merged_table.field_names = new_field_names
|
||||
|
||||
for i in range(max_rows):
|
||||
merged_row = []
|
||||
for table in tables:
|
||||
if i < len(table._rows):
|
||||
merged_row.extend(table._rows[i])
|
||||
else:
|
||||
# Fill empty cells for shorter tables
|
||||
merged_row.extend([""] * len(table.field_names))
|
||||
merged_table.add_row(merged_row)
|
||||
|
||||
return merged_table
|
||||
|
||||
|
||||
def split_string_by_terminal_width(s, split=True, max_len=None, sp="\n"):
|
||||
"""
|
||||
Split a string into substrings based on the current terminal width.
|
||||
|
||||
Parameters:
|
||||
- s: the input string
|
||||
"""
|
||||
if not split:
|
||||
return s
|
||||
if not max_len:
|
||||
try:
|
||||
max_len = int(os.get_terminal_size().columns * 0.8)
|
||||
except OSError:
|
||||
# Default to 80 columns if the terminal size can't be determined
|
||||
max_len = 100
|
||||
return sp.join([s[i : i + max_len] for i in range(0, len(s), max_len)])
|
235
dbgpt/util/tracer/tracer_impl.py
Normal file
235
dbgpt/util/tracer/tracer_impl.py
Normal file
@@ -0,0 +1,235 @@
|
||||
from typing import Dict, Optional
|
||||
from contextvars import ContextVar
|
||||
from functools import wraps
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
|
||||
from dbgpt.component import SystemApp, ComponentType
|
||||
from dbgpt.util.tracer.base import (
|
||||
SpanType,
|
||||
Span,
|
||||
Tracer,
|
||||
SpanStorage,
|
||||
SpanStorageType,
|
||||
TracerContext,
|
||||
)
|
||||
from dbgpt.util.tracer.span_storage import MemorySpanStorage
|
||||
from dbgpt.util.module_utils import import_from_checked_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DefaultTracer(Tracer):
|
||||
def __init__(
|
||||
self,
|
||||
system_app: SystemApp | None = None,
|
||||
default_storage: SpanStorage = None,
|
||||
span_storage_type: SpanStorageType = SpanStorageType.ON_CREATE_END,
|
||||
):
|
||||
super().__init__(system_app)
|
||||
self._span_stack_var = ContextVar("span_stack", default=[])
|
||||
|
||||
if not default_storage:
|
||||
default_storage = MemorySpanStorage(system_app)
|
||||
self._default_storage = default_storage
|
||||
self._span_storage_type = span_storage_type
|
||||
|
||||
def append_span(self, span: Span):
|
||||
self._get_current_storage().append_span(span)
|
||||
|
||||
def start_span(
|
||||
self,
|
||||
operation_name: str,
|
||||
parent_span_id: str = None,
|
||||
span_type: SpanType = None,
|
||||
metadata: Dict = None,
|
||||
) -> Span:
|
||||
trace_id = (
|
||||
self._new_uuid() if parent_span_id is None else parent_span_id.split(":")[0]
|
||||
)
|
||||
span_id = f"{trace_id}:{self._new_uuid()}"
|
||||
span = Span(
|
||||
trace_id,
|
||||
span_id,
|
||||
span_type,
|
||||
parent_span_id,
|
||||
operation_name,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
if self._span_storage_type in [
|
||||
SpanStorageType.ON_END,
|
||||
SpanStorageType.ON_CREATE_END,
|
||||
]:
|
||||
span.add_end_caller(self.append_span)
|
||||
|
||||
if self._span_storage_type in [
|
||||
SpanStorageType.ON_CREATE,
|
||||
SpanStorageType.ON_CREATE_END,
|
||||
]:
|
||||
self.append_span(span)
|
||||
current_stack = self._span_stack_var.get()
|
||||
current_stack.append(span)
|
||||
self._span_stack_var.set(current_stack)
|
||||
|
||||
span.add_end_caller(self._remove_from_stack_top)
|
||||
return span
|
||||
|
||||
def end_span(self, span: Span, **kwargs):
|
||||
""""""
|
||||
span.end(**kwargs)
|
||||
|
||||
def _remove_from_stack_top(self, span: Span):
|
||||
current_stack = self._span_stack_var.get()
|
||||
if current_stack:
|
||||
current_stack.pop()
|
||||
self._span_stack_var.set(current_stack)
|
||||
|
||||
def get_current_span(self) -> Optional[Span]:
|
||||
current_stack = self._span_stack_var.get()
|
||||
return current_stack[-1] if current_stack else None
|
||||
|
||||
def _get_current_storage(self) -> SpanStorage:
|
||||
return self.system_app.get_component(
|
||||
ComponentType.TRACER_SPAN_STORAGE, SpanStorage, self._default_storage
|
||||
)
|
||||
|
||||
|
||||
class TracerManager:
|
||||
"""The manager of current tracer"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._system_app: Optional[SystemApp] = None
|
||||
self._trace_context_var: ContextVar[TracerContext] = ContextVar(
|
||||
"trace_context",
|
||||
default=TracerContext(),
|
||||
)
|
||||
|
||||
def initialize(
|
||||
self, system_app: SystemApp, trace_context_var: ContextVar[TracerContext] = None
|
||||
) -> None:
|
||||
self._system_app = system_app
|
||||
if trace_context_var:
|
||||
self._trace_context_var = trace_context_var
|
||||
|
||||
def _get_tracer(self) -> Tracer:
|
||||
if not self._system_app:
|
||||
return None
|
||||
return self._system_app.get_component(ComponentType.TRACER, Tracer, None)
|
||||
|
||||
def start_span(
|
||||
self,
|
||||
operation_name: str,
|
||||
parent_span_id: str = None,
|
||||
span_type: SpanType = None,
|
||||
metadata: Dict = None,
|
||||
) -> Span:
|
||||
"""Start a new span with operation_name
|
||||
This method must not throw an exception under any case and try not to block as much as possible
|
||||
"""
|
||||
tracer = self._get_tracer()
|
||||
if not tracer:
|
||||
return Span("empty_span", "empty_span")
|
||||
if not parent_span_id:
|
||||
parent_span_id = self.get_current_span_id()
|
||||
return tracer.start_span(
|
||||
operation_name, parent_span_id, span_type=span_type, metadata=metadata
|
||||
)
|
||||
|
||||
def end_span(self, span: Span, **kwargs):
|
||||
tracer = self._get_tracer()
|
||||
if not tracer or not span:
|
||||
return
|
||||
tracer.end_span(span, **kwargs)
|
||||
|
||||
def get_current_span(self) -> Optional[Span]:
|
||||
tracer = self._get_tracer()
|
||||
if not tracer:
|
||||
return None
|
||||
return tracer.get_current_span()
|
||||
|
||||
def get_current_span_id(self) -> Optional[str]:
|
||||
current_span = self.get_current_span()
|
||||
if current_span:
|
||||
return current_span.span_id
|
||||
ctx = self._trace_context_var.get()
|
||||
return ctx.span_id if ctx else None
|
||||
|
||||
|
||||
root_tracer: TracerManager = TracerManager()
|
||||
|
||||
|
||||
def trace(operation_name: Optional[str] = None, **trace_kwargs):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
name = (
|
||||
operation_name if operation_name else _parse_operation_name(func, *args)
|
||||
)
|
||||
with root_tracer.start_span(name, **trace_kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
name = (
|
||||
operation_name if operation_name else _parse_operation_name(func, *args)
|
||||
)
|
||||
with root_tracer.start_span(name, **trace_kwargs):
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _parse_operation_name(func, *args):
|
||||
self_name = None
|
||||
if inspect.signature(func).parameters.get("self"):
|
||||
self_name = args[0].__class__.__name__
|
||||
func_name = func.__name__
|
||||
if self_name:
|
||||
return f"{self_name}.{func_name}"
|
||||
return func_name
|
||||
|
||||
|
||||
def initialize_tracer(
|
||||
system_app: SystemApp,
|
||||
tracer_filename: str,
|
||||
root_operation_name: str = "DB-GPT-Web-Entry",
|
||||
tracer_storage_cls: str = None,
|
||||
):
|
||||
if not system_app:
|
||||
return
|
||||
from dbgpt.util.tracer.span_storage import FileSpanStorage, SpanStorageContainer
|
||||
|
||||
trace_context_var = ContextVar(
|
||||
"trace_context",
|
||||
default=TracerContext(),
|
||||
)
|
||||
tracer = DefaultTracer(system_app)
|
||||
|
||||
storage_container = SpanStorageContainer(system_app)
|
||||
storage_container.append_storage(FileSpanStorage(tracer_filename))
|
||||
|
||||
if tracer_storage_cls:
|
||||
logger.info(f"Begin parse storage class {tracer_storage_cls}")
|
||||
storage = import_from_checked_string(tracer_storage_cls, SpanStorage)
|
||||
storage_container.append_storage(storage())
|
||||
|
||||
system_app.register_instance(storage_container)
|
||||
system_app.register_instance(tracer)
|
||||
root_tracer.initialize(system_app, trace_context_var)
|
||||
if system_app.app:
|
||||
from dbgpt.util.tracer.tracer_middleware import TraceIDMiddleware
|
||||
|
||||
system_app.app.add_middleware(
|
||||
TraceIDMiddleware,
|
||||
trace_context_var=trace_context_var,
|
||||
tracer=tracer,
|
||||
root_operation_name=root_operation_name,
|
||||
)
|
45
dbgpt/util/tracer/tracer_middleware.py
Normal file
45
dbgpt/util/tracer/tracer_middleware.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import uuid
|
||||
from contextvars import ContextVar
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.types import ASGIApp
|
||||
from dbgpt.util.tracer import TracerContext, Tracer
|
||||
|
||||
|
||||
_DEFAULT_EXCLUDE_PATHS = ["/api/controller/heartbeat"]
|
||||
|
||||
|
||||
class TraceIDMiddleware(BaseHTTPMiddleware):
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
trace_context_var: ContextVar[TracerContext],
|
||||
tracer: Tracer,
|
||||
root_operation_name: str = "DB-GPT-Web-Entry",
|
||||
include_prefix: str = "/api",
|
||||
exclude_paths=_DEFAULT_EXCLUDE_PATHS,
|
||||
):
|
||||
super().__init__(app)
|
||||
self.trace_context_var = trace_context_var
|
||||
self.tracer = tracer
|
||||
self.root_operation_name = root_operation_name
|
||||
self.include_prefix = include_prefix
|
||||
self.exclude_paths = exclude_paths
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if request.url.path in self.exclude_paths or not request.url.path.startswith(
|
||||
self.include_prefix
|
||||
):
|
||||
return await call_next(request)
|
||||
|
||||
span_id = request.headers.get("DBGPT_TRACER_SPAN_ID")
|
||||
# if not span_id:
|
||||
# span_id = str(uuid.uuid4())
|
||||
# self.trace_context_var.set(TracerContext(span_id=span_id))
|
||||
|
||||
with self.tracer.start_span(
|
||||
self.root_operation_name, span_id, metadata={"path": request.url.path}
|
||||
):
|
||||
response = await call_next(request)
|
||||
return response
|
229
dbgpt/util/utils.py
Normal file
229
dbgpt/util/utils.py
Normal file
@@ -0,0 +1,229 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import logging
|
||||
import logging.handlers
|
||||
from typing import Any, List
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
|
||||
from dbgpt.configs.model_config import LOGDIR
|
||||
|
||||
server_error_msg = (
|
||||
"**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
||||
)
|
||||
|
||||
handler = None
|
||||
|
||||
|
||||
def _get_logging_level() -> str:
|
||||
return os.getenv("DBGPT_LOG_LEVEL", "INFO")
|
||||
|
||||
|
||||
def setup_logging_level(logging_level=None, logger_name: str = None):
|
||||
if not logging_level:
|
||||
logging_level = _get_logging_level()
|
||||
if type(logging_level) is str:
|
||||
logging_level = logging.getLevelName(logging_level.upper())
|
||||
if logger_name:
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.setLevel(logging_level)
|
||||
else:
|
||||
logging.basicConfig(level=logging_level, encoding="utf-8")
|
||||
|
||||
|
||||
def setup_logging(logger_name: str, logging_level=None, logger_filename: str = None):
|
||||
if not logging_level:
|
||||
logging_level = _get_logging_level()
|
||||
logger = _build_logger(logger_name, logging_level, logger_filename)
|
||||
try:
|
||||
import coloredlogs
|
||||
|
||||
color_level = logging_level if logging_level else "INFO"
|
||||
coloredlogs.install(level=color_level, logger=logger)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def get_gpu_memory(max_gpus=None):
|
||||
import torch
|
||||
|
||||
gpu_memory = []
|
||||
num_gpus = (
|
||||
torch.cuda.device_count()
|
||||
if max_gpus is None
|
||||
else min(max_gpus, torch.cuda.device_count())
|
||||
)
|
||||
|
||||
for gpu_id in range(num_gpus):
|
||||
with torch.cuda.device(gpu_id):
|
||||
device = torch.cuda.current_device()
|
||||
gpu_properties = torch.cuda.get_device_properties(device)
|
||||
total_memory = gpu_properties.total_memory / (1024**3)
|
||||
allocated_memory = torch.cuda.memory_allocated() / (1024**3)
|
||||
available_memory = total_memory - allocated_memory
|
||||
gpu_memory.append(available_memory)
|
||||
return gpu_memory
|
||||
|
||||
|
||||
def _build_logger(logger_name, logging_level=None, logger_filename: str = None):
|
||||
global handler
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
# Set the format of root handlers
|
||||
if not logging.getLogger().handlers:
|
||||
setup_logging_level(logging_level=logging_level)
|
||||
logging.getLogger().handlers[0].setFormatter(formatter)
|
||||
|
||||
# Redirect stdout and stderr to loggers
|
||||
# stdout_logger = logging.getLogger("stdout")
|
||||
# stdout_logger.setLevel(logging.INFO)
|
||||
# sl_1 = StreamToLogger(stdout_logger, logging.INFO)
|
||||
# sys.stdout = sl_1
|
||||
#
|
||||
# stderr_logger = logging.getLogger("stderr")
|
||||
# stderr_logger.setLevel(logging.ERROR)
|
||||
# sl = StreamToLogger(stderr_logger, logging.ERROR)
|
||||
# sys.stderr = sl
|
||||
|
||||
# Add a file handler for all loggers
|
||||
if handler is None and logger_filename:
|
||||
os.makedirs(LOGDIR, exist_ok=True)
|
||||
filename = os.path.join(LOGDIR, logger_filename)
|
||||
handler = logging.handlers.TimedRotatingFileHandler(
|
||||
filename, when="D", utc=True
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
for name, item in logging.root.manager.loggerDict.items():
|
||||
if isinstance(item, logging.Logger):
|
||||
item.addHandler(handler)
|
||||
# Get logger
|
||||
logger = logging.getLogger(logger_name)
|
||||
setup_logging_level(logging_level=logging_level, logger_name=logger_name)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
class StreamToLogger(object):
|
||||
"""
|
||||
Fake file-like stream object that redirects writes to a logger instance.
|
||||
"""
|
||||
|
||||
def __init__(self, logger, log_level=logging.INFO):
|
||||
self.terminal = sys.stdout
|
||||
self.logger = logger
|
||||
self.log_level = log_level
|
||||
self.linebuf = ""
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.terminal, attr)
|
||||
|
||||
def write(self, buf):
|
||||
temp_linebuf = self.linebuf + buf
|
||||
self.linebuf = ""
|
||||
for line in temp_linebuf.splitlines(True):
|
||||
# From the io.TextIOWrapper docs:
|
||||
# On output, if newline is None, any '\n' characters written
|
||||
# are translated to the system default line separator.
|
||||
# By default sys.stdout.write() expects '\n' newlines and then
|
||||
# translates them so this is still cross platform.
|
||||
if line[-1] == "\n":
|
||||
encoded_message = line.encode("utf-8", "ignore").decode("utf-8")
|
||||
self.logger.log(self.log_level, encoded_message.rstrip())
|
||||
else:
|
||||
self.linebuf += line
|
||||
|
||||
def flush(self):
|
||||
if self.linebuf != "":
|
||||
encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8")
|
||||
self.logger.log(self.log_level, encoded_message.rstrip())
|
||||
self.linebuf = ""
|
||||
|
||||
|
||||
def disable_torch_init():
|
||||
"""
|
||||
Disable the redundant torch default initialization to accelerate model creation.
|
||||
"""
|
||||
import torch
|
||||
|
||||
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
||||
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
||||
|
||||
|
||||
def pretty_print_semaphore(semaphore):
|
||||
if semaphore is None:
|
||||
return "None"
|
||||
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
|
||||
|
||||
|
||||
def get_or_create_event_loop() -> asyncio.BaseEventLoop:
|
||||
loop = None
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
assert loop is not None
|
||||
return loop
|
||||
except RuntimeError as e:
|
||||
if not "no running event loop" in str(e) and not "no current event loop" in str(
|
||||
e
|
||||
):
|
||||
raise e
|
||||
logging.warning("Cant not get running event loop, create new event loop now")
|
||||
return asyncio.get_event_loop_policy().new_event_loop()
|
||||
|
||||
|
||||
def logging_str_to_uvicorn_level(log_level_str):
|
||||
level_str_mapping = {
|
||||
"CRITICAL": "critical",
|
||||
"ERROR": "error",
|
||||
"WARNING": "warning",
|
||||
"INFO": "info",
|
||||
"DEBUG": "debug",
|
||||
"NOTSET": "info",
|
||||
}
|
||||
return level_str_mapping.get(log_level_str.upper(), "info")
|
||||
|
||||
|
||||
class EndpointFilter(logging.Filter):
|
||||
"""Disable access log on certain endpoint
|
||||
|
||||
source: https://github.com/encode/starlette/issues/864#issuecomment-1254987630
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._path = path
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
return record.getMessage().find(self._path) == -1
|
||||
|
||||
|
||||
def setup_http_service_logging(exclude_paths: List[str] = None):
|
||||
"""Setup http service logging
|
||||
|
||||
Now just disable some logs
|
||||
|
||||
Args:
|
||||
exclude_paths (List[str]): The paths to disable log
|
||||
"""
|
||||
if not exclude_paths:
|
||||
# Not show heartbeat log
|
||||
exclude_paths = ["/api/controller/heartbeat"]
|
||||
uvicorn_logger = logging.getLogger("uvicorn.access")
|
||||
if uvicorn_logger:
|
||||
for path in exclude_paths:
|
||||
uvicorn_logger.addFilter(EndpointFilter(path=path))
|
||||
httpx_logger = logging.getLogger("httpx")
|
||||
if httpx_logger:
|
||||
httpx_logger.setLevel(logging.WARNING)
|
Reference in New Issue
Block a user