Merge branch 'main' into Agent_Hub_Dev

# Conflicts:
#	pilot/model/parameter.py
#	pilot/server/dbgpt_server.py
This commit is contained in:
yhjun1026 2023-10-16 14:34:29 +08:00
commit 3e74a5e1cd
17 changed files with 493 additions and 90 deletions

View File

@ -1,2 +1,5 @@
models/
plugins/
pilot/data
pilot/message
logs/

View File

@ -49,7 +49,7 @@ DB-GPT is an experimental open-source project that uses localized GPT large mode
- [features](#features)
- [contribution](#contribution)
- [roadmap](#roadmap)
- [contract](#contact-information)
- [contact](#contact-information)
[DB-GPT Youtube Video](https://www.youtube.com/watch?v=f5_g0OObZBQ)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 186 KiB

After

Width:  |  Height:  |  Size: 165 KiB

View File

@ -3,7 +3,7 @@ ARG BASE_IMAGE="nvidia/cuda:11.8.0-runtime-ubuntu22.04"
FROM ${BASE_IMAGE}
ARG BASE_IMAGE
RUN apt-get update && apt-get install -y git python3 pip wget sqlite3 \
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y git python3 pip wget sqlite3 tzdata \
&& apt-get clean
ARG BUILD_LOCAL_CODE="false"
@ -44,11 +44,6 @@ ARG BUILD_LOCAL_CODE="false"
# COPY the rest of the app
COPY . /app
# TODONeed to find a better way to determine whether to build docker image with local code.
RUN (if [ "${BUILD_LOCAL_CODE}" = "true" ]; \
then rm -rf /app/logs && rm -rf /app/pilot/data && rm -rf /app/pilot/message; \
fi;)
ARG LOAD_EXAMPLES="true"
RUN (if [ "${LOAD_EXAMPLES}" = "true" ]; \
@ -57,6 +52,11 @@ RUN (if [ "${LOAD_EXAMPLES}" = "true" ]; \
&& sqlite3 /app/pilot/data/default_sqlite.db < /app/docker/examples/sqls/test_case_info_sqlite.sql; \
fi;)
RUN (if [ "${LANGUAGE}" = "zh" ]; \
then ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \
&& echo "Asia/Shanghai" > /etc/timezone; \
fi;)
ENV PYTHONPATH "/app:$PYTHONPATH"
EXPOSE 5000

View File

@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
from pilot.model.parameter import BaseEmbeddingModelParameters
from pilot.utils.parameter_utils import _get_dict_from_obj
from pilot.utils.tracer import root_tracer, SpanType, SpanTypeRunName
from pilot.utils.system_utils import get_system_info
if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings
@ -21,6 +22,7 @@ class EmbeddingLoader:
"model_name": model_name,
"run_service": SpanTypeRunName.EMBEDDING_MODEL.value,
"params": _get_dict_from_obj(param),
"sys_infos": _get_dict_from_obj(get_system_info()),
}
with root_tracer.start_span(
"EmbeddingLoader.load", span_type=SpanType.RUN, metadata=metadata

View File

@ -11,6 +11,7 @@ from pilot.model.cluster.worker_base import ModelWorker
from pilot.utils.model_utils import _clear_model_cache
from pilot.utils.parameter_utils import EnvArgumentParser, _get_dict_from_obj
from pilot.utils.tracer import root_tracer, SpanType, SpanTypeRunName
from pilot.utils.system_utils import get_system_info
logger = logging.getLogger(__name__)
@ -102,6 +103,7 @@ class DefaultModelWorker(ModelWorker):
"llm_adapter": str(self.llm_adapter),
"run_service": SpanTypeRunName.MODEL_WORKER,
"params": _get_dict_from_obj(model_params),
"sys_infos": _get_dict_from_obj(get_system_info()),
}
with root_tracer.start_span(
"DefaultModelWorker.start", span_type=SpanType.RUN, metadata=metadata

View File

@ -40,6 +40,7 @@ from pilot.utils.parameter_utils import (
)
from pilot.utils.utils import setup_logging
from pilot.utils.tracer import initialize_tracer, root_tracer, SpanType, SpanTypeRunName
from pilot.utils.system_utils import get_system_info
logger = logging.getLogger(__name__)
@ -838,6 +839,7 @@ def _start_local_worker(
metadata={
"run_service": SpanTypeRunName.WORKER_MANAGER,
"params": _get_dict_from_obj(worker_params),
"sys_infos": _get_dict_from_obj(get_system_info()),
},
):
worker = _build_worker(worker_params)
@ -974,6 +976,7 @@ def run_worker_manager(
os.path.join(LOGDIR, "dbgpt_model_worker_manager_tracer.jsonl"),
root_operation_name="DB-GPT-WorkerManager-Entry",
)
_start_local_worker(worker_manager, worker_params)
_start_local_embedding_worker(
worker_manager, embedding_model_name, embedding_model_path
@ -985,11 +988,13 @@ def run_worker_manager(
if not embedded_mod:
import uvicorn
loop = asyncio.get_event_loop()
loop.run_until_complete(worker_manager.start())
uvicorn.run(
app, host=worker_params.host, port=worker_params.port, log_level="info"
)
else:
# Embedded mod, start worker manager
loop = asyncio.get_event_loop()
loop.run_until_complete(worker_manager.start())
if __name__ == "__main__":

View File

@ -51,6 +51,12 @@ _OLD_MODELS = [
class LLMModelAdaper:
"""New Adapter for DB-GPT LLM models"""
def use_fast_tokenizer(self) -> bool:
"""Whether use a [fast Rust-based tokenizer](https://huggingface.co/docs/tokenizers/index) if it is supported
for a given model.
"""
return False
def model_type(self) -> str:
return ModelType.HF
@ -169,6 +175,9 @@ class OldLLMModelAdaperWrapper(LLMModelAdaper):
self._adapter = adapter
self._chat_adapter = chat_adapter
def use_fast_tokenizer(self) -> bool:
return self._adapter.use_fast_tokenizer()
def model_type(self) -> str:
return self._adapter.model_type()
@ -200,6 +209,9 @@ class FastChatLLMModelAdaperWrapper(LLMModelAdaper):
def __init__(self, adapter: "BaseModelAdapter") -> None:
self._adapter = adapter
def use_fast_tokenizer(self) -> bool:
return self._adapter.use_fast_tokenizer
def load(self, model_path: str, from_pretrained_kwargs: dict):
return self._adapter.load_model(model_path, from_pretrained_kwargs)

View File

@ -282,20 +282,28 @@ class ProxyModelParameters(BaseModelParameters):
"help": "Proxy server url, such as: https://api.openai.com/v1/chat/completions"
},
)
proxy_api_key: str = field(
metadata={"tags": "privacy", "help": "The api key of current proxy LLM"},
)
proxy_app_id: Optional[str] = field(
proxy_api_base: str = field(
default=None,
metadata={
"help": "Appid for visitor proxy"
"help": "The base api address, such as: https://api.openai.com/v1. If None, we will use proxy_api_base first"
},
)
proxy_api_secret: Optional[str] = field(
proxy_api_type: Optional[str] = field(
default=None,
metadata={"tags": "privacy", "help": "The api secret of current proxy LLM"},
metadata={
"help": "The api type of current proxy the current proxy model, if you use Azure, it can be: azure"
},
)
proxy_api_version: Optional[str] = field(
default=None,
metadata={"help": "The api version of current proxy the current model"},
)
http_proxy: Optional[str] = field(

View File

@ -1,31 +1,63 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import json
import requests
import os
from typing import List
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
import logging
import openai
from pilot.model.proxy.llms.proxy_model import ProxyModel
from pilot.model.parameter import ProxyModelParameters
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
logger = logging.getLogger(__name__)
def chatgpt_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
def _initialize_openai(params: ProxyModelParameters):
api_type = params.proxy_api_type or os.getenv("OPENAI_API_TYPE", "open_ai")
api_base = params.proxy_api_base or os.getenv(
"OPENAI_API_TYPE",
os.getenv("AZURE_OPENAI_ENDPOINT") if api_type == "azure" else None,
)
api_key = params.proxy_api_key or os.getenv(
"OPENAI_API_KEY",
os.getenv("AZURE_OPENAI_KEY") if api_type == "azure" else None,
)
api_version = params.proxy_api_version or os.getenv("OPENAI_API_VERSION")
if not api_base and params.proxy_server_url:
# Adapt previous proxy_server_url configuration
api_base = params.proxy_server_url.split("/chat/completions")[0]
if api_type:
openai.api_type = api_type
if api_base:
openai.api_base = api_base
if api_key:
openai.api_key = api_key
if api_version:
openai.api_version = api_version
if params.http_proxy:
openai.proxy = params.http_proxy
openai_params = {
"api_type": api_type,
"api_base": api_base,
"api_version": api_version,
"proxy": params.http_proxy,
}
return openai_params
def _build_request(model: ProxyModel, params):
history = []
model_params = model.get_params()
print(f"Model: {model}, model_params: {model_params}")
logger.info(f"Model: {model}, model_params: {model_params}")
proxy_api_key = model_params.proxy_api_key
proxy_server_url = model_params.proxy_server_url
proxyllm_backend = model_params.proxyllm_backend
if not proxyllm_backend:
proxyllm_backend = "gpt-3.5-turbo"
headers = {
"Authorization": "Bearer " + proxy_api_key,
"Token": proxy_api_key,
}
openai_params = _initialize_openai(model_params)
messages: List[ModelMessage] = params["messages"]
# Add history conversation
@ -51,29 +83,51 @@ def chatgpt_generate_stream(
history.append(last_user_input)
payloads = {
"model": proxyllm_backend, # just for test, remove this later
"messages": history,
"temperature": params.get("temperature"),
"max_tokens": params.get("max_new_tokens"),
"stream": True,
}
proxyllm_backend = model_params.proxyllm_backend
res = requests.post(proxy_server_url, headers=headers, json=payloads, stream=True)
if openai_params["api_type"] == "azure":
# engine = "deployment_name".
proxyllm_backend = proxyllm_backend or "gpt-35-turbo"
payloads["engine"] = proxyllm_backend
else:
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
payloads["model"] = proxyllm_backend
print(f"Send request to {proxy_server_url} with real model {proxyllm_backend}")
logger.info(
f"Send request to real model {proxyllm_backend}, openai_params: {openai_params}"
)
return history, payloads
def chatgpt_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
history, payloads = _build_request(model, params)
res = openai.ChatCompletion.create(messages=history, **payloads)
text = ""
for line in res.iter_lines():
if line:
if not line.startswith(b"data: "):
error_message = line.decode("utf-8")
yield error_message
else:
json_data = line.split(b": ", 1)[1]
decoded_line = json_data.decode("utf-8")
if decoded_line.lower() != "[DONE]".lower():
obj = json.loads(json_data)
if obj["choices"][0]["delta"].get("content") is not None:
content = obj["choices"][0]["delta"]["content"]
text += content
yield text
for r in res:
if r["choices"][0]["delta"].get("content") is not None:
content = r["choices"][0]["delta"]["content"]
text += content
yield text
async def async_chatgpt_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
history, payloads = _build_request(model, params)
res = await openai.ChatCompletion.acreate(messages=history, **payloads)
text = ""
async for r in res:
if r["choices"][0]["delta"].get("content") is not None:
content = r["choices"][0]["delta"]["content"]
text += content
yield text

View File

@ -5,7 +5,6 @@ from typing import List
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
import signal
from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG, LOGDIR
from pilot.component import SystemApp
@ -41,6 +40,7 @@ from pilot.utils.utils import (
)
from pilot.utils.tracer import root_tracer, initialize_tracer, SpanType, SpanTypeRunName
from pilot.utils.parameter_utils import _get_dict_from_obj
from pilot.utils.system_utils import get_system_info
from pilot.base_modules.agent.controller import router as agent_route
@ -84,19 +84,6 @@ app.include_router(knowledge_router, tags=["Knowledge"])
app.include_router(prompt_router, tags=["Prompt"])
@app.get("/openapi.json")
async def get_openapi_endpoint():
return get_openapi(
title="Your API title",
version="1.0.0",
description="Your API description",
routes=app.routes,
)
@app.get("/docs")
async def get_docs():
return get_swagger_ui_html(openapi_url="/openapi.json", title="API docs")
def mount_static_files(app):
os.makedirs(static_message_img_path, exist_ok=True)
app.mount(
@ -112,17 +99,20 @@ def mount_static_files(app):
app.add_exception_handler(RequestValidationError, validation_exception_handler)
def _get_webserver_params(args: List[str] = None):
from pilot.utils.parameter_utils import EnvArgumentParser
parser: argparse.ArgumentParser = EnvArgumentParser.create_argparse_option(
WebWerverParameters
)
return WebWerverParameters(**vars(parser.parse_args(args=args)))
def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
"""Initialize app
If you use gunicorn as a process manager, initialize_app can be invoke in `on_starting` hook.
"""
if not param:
from pilot.utils.parameter_utils import EnvArgumentParser
parser: argparse.ArgumentParser = EnvArgumentParser.create_argparse_option(
WebWerverParameters
)
param = WebWerverParameters(**vars(parser.parse_args(args=args)))
param = _get_webserver_params(args)
if not param.log_level:
param.log_level = _get_logging_level()
@ -141,7 +131,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
model_start_listener = _create_model_start_listener(system_app)
initialize_components(param, system_app, embedding_model_name, embedding_model_path)
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
model_path = LLM_MODEL_CONFIG.get(CFG.LLM_MODEL)
if not param.light:
print("Model Unified Deployment Mode!")
if not param.remote_embedding:
@ -188,8 +178,21 @@ def run_uvicorn(param: WebWerverParameters):
def run_webserver(param: WebWerverParameters = None):
param = initialize_app(param)
run_uvicorn(param)
if not param:
param = _get_webserver_params()
initialize_tracer(system_app, os.path.join(LOGDIR, "dbgpt_webserver_tracer.jsonl"))
with root_tracer.start_span(
"run_webserver",
span_type=SpanType.RUN,
metadata={
"run_service": SpanTypeRunName.WEBSERVER,
"params": _get_dict_from_obj(param),
"sys_infos": _get_dict_from_obj(get_system_info()),
},
):
param = initialize_app(param)
run_uvicorn(param)
if __name__ == "__main__":

View File

@ -22,8 +22,10 @@ chat_factory = ChatFactory()
class DBSummaryClient:
"""db summary client, provide db_summary_embedding(put db profile and table profile summary into vector store)
"""DB Summary client, provide db_summary_embedding(put db profile and table profile summary into vector store)
, get_similar_tables method(get user query related tables info)
Args:
system_app (SystemApp): Main System Application class that manages the lifecycle and registration of components..
"""
def __init__(self, system_app: SystemApp):
@ -160,6 +162,12 @@ class DBSummaryClient:
)
def init_db_profile(self, db_summary_client, dbname, embeddings):
"""db profile initialization
Args:
db_summary_client(DBSummaryClient): DB Summary Client
dbname(str): dbname
embeddings(SourceEmbedding): embedding for read string document
"""
from pilot.embedding_engine.string_embedding import StringEmbedding
vector_store_name = dbname + "_profile"
@ -176,9 +184,15 @@ class DBSummaryClient:
docs = []
docs.extend(embedding.read_batch())
for table_summary in db_summary_client.table_info_json():
from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=len(table_summary), chunk_overlap=100
)
embedding = StringEmbedding(
table_summary,
profile_store_config,
file_path=table_summary,
vector_store_config=profile_store_config,
text_splitter=text_splitter,
)
docs.extend(embedding.read_batch())
embedding.index_to_store(docs)

272
pilot/utils/system_utils.py Normal file
View File

@ -0,0 +1,272 @@
from dataclasses import dataclass, asdict
from enum import Enum
from typing import Tuple, Dict
import os
import platform
import subprocess
import re
from functools import cache
@dataclass
class SystemInfo:
platform: str
distribution: str
python_version: str
cpu: str
cpu_avx: str
memory: str
torch_version: str
device: str
device_version: str
device_count: int
device_other: str
def to_dict(self) -> Dict:
return asdict(self)
class AVXType(Enum):
BASIC = "basic"
AVX = "AVX"
AVX2 = "AVX2"
AVX512 = "AVX512"
@staticmethod
def of_type(avx: str):
for item in AVXType:
if item._value_ == avx:
return item
return None
class OSType(str, Enum):
WINDOWS = "win"
LINUX = "linux"
DARWIN = "darwin"
OTHER = "other"
def get_cpu_avx_support() -> Tuple[OSType, AVXType, str]:
system = platform.system()
os_type = OSType.OTHER
cpu_avx = AVXType.BASIC
env_cpu_avx = AVXType.of_type(os.getenv("DBGPT_LLAMA_CPP_AVX"))
distribution = "Unknown Distribution"
if "windows" in system.lower():
os_type = OSType.WINDOWS
output = "avx2"
distribution = "Windows " + platform.release()
print("Current platform is windows, use avx2 as default cpu architecture")
elif system == "Linux":
os_type = OSType.LINUX
result = subprocess.run(
["lscpu"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
output = result.stdout.decode()
distribution = get_linux_distribution()
elif system == "Darwin":
os_type = OSType.DARWIN
result = subprocess.run(
["sysctl", "-a"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
distribution = "Mac OS " + platform.mac_ver()[0]
output = result.stdout.decode()
else:
os_type = OSType.OTHER
print("Unsupported OS to get cpu avx, use default")
return os_type, env_cpu_avx if env_cpu_avx else cpu_avx, distribution
if "avx512" in output.lower():
cpu_avx = AVXType.AVX512
elif "avx2" in output.lower():
cpu_avx = AVXType.AVX2
elif "avx " in output.lower():
# cpu_avx = AVXType.AVX
pass
return os_type, env_cpu_avx if env_cpu_avx else cpu_avx, distribution
def get_device() -> str:
try:
import torch
return (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
except ModuleNotFoundError:
return "cpu"
def get_device_info() -> Tuple[str, str, str, int, str]:
torch_version, device, device_version, device_count, device_other = (
None,
"cpu",
None,
0,
"",
)
try:
import torch
torch_version = torch.__version__
if torch.cuda.is_available():
device = "cuda"
device_version = torch.version.cuda
device_count = torch.cuda.device_count()
elif torch.backends.mps.is_available():
device = "mps"
except ModuleNotFoundError:
pass
if not device_version:
device_version = (
get_cuda_version_from_nvcc() or get_cuda_version_from_nvidia_smi()
)
if device == "cuda":
try:
output = subprocess.check_output(
[
"nvidia-smi",
"--query-gpu=name,driver_version,memory.total,memory.free,memory.used",
"--format=csv",
]
)
device_other = output.decode("utf-8")
except:
pass
return torch_version, device, device_version, device_count, device_other
def get_cuda_version_from_nvcc():
try:
output = subprocess.check_output(["nvcc", "--version"])
version_line = [
line for line in output.decode("utf-8").split("\n") if "release" in line
][0]
return version_line.split("release")[-1].strip().split(",")[0]
except:
return None
def get_cuda_version_from_nvidia_smi():
try:
output = subprocess.check_output(["nvidia-smi"]).decode("utf-8")
match = re.search(r"CUDA Version:\s+(\d+\.\d+)", output)
if match:
return match.group(1)
else:
return None
except:
return None
def get_linux_distribution():
"""Get distribution of Linux"""
if os.path.isfile("/etc/os-release"):
with open("/etc/os-release", "r") as f:
info = {}
for line in f:
key, _, value = line.partition("=")
info[key] = value.strip().strip('"')
return f"{info.get('NAME', 'Unknown')} {info.get('VERSION_ID', '')}".strip()
return "Unknown Linux Distribution"
def get_cpu_info():
# Getting platform
os_type, avx_type, distribution = get_cpu_avx_support()
# Getting CPU information
cpu_info = "Unknown CPU"
if os_type == OSType.LINUX:
try:
output = subprocess.check_output(["lscpu"]).decode("utf-8")
match = re.search(r".*Model name:\s*(.+)", output)
if match:
cpu_info = match.group(1).strip()
match = re.search(f".*型号名称:\s*(.+)", output)
if match:
cpu_info = match.group(1).strip()
except:
pass
elif os_type == OSType.DARWIN:
try:
output = subprocess.check_output(
["sysctl", "machdep.cpu.brand_string"]
).decode("utf-8")
match = re.search(r"machdep.cpu.brand_string:\s*(.+)", output)
if match:
cpu_info = match.group(1).strip()
except:
pass
elif os_type == OSType.WINDOWS:
try:
output = subprocess.check_output("wmic cpu get Name", shell=True).decode(
"utf-8"
)
lines = output.splitlines()
cpu_info = lines[2].split(":")[-1].strip()
except:
pass
return os_type, avx_type, cpu_info, distribution
def get_memory_info(os_type: OSType) -> str:
memory = "Unknown Memory"
try:
import psutil
memory = f"{psutil.virtual_memory().total // (1024 ** 3)} GB"
except ImportError:
pass
if os_type == OSType.LINUX:
try:
with open("/proc/meminfo", "r") as f:
mem_info = f.readlines()
for line in mem_info:
if "MemTotal" in line:
memory = line.split(":")[1].strip()
break
except:
pass
return memory
@cache
def get_system_info() -> SystemInfo:
"""Get System information"""
os_type, avx_type, cpu_info, distribution = get_cpu_info()
# Getting Python version
python_version = platform.python_version()
memory = get_memory_info(os_type)
(
torch_version,
device,
device_version,
device_count,
device_other,
) = get_device_info()
return SystemInfo(
platform=os_type._value_,
distribution=distribution,
python_version=python_version,
cpu=cpu_info,
cpu_avx=avx_type._value_,
memory=memory,
torch_version=torch_version,
device=device,
device_version=device_version,
device_count=device_count,
device_other=device_other,
)

View File

@ -34,7 +34,9 @@ class FileSpanStorage(SpanStorage):
self.flush_signal_queue = queue.Queue()
if not os.path.exists(filename):
with open(filename, "w") as _:
# New file if not exist
os.makedirs(os.path.dirname(filename), exist_ok=True)
with open(filename, "a"):
pass
self.flush_thread = threading.Thread(target=self._flush_to_file, daemon=True)
self.flush_thread.start()

View File

@ -1,7 +1,7 @@
from typing import Dict
from pilot.component import SystemApp
from pilot.utils.tracer import Span, SpanStorage, Tracer
from pilot.utils.tracer import Span, SpanType, SpanStorage, Tracer
# Mock implementations
@ -31,7 +31,9 @@ class MockTracer(Tracer):
self._new_uuid() if parent_span_id is None else parent_span_id.split(":")[0]
)
span_id = f"{trace_id}:{self._new_uuid()}"
span = Span(trace_id, span_id, parent_span_id, operation_name, metadata)
span = Span(
trace_id, span_id, SpanType.BASE, parent_span_id, operation_name, metadata
)
self.current_span = span
return span
@ -50,7 +52,14 @@ class MockTracer(Tracer):
def test_span_creation():
span = Span("trace_id", "span_id", "parent_span_id", "operation", {"key": "value"})
span = Span(
"trace_id",
"span_id",
SpanType.BASE,
"parent_span_id",
"operation",
{"key": "value"},
)
assert span.trace_id == "trace_id"
assert span.span_id == "span_id"
assert span.parent_span_id == "parent_span_id"

View File

@ -5,7 +5,7 @@ import json
import tempfile
import time
from pilot.utils.tracer import SpanStorage, FileSpanStorage, Span
from pilot.utils.tracer import SpanStorage, FileSpanStorage, Span, SpanType
@pytest.fixture
@ -44,7 +44,7 @@ def read_spans_from_file(filename):
"storage", [{"batch_size": 1, "flush_interval": 5}], indirect=True
)
def test_write_span(storage: SpanStorage):
span = Span("1", "a", "b", "op1")
span = Span("1", "a", SpanType.BASE, "b", "op1")
storage.append_span(span)
time.sleep(0.1)
@ -57,8 +57,8 @@ def test_write_span(storage: SpanStorage):
"storage", [{"batch_size": 1, "flush_interval": 5}], indirect=True
)
def test_incremental_write(storage: SpanStorage):
span1 = Span("1", "a", "b", "op1")
span2 = Span("2", "c", "d", "op2")
span1 = Span("1", "a", SpanType.BASE, "b", "op1")
span2 = Span("2", "c", SpanType.BASE, "d", "op2")
storage.append_span(span1)
storage.append_span(span2)
@ -72,7 +72,7 @@ def test_incremental_write(storage: SpanStorage):
"storage", [{"batch_size": 2, "flush_interval": 5}], indirect=True
)
def test_sync_and_async_append(storage: SpanStorage):
span = Span("1", "a", "b", "op1")
span = Span("1", "a", SpanType.BASE, "b", "op1")
storage.append_span(span)
@ -88,7 +88,7 @@ def test_sync_and_async_append(storage: SpanStorage):
@pytest.mark.asyncio
async def test_flush_policy(storage: SpanStorage):
span = Span("1", "a", "b", "op1")
span = Span("1", "a", SpanType.BASE, "b", "op1")
for _ in range(storage.batch_size - 1):
storage.append_span(span)
@ -108,8 +108,8 @@ async def test_flush_policy(storage: SpanStorage):
"storage", [{"batch_size": 2, "file_does_not_exist": True}], indirect=True
)
def test_non_existent_file(storage: SpanStorage):
span = Span("1", "a", "b", "op1")
span2 = Span("2", "c", "d", "op2")
span = Span("1", "a", SpanType.BASE, "b", "op1")
span2 = Span("2", "c", SpanType.BASE, "d", "op2")
storage.append_span(span)
time.sleep(0.1)

View File

@ -259,6 +259,7 @@ def chat(
found_trace_id = trace_id
service_tables = {}
system_infos_table = {}
out_kwargs = {"ensure_ascii": False} if output == "json" else {}
for service_name, sp in service_spans.items():
metadata = sp["metadata"]
@ -266,6 +267,15 @@ def chat(
for k, v in metadata["params"].items():
table.add_row([k, v])
service_tables[service_name] = table
sys_infos = metadata.get("sys_infos")
if sys_infos and isinstance(sys_infos, dict):
sys_table = PrettyTable(
["System Config Key", "System Config Value"],
title=f"{service_name} System information",
)
for k, v in sys_infos.items():
sys_table.add_row([k, v])
system_infos_table[service_name] = sys_table
if not hide_run_params:
merged_table1 = merge_tables_horizontally(
@ -276,16 +286,23 @@ def chat(
)
merged_table2 = merge_tables_horizontally(
[
service_tables.get(SpanTypeRunName.MODEL_WORKER),
service_tables.get(SpanTypeRunName.WORKER_MANAGER),
service_tables.get(SpanTypeRunName.MODEL_WORKER.value),
service_tables.get(SpanTypeRunName.WORKER_MANAGER.value),
]
)
sys_table = system_infos_table.get(SpanTypeRunName.WORKER_MANAGER.value)
if system_infos_table:
for k, v in system_infos_table.items():
sys_table = v
break
if output == "text":
print(merged_table1)
print(merged_table2)
else:
for service_name, table in service_tables.items():
print(table.get_formatted_string(out_format=output, **out_kwargs))
if sys_table:
print(sys_table.get_formatted_string(out_format=output, **out_kwargs))
if hide_conv:
return