diff --git a/.dockerignore b/.dockerignore index e5b067a78..efded29b9 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,2 +1,5 @@ models/ plugins/ +pilot/data +pilot/message +logs/ \ No newline at end of file diff --git a/README.md b/README.md index 2e0f6eaf5..88b488aea 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ DB-GPT is an experimental open-source project that uses localized GPT large mode - [features](#features) - [contribution](#contribution) - [roadmap](#roadmap) -- [contract](#contact-information) +- [contact](#contact-information) [DB-GPT Youtube Video](https://www.youtube.com/watch?v=f5_g0OObZBQ) diff --git a/assets/wechat.jpg b/assets/wechat.jpg index d07c3009c..df4712129 100644 Binary files a/assets/wechat.jpg and b/assets/wechat.jpg differ diff --git a/docker/base/Dockerfile b/docker/base/Dockerfile index 7c6bbf598..63486e260 100644 --- a/docker/base/Dockerfile +++ b/docker/base/Dockerfile @@ -3,7 +3,7 @@ ARG BASE_IMAGE="nvidia/cuda:11.8.0-runtime-ubuntu22.04" FROM ${BASE_IMAGE} ARG BASE_IMAGE -RUN apt-get update && apt-get install -y git python3 pip wget sqlite3 \ +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y git python3 pip wget sqlite3 tzdata \ && apt-get clean ARG BUILD_LOCAL_CODE="false" @@ -44,11 +44,6 @@ ARG BUILD_LOCAL_CODE="false" # COPY the rest of the app COPY . /app -# TODO:Need to find a better way to determine whether to build docker image with local code. -RUN (if [ "${BUILD_LOCAL_CODE}" = "true" ]; \ - then rm -rf /app/logs && rm -rf /app/pilot/data && rm -rf /app/pilot/message; \ - fi;) - ARG LOAD_EXAMPLES="true" RUN (if [ "${LOAD_EXAMPLES}" = "true" ]; \ @@ -57,6 +52,11 @@ RUN (if [ "${LOAD_EXAMPLES}" = "true" ]; \ && sqlite3 /app/pilot/data/default_sqlite.db < /app/docker/examples/sqls/test_case_info_sqlite.sql; \ fi;) +RUN (if [ "${LANGUAGE}" = "zh" ]; \ + then ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \ + && echo "Asia/Shanghai" > /etc/timezone; \ + fi;) + ENV PYTHONPATH "/app:$PYTHONPATH" EXPOSE 5000 diff --git a/pilot/model/cluster/embedding/loader.py b/pilot/model/cluster/embedding/loader.py index caf4bda9a..258e3ec2d 100644 --- a/pilot/model/cluster/embedding/loader.py +++ b/pilot/model/cluster/embedding/loader.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING from pilot.model.parameter import BaseEmbeddingModelParameters from pilot.utils.parameter_utils import _get_dict_from_obj from pilot.utils.tracer import root_tracer, SpanType, SpanTypeRunName +from pilot.utils.system_utils import get_system_info if TYPE_CHECKING: from langchain.embeddings.base import Embeddings @@ -21,6 +22,7 @@ class EmbeddingLoader: "model_name": model_name, "run_service": SpanTypeRunName.EMBEDDING_MODEL.value, "params": _get_dict_from_obj(param), + "sys_infos": _get_dict_from_obj(get_system_info()), } with root_tracer.start_span( "EmbeddingLoader.load", span_type=SpanType.RUN, metadata=metadata diff --git a/pilot/model/cluster/worker/default_worker.py b/pilot/model/cluster/worker/default_worker.py index 378fee2ea..04b47cbdb 100644 --- a/pilot/model/cluster/worker/default_worker.py +++ b/pilot/model/cluster/worker/default_worker.py @@ -11,6 +11,7 @@ from pilot.model.cluster.worker_base import ModelWorker from pilot.utils.model_utils import _clear_model_cache from pilot.utils.parameter_utils import EnvArgumentParser, _get_dict_from_obj from pilot.utils.tracer import root_tracer, SpanType, SpanTypeRunName +from pilot.utils.system_utils import get_system_info logger = logging.getLogger(__name__) @@ -102,6 +103,7 @@ class DefaultModelWorker(ModelWorker): "llm_adapter": str(self.llm_adapter), "run_service": SpanTypeRunName.MODEL_WORKER, "params": _get_dict_from_obj(model_params), + "sys_infos": _get_dict_from_obj(get_system_info()), } with root_tracer.start_span( "DefaultModelWorker.start", span_type=SpanType.RUN, metadata=metadata diff --git a/pilot/model/cluster/worker/manager.py b/pilot/model/cluster/worker/manager.py index 5648c8e01..a85ee0ed7 100644 --- a/pilot/model/cluster/worker/manager.py +++ b/pilot/model/cluster/worker/manager.py @@ -40,6 +40,7 @@ from pilot.utils.parameter_utils import ( ) from pilot.utils.utils import setup_logging from pilot.utils.tracer import initialize_tracer, root_tracer, SpanType, SpanTypeRunName +from pilot.utils.system_utils import get_system_info logger = logging.getLogger(__name__) @@ -838,6 +839,7 @@ def _start_local_worker( metadata={ "run_service": SpanTypeRunName.WORKER_MANAGER, "params": _get_dict_from_obj(worker_params), + "sys_infos": _get_dict_from_obj(get_system_info()), }, ): worker = _build_worker(worker_params) @@ -974,6 +976,7 @@ def run_worker_manager( os.path.join(LOGDIR, "dbgpt_model_worker_manager_tracer.jsonl"), root_operation_name="DB-GPT-WorkerManager-Entry", ) + _start_local_worker(worker_manager, worker_params) _start_local_embedding_worker( worker_manager, embedding_model_name, embedding_model_path @@ -985,11 +988,13 @@ def run_worker_manager( if not embedded_mod: import uvicorn - loop = asyncio.get_event_loop() - loop.run_until_complete(worker_manager.start()) uvicorn.run( app, host=worker_params.host, port=worker_params.port, log_level="info" ) + else: + # Embedded mod, start worker manager + loop = asyncio.get_event_loop() + loop.run_until_complete(worker_manager.start()) if __name__ == "__main__": diff --git a/pilot/model/model_adapter.py b/pilot/model/model_adapter.py index b8f4f0982..57e3cf251 100644 --- a/pilot/model/model_adapter.py +++ b/pilot/model/model_adapter.py @@ -51,6 +51,12 @@ _OLD_MODELS = [ class LLMModelAdaper: """New Adapter for DB-GPT LLM models""" + def use_fast_tokenizer(self) -> bool: + """Whether use a [fast Rust-based tokenizer](https://huggingface.co/docs/tokenizers/index) if it is supported + for a given model. + """ + return False + def model_type(self) -> str: return ModelType.HF @@ -169,6 +175,9 @@ class OldLLMModelAdaperWrapper(LLMModelAdaper): self._adapter = adapter self._chat_adapter = chat_adapter + def use_fast_tokenizer(self) -> bool: + return self._adapter.use_fast_tokenizer() + def model_type(self) -> str: return self._adapter.model_type() @@ -200,6 +209,9 @@ class FastChatLLMModelAdaperWrapper(LLMModelAdaper): def __init__(self, adapter: "BaseModelAdapter") -> None: self._adapter = adapter + def use_fast_tokenizer(self) -> bool: + return self._adapter.use_fast_tokenizer + def load(self, model_path: str, from_pretrained_kwargs: dict): return self._adapter.load_model(model_path, from_pretrained_kwargs) diff --git a/pilot/model/parameter.py b/pilot/model/parameter.py index 271c6bd38..0ad048c24 100644 --- a/pilot/model/parameter.py +++ b/pilot/model/parameter.py @@ -282,22 +282,30 @@ class ProxyModelParameters(BaseModelParameters): "help": "Proxy server url, such as: https://api.openai.com/v1/chat/completions" }, ) + proxy_api_key: str = field( metadata={"tags": "privacy", "help": "The api key of current proxy LLM"}, ) - - proxy_app_id: Optional[str] = field( + + proxy_api_base: str = field( default=None, metadata={ - "help": "Appid for visitor proxy" + "help": "The base api address, such as: https://api.openai.com/v1. If None, we will use proxy_api_base first" }, ) - proxy_api_secret: Optional[str] = field( + proxy_api_type: Optional[str] = field( default=None, - metadata={"tags": "privacy", "help": "The api secret of current proxy LLM"}, + metadata={ + "help": "The api type of current proxy the current proxy model, if you use Azure, it can be: azure" + }, ) - + + proxy_api_version: Optional[str] = field( + default=None, + metadata={"help": "The api version of current proxy the current model"}, + ) + http_proxy: Optional[str] = field( default=os.environ.get("http_proxy") or os.environ.get("https_proxy"), metadata={"help": "The http or https proxy to use openai"}, diff --git a/pilot/model/proxy/llms/chatgpt.py b/pilot/model/proxy/llms/chatgpt.py index 70dea67f0..a2aff0b86 100644 --- a/pilot/model/proxy/llms/chatgpt.py +++ b/pilot/model/proxy/llms/chatgpt.py @@ -1,31 +1,63 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import json -import requests +import os from typing import List -from pilot.scene.base_message import ModelMessage, ModelMessageRoleType +import logging + +import openai + from pilot.model.proxy.llms.proxy_model import ProxyModel +from pilot.model.parameter import ProxyModelParameters +from pilot.scene.base_message import ModelMessage, ModelMessageRoleType + +logger = logging.getLogger(__name__) -def chatgpt_generate_stream( - model: ProxyModel, tokenizer, params, device, context_len=2048 -): +def _initialize_openai(params: ProxyModelParameters): + api_type = params.proxy_api_type or os.getenv("OPENAI_API_TYPE", "open_ai") + + api_base = params.proxy_api_base or os.getenv( + "OPENAI_API_TYPE", + os.getenv("AZURE_OPENAI_ENDPOINT") if api_type == "azure" else None, + ) + api_key = params.proxy_api_key or os.getenv( + "OPENAI_API_KEY", + os.getenv("AZURE_OPENAI_KEY") if api_type == "azure" else None, + ) + api_version = params.proxy_api_version or os.getenv("OPENAI_API_VERSION") + + if not api_base and params.proxy_server_url: + # Adapt previous proxy_server_url configuration + api_base = params.proxy_server_url.split("/chat/completions")[0] + if api_type: + openai.api_type = api_type + if api_base: + openai.api_base = api_base + if api_key: + openai.api_key = api_key + if api_version: + openai.api_version = api_version + if params.http_proxy: + openai.proxy = params.http_proxy + + openai_params = { + "api_type": api_type, + "api_base": api_base, + "api_version": api_version, + "proxy": params.http_proxy, + } + + return openai_params + + +def _build_request(model: ProxyModel, params): history = [] model_params = model.get_params() - print(f"Model: {model}, model_params: {model_params}") + logger.info(f"Model: {model}, model_params: {model_params}") - proxy_api_key = model_params.proxy_api_key - proxy_server_url = model_params.proxy_server_url - proxyllm_backend = model_params.proxyllm_backend - if not proxyllm_backend: - proxyllm_backend = "gpt-3.5-turbo" - - headers = { - "Authorization": "Bearer " + proxy_api_key, - "Token": proxy_api_key, - } + openai_params = _initialize_openai(model_params) messages: List[ModelMessage] = params["messages"] # Add history conversation @@ -51,29 +83,51 @@ def chatgpt_generate_stream( history.append(last_user_input) payloads = { - "model": proxyllm_backend, # just for test, remove this later - "messages": history, "temperature": params.get("temperature"), "max_tokens": params.get("max_new_tokens"), "stream": True, } + proxyllm_backend = model_params.proxyllm_backend - res = requests.post(proxy_server_url, headers=headers, json=payloads, stream=True) + if openai_params["api_type"] == "azure": + # engine = "deployment_name". + proxyllm_backend = proxyllm_backend or "gpt-35-turbo" + payloads["engine"] = proxyllm_backend + else: + proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo" + payloads["model"] = proxyllm_backend - print(f"Send request to {proxy_server_url} with real model {proxyllm_backend}") + logger.info( + f"Send request to real model {proxyllm_backend}, openai_params: {openai_params}" + ) + return history, payloads + + +def chatgpt_generate_stream( + model: ProxyModel, tokenizer, params, device, context_len=2048 +): + history, payloads = _build_request(model, params) + + res = openai.ChatCompletion.create(messages=history, **payloads) text = "" - for line in res.iter_lines(): - if line: - if not line.startswith(b"data: "): - error_message = line.decode("utf-8") - yield error_message - else: - json_data = line.split(b": ", 1)[1] - decoded_line = json_data.decode("utf-8") - if decoded_line.lower() != "[DONE]".lower(): - obj = json.loads(json_data) - if obj["choices"][0]["delta"].get("content") is not None: - content = obj["choices"][0]["delta"]["content"] - text += content - yield text + for r in res: + if r["choices"][0]["delta"].get("content") is not None: + content = r["choices"][0]["delta"]["content"] + text += content + yield text + + +async def async_chatgpt_generate_stream( + model: ProxyModel, tokenizer, params, device, context_len=2048 +): + history, payloads = _build_request(model, params) + + res = await openai.ChatCompletion.acreate(messages=history, **payloads) + + text = "" + async for r in res: + if r["choices"][0]["delta"].get("content") is not None: + content = r["choices"][0]["delta"]["content"] + text += content + yield text diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py index ed170c50e..591ac39af 100644 --- a/pilot/server/dbgpt_server.py +++ b/pilot/server/dbgpt_server.py @@ -5,7 +5,6 @@ from typing import List ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(ROOT_PATH) -import signal from pilot.configs.config import Config from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG, LOGDIR from pilot.component import SystemApp @@ -41,6 +40,7 @@ from pilot.utils.utils import ( ) from pilot.utils.tracer import root_tracer, initialize_tracer, SpanType, SpanTypeRunName from pilot.utils.parameter_utils import _get_dict_from_obj +from pilot.utils.system_utils import get_system_info from pilot.base_modules.agent.controller import router as agent_route @@ -84,19 +84,6 @@ app.include_router(knowledge_router, tags=["Knowledge"]) app.include_router(prompt_router, tags=["Prompt"]) -@app.get("/openapi.json") -async def get_openapi_endpoint(): - return get_openapi( - title="Your API title", - version="1.0.0", - description="Your API description", - routes=app.routes, - ) - -@app.get("/docs") -async def get_docs(): - return get_swagger_ui_html(openapi_url="/openapi.json", title="API docs") - def mount_static_files(app): os.makedirs(static_message_img_path, exist_ok=True) app.mount( @@ -112,17 +99,20 @@ def mount_static_files(app): app.add_exception_handler(RequestValidationError, validation_exception_handler) +def _get_webserver_params(args: List[str] = None): + from pilot.utils.parameter_utils import EnvArgumentParser + + parser: argparse.ArgumentParser = EnvArgumentParser.create_argparse_option( + WebWerverParameters + ) + return WebWerverParameters(**vars(parser.parse_args(args=args))) + def initialize_app(param: WebWerverParameters = None, args: List[str] = None): """Initialize app If you use gunicorn as a process manager, initialize_app can be invoke in `on_starting` hook. """ if not param: - from pilot.utils.parameter_utils import EnvArgumentParser - - parser: argparse.ArgumentParser = EnvArgumentParser.create_argparse_option( - WebWerverParameters - ) - param = WebWerverParameters(**vars(parser.parse_args(args=args))) + param = _get_webserver_params(args) if not param.log_level: param.log_level = _get_logging_level() @@ -141,7 +131,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None): model_start_listener = _create_model_start_listener(system_app) initialize_components(param, system_app, embedding_model_name, embedding_model_path) - model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL] + model_path = LLM_MODEL_CONFIG.get(CFG.LLM_MODEL) if not param.light: print("Model Unified Deployment Mode!") if not param.remote_embedding: @@ -188,8 +178,21 @@ def run_uvicorn(param: WebWerverParameters): def run_webserver(param: WebWerverParameters = None): - param = initialize_app(param) - run_uvicorn(param) + if not param: + param = _get_webserver_params() + initialize_tracer(system_app, os.path.join(LOGDIR, "dbgpt_webserver_tracer.jsonl")) + + with root_tracer.start_span( + "run_webserver", + span_type=SpanType.RUN, + metadata={ + "run_service": SpanTypeRunName.WEBSERVER, + "params": _get_dict_from_obj(param), + "sys_infos": _get_dict_from_obj(get_system_info()), + }, + ): + param = initialize_app(param) + run_uvicorn(param) if __name__ == "__main__": diff --git a/pilot/summary/db_summary_client.py b/pilot/summary/db_summary_client.py index 6ba28afe7..76c72ec1b 100644 --- a/pilot/summary/db_summary_client.py +++ b/pilot/summary/db_summary_client.py @@ -22,8 +22,10 @@ chat_factory = ChatFactory() class DBSummaryClient: - """db summary client, provide db_summary_embedding(put db profile and table profile summary into vector store) + """DB Summary client, provide db_summary_embedding(put db profile and table profile summary into vector store) , get_similar_tables method(get user query related tables info) + Args: + system_app (SystemApp): Main System Application class that manages the lifecycle and registration of components.. """ def __init__(self, system_app: SystemApp): @@ -160,6 +162,12 @@ class DBSummaryClient: ) def init_db_profile(self, db_summary_client, dbname, embeddings): + """db profile initialization + Args: + db_summary_client(DBSummaryClient): DB Summary Client + dbname(str): dbname + embeddings(SourceEmbedding): embedding for read string document + """ from pilot.embedding_engine.string_embedding import StringEmbedding vector_store_name = dbname + "_profile" @@ -176,9 +184,15 @@ class DBSummaryClient: docs = [] docs.extend(embedding.read_batch()) for table_summary in db_summary_client.table_info_json(): + from langchain.text_splitter import RecursiveCharacterTextSplitter + + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=len(table_summary), chunk_overlap=100 + ) embedding = StringEmbedding( - table_summary, - profile_store_config, + file_path=table_summary, + vector_store_config=profile_store_config, + text_splitter=text_splitter, ) docs.extend(embedding.read_batch()) embedding.index_to_store(docs) diff --git a/pilot/utils/system_utils.py b/pilot/utils/system_utils.py new file mode 100644 index 000000000..50b54be0d --- /dev/null +++ b/pilot/utils/system_utils.py @@ -0,0 +1,272 @@ +from dataclasses import dataclass, asdict +from enum import Enum +from typing import Tuple, Dict +import os +import platform +import subprocess +import re +from functools import cache + + +@dataclass +class SystemInfo: + platform: str + distribution: str + python_version: str + cpu: str + cpu_avx: str + memory: str + torch_version: str + device: str + device_version: str + device_count: int + device_other: str + + def to_dict(self) -> Dict: + return asdict(self) + + +class AVXType(Enum): + BASIC = "basic" + AVX = "AVX" + AVX2 = "AVX2" + AVX512 = "AVX512" + + @staticmethod + def of_type(avx: str): + for item in AVXType: + if item._value_ == avx: + return item + return None + + +class OSType(str, Enum): + WINDOWS = "win" + LINUX = "linux" + DARWIN = "darwin" + OTHER = "other" + + +def get_cpu_avx_support() -> Tuple[OSType, AVXType, str]: + system = platform.system() + os_type = OSType.OTHER + cpu_avx = AVXType.BASIC + env_cpu_avx = AVXType.of_type(os.getenv("DBGPT_LLAMA_CPP_AVX")) + distribution = "Unknown Distribution" + if "windows" in system.lower(): + os_type = OSType.WINDOWS + output = "avx2" + distribution = "Windows " + platform.release() + print("Current platform is windows, use avx2 as default cpu architecture") + elif system == "Linux": + os_type = OSType.LINUX + result = subprocess.run( + ["lscpu"], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + output = result.stdout.decode() + distribution = get_linux_distribution() + elif system == "Darwin": + os_type = OSType.DARWIN + result = subprocess.run( + ["sysctl", "-a"], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + distribution = "Mac OS " + platform.mac_ver()[0] + output = result.stdout.decode() + else: + os_type = OSType.OTHER + print("Unsupported OS to get cpu avx, use default") + return os_type, env_cpu_avx if env_cpu_avx else cpu_avx, distribution + + if "avx512" in output.lower(): + cpu_avx = AVXType.AVX512 + elif "avx2" in output.lower(): + cpu_avx = AVXType.AVX2 + elif "avx " in output.lower(): + # cpu_avx = AVXType.AVX + pass + return os_type, env_cpu_avx if env_cpu_avx else cpu_avx, distribution + + +def get_device() -> str: + try: + import torch + + return ( + "cuda" + if torch.cuda.is_available() + else "mps" + if torch.backends.mps.is_available() + else "cpu" + ) + except ModuleNotFoundError: + return "cpu" + + +def get_device_info() -> Tuple[str, str, str, int, str]: + torch_version, device, device_version, device_count, device_other = ( + None, + "cpu", + None, + 0, + "", + ) + try: + import torch + + torch_version = torch.__version__ + if torch.cuda.is_available(): + device = "cuda" + device_version = torch.version.cuda + device_count = torch.cuda.device_count() + elif torch.backends.mps.is_available(): + device = "mps" + except ModuleNotFoundError: + pass + + if not device_version: + device_version = ( + get_cuda_version_from_nvcc() or get_cuda_version_from_nvidia_smi() + ) + if device == "cuda": + try: + output = subprocess.check_output( + [ + "nvidia-smi", + "--query-gpu=name,driver_version,memory.total,memory.free,memory.used", + "--format=csv", + ] + ) + device_other = output.decode("utf-8") + except: + pass + return torch_version, device, device_version, device_count, device_other + + +def get_cuda_version_from_nvcc(): + try: + output = subprocess.check_output(["nvcc", "--version"]) + version_line = [ + line for line in output.decode("utf-8").split("\n") if "release" in line + ][0] + return version_line.split("release")[-1].strip().split(",")[0] + except: + return None + + +def get_cuda_version_from_nvidia_smi(): + try: + output = subprocess.check_output(["nvidia-smi"]).decode("utf-8") + match = re.search(r"CUDA Version:\s+(\d+\.\d+)", output) + if match: + return match.group(1) + else: + return None + except: + return None + + +def get_linux_distribution(): + """Get distribution of Linux""" + if os.path.isfile("/etc/os-release"): + with open("/etc/os-release", "r") as f: + info = {} + for line in f: + key, _, value = line.partition("=") + info[key] = value.strip().strip('"') + return f"{info.get('NAME', 'Unknown')} {info.get('VERSION_ID', '')}".strip() + return "Unknown Linux Distribution" + + +def get_cpu_info(): + # Getting platform + os_type, avx_type, distribution = get_cpu_avx_support() + + # Getting CPU information + cpu_info = "Unknown CPU" + if os_type == OSType.LINUX: + try: + output = subprocess.check_output(["lscpu"]).decode("utf-8") + match = re.search(r".*Model name:\s*(.+)", output) + if match: + cpu_info = match.group(1).strip() + match = re.search(f".*型号名称:\s*(.+)", output) + if match: + cpu_info = match.group(1).strip() + except: + pass + elif os_type == OSType.DARWIN: + try: + output = subprocess.check_output( + ["sysctl", "machdep.cpu.brand_string"] + ).decode("utf-8") + match = re.search(r"machdep.cpu.brand_string:\s*(.+)", output) + if match: + cpu_info = match.group(1).strip() + except: + pass + elif os_type == OSType.WINDOWS: + try: + output = subprocess.check_output("wmic cpu get Name", shell=True).decode( + "utf-8" + ) + lines = output.splitlines() + cpu_info = lines[2].split(":")[-1].strip() + except: + pass + + return os_type, avx_type, cpu_info, distribution + + +def get_memory_info(os_type: OSType) -> str: + memory = "Unknown Memory" + try: + import psutil + + memory = f"{psutil.virtual_memory().total // (1024 ** 3)} GB" + except ImportError: + pass + if os_type == OSType.LINUX: + try: + with open("/proc/meminfo", "r") as f: + mem_info = f.readlines() + for line in mem_info: + if "MemTotal" in line: + memory = line.split(":")[1].strip() + break + except: + pass + return memory + + +@cache +def get_system_info() -> SystemInfo: + """Get System information""" + + os_type, avx_type, cpu_info, distribution = get_cpu_info() + + # Getting Python version + python_version = platform.python_version() + + memory = get_memory_info(os_type) + + ( + torch_version, + device, + device_version, + device_count, + device_other, + ) = get_device_info() + + return SystemInfo( + platform=os_type._value_, + distribution=distribution, + python_version=python_version, + cpu=cpu_info, + cpu_avx=avx_type._value_, + memory=memory, + torch_version=torch_version, + device=device, + device_version=device_version, + device_count=device_count, + device_other=device_other, + ) diff --git a/pilot/utils/tracer/span_storage.py b/pilot/utils/tracer/span_storage.py index 8967f9ee5..914aa0126 100644 --- a/pilot/utils/tracer/span_storage.py +++ b/pilot/utils/tracer/span_storage.py @@ -34,7 +34,9 @@ class FileSpanStorage(SpanStorage): self.flush_signal_queue = queue.Queue() if not os.path.exists(filename): - with open(filename, "w") as _: + # New file if not exist + os.makedirs(os.path.dirname(filename), exist_ok=True) + with open(filename, "a"): pass self.flush_thread = threading.Thread(target=self._flush_to_file, daemon=True) self.flush_thread.start() diff --git a/pilot/utils/tracer/tests/test_base.py b/pilot/utils/tracer/tests/test_base.py index 4a8061cb1..445120e8a 100644 --- a/pilot/utils/tracer/tests/test_base.py +++ b/pilot/utils/tracer/tests/test_base.py @@ -1,7 +1,7 @@ from typing import Dict from pilot.component import SystemApp -from pilot.utils.tracer import Span, SpanStorage, Tracer +from pilot.utils.tracer import Span, SpanType, SpanStorage, Tracer # Mock implementations @@ -31,7 +31,9 @@ class MockTracer(Tracer): self._new_uuid() if parent_span_id is None else parent_span_id.split(":")[0] ) span_id = f"{trace_id}:{self._new_uuid()}" - span = Span(trace_id, span_id, parent_span_id, operation_name, metadata) + span = Span( + trace_id, span_id, SpanType.BASE, parent_span_id, operation_name, metadata + ) self.current_span = span return span @@ -50,7 +52,14 @@ class MockTracer(Tracer): def test_span_creation(): - span = Span("trace_id", "span_id", "parent_span_id", "operation", {"key": "value"}) + span = Span( + "trace_id", + "span_id", + SpanType.BASE, + "parent_span_id", + "operation", + {"key": "value"}, + ) assert span.trace_id == "trace_id" assert span.span_id == "span_id" assert span.parent_span_id == "parent_span_id" diff --git a/pilot/utils/tracer/tests/test_span_storage.py b/pilot/utils/tracer/tests/test_span_storage.py index 5ad518b15..0c63992a6 100644 --- a/pilot/utils/tracer/tests/test_span_storage.py +++ b/pilot/utils/tracer/tests/test_span_storage.py @@ -5,7 +5,7 @@ import json import tempfile import time -from pilot.utils.tracer import SpanStorage, FileSpanStorage, Span +from pilot.utils.tracer import SpanStorage, FileSpanStorage, Span, SpanType @pytest.fixture @@ -44,7 +44,7 @@ def read_spans_from_file(filename): "storage", [{"batch_size": 1, "flush_interval": 5}], indirect=True ) def test_write_span(storage: SpanStorage): - span = Span("1", "a", "b", "op1") + span = Span("1", "a", SpanType.BASE, "b", "op1") storage.append_span(span) time.sleep(0.1) @@ -57,8 +57,8 @@ def test_write_span(storage: SpanStorage): "storage", [{"batch_size": 1, "flush_interval": 5}], indirect=True ) def test_incremental_write(storage: SpanStorage): - span1 = Span("1", "a", "b", "op1") - span2 = Span("2", "c", "d", "op2") + span1 = Span("1", "a", SpanType.BASE, "b", "op1") + span2 = Span("2", "c", SpanType.BASE, "d", "op2") storage.append_span(span1) storage.append_span(span2) @@ -72,7 +72,7 @@ def test_incremental_write(storage: SpanStorage): "storage", [{"batch_size": 2, "flush_interval": 5}], indirect=True ) def test_sync_and_async_append(storage: SpanStorage): - span = Span("1", "a", "b", "op1") + span = Span("1", "a", SpanType.BASE, "b", "op1") storage.append_span(span) @@ -88,7 +88,7 @@ def test_sync_and_async_append(storage: SpanStorage): @pytest.mark.asyncio async def test_flush_policy(storage: SpanStorage): - span = Span("1", "a", "b", "op1") + span = Span("1", "a", SpanType.BASE, "b", "op1") for _ in range(storage.batch_size - 1): storage.append_span(span) @@ -108,8 +108,8 @@ async def test_flush_policy(storage: SpanStorage): "storage", [{"batch_size": 2, "file_does_not_exist": True}], indirect=True ) def test_non_existent_file(storage: SpanStorage): - span = Span("1", "a", "b", "op1") - span2 = Span("2", "c", "d", "op2") + span = Span("1", "a", SpanType.BASE, "b", "op1") + span2 = Span("2", "c", SpanType.BASE, "d", "op2") storage.append_span(span) time.sleep(0.1) diff --git a/pilot/utils/tracer/tracer_cli.py b/pilot/utils/tracer/tracer_cli.py index 822b039ee..7df18f516 100644 --- a/pilot/utils/tracer/tracer_cli.py +++ b/pilot/utils/tracer/tracer_cli.py @@ -259,6 +259,7 @@ def chat( found_trace_id = trace_id service_tables = {} + system_infos_table = {} out_kwargs = {"ensure_ascii": False} if output == "json" else {} for service_name, sp in service_spans.items(): metadata = sp["metadata"] @@ -266,6 +267,15 @@ def chat( for k, v in metadata["params"].items(): table.add_row([k, v]) service_tables[service_name] = table + sys_infos = metadata.get("sys_infos") + if sys_infos and isinstance(sys_infos, dict): + sys_table = PrettyTable( + ["System Config Key", "System Config Value"], + title=f"{service_name} System information", + ) + for k, v in sys_infos.items(): + sys_table.add_row([k, v]) + system_infos_table[service_name] = sys_table if not hide_run_params: merged_table1 = merge_tables_horizontally( @@ -276,16 +286,23 @@ def chat( ) merged_table2 = merge_tables_horizontally( [ - service_tables.get(SpanTypeRunName.MODEL_WORKER), - service_tables.get(SpanTypeRunName.WORKER_MANAGER), + service_tables.get(SpanTypeRunName.MODEL_WORKER.value), + service_tables.get(SpanTypeRunName.WORKER_MANAGER.value), ] ) + sys_table = system_infos_table.get(SpanTypeRunName.WORKER_MANAGER.value) + if system_infos_table: + for k, v in system_infos_table.items(): + sys_table = v + break if output == "text": print(merged_table1) print(merged_table2) else: for service_name, table in service_tables.items(): print(table.get_formatted_string(out_format=output, **out_kwargs)) + if sys_table: + print(sys_table.get_formatted_string(out_format=output, **out_kwargs)) if hide_conv: return