mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 16:18:27 +00:00
feat(model): Add more support model list and modify LLM benchmarks (#823)
This commit is contained in:
parent
e4c91d161f
commit
89dcccc642
16
README.md
16
README.md
@ -115,16 +115,26 @@ At present, we have introduced several key features to showcase our current capa
|
||||
- **SMMF(Service-oriented Multi-model Management Framework)**
|
||||
|
||||
We offer extensive model support, including dozens of large language models (LLMs) from both open-source and API agents, such as LLaMA/LLaMA2, Baichuan, ChatGLM, Wenxin, Tongyi, Zhipu, and many more.
|
||||
|
||||
- [Vicuna](https://huggingface.co/Tribbiani/vicuna-13b)
|
||||
- [vicuna-13b-v1.5](https://huggingface.co/lmsys/vicuna-13b-v1.5)
|
||||
- [LLama2](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
|
||||
- [baichuan2-13b](https://huggingface.co/baichuan-inc)
|
||||
- [baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B)
|
||||
- [baichuan2-13b](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat)
|
||||
- [baichuan2-7b](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat)
|
||||
- [chatglm-6b](https://huggingface.co/THUDM/chatglm-6b)
|
||||
- [chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b)
|
||||
- [chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b)
|
||||
- [falcon-40b](https://huggingface.co/tiiuae/falcon-40b)
|
||||
- [internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b)
|
||||
- [Qwen-7B-Chat/Qwen-14B-Chat](https://huggingface.co/Qwen/)
|
||||
- [internlm-chat-20b](https://huggingface.co/internlm/internlm-chat-20b)
|
||||
- [qwen-7b-chat](https://huggingface.co/Qwen/Qwen-7B-Chat)
|
||||
- [qwen-14b-chat](https://huggingface.co/Qwen/Qwen-14B-Chat)
|
||||
- [wizardlm-13b](https://huggingface.co/WizardLM/WizardLM-13B-V1.2)
|
||||
- [orca-2-7b](https://huggingface.co/microsoft/Orca-2-7b)
|
||||
- [orca-2-13b](https://huggingface.co/microsoft/Orca-2-13b)
|
||||
- [openchat_3.5](https://huggingface.co/openchat/openchat_3.5)
|
||||
- [zephyr-7b-alpha](https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha)
|
||||
- [mistral-7b-instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
|
||||
|
||||
- Support API Proxy LLMs
|
||||
- [x] [ChatGPT](https://api.openai.com/)
|
||||
|
20
README.zh.md
20
README.zh.md
@ -112,15 +112,27 @@ DB-GPT是一个开源的数据库领域大模型框架。目的是构建大模
|
||||
|
||||
- **多模型支持与管理**
|
||||
|
||||
海量模型支持,包括开源、API代理等几十种大语言模型。如LLaMA/LLaMA2、Baichuan、ChatGLM、文心、通义、智谱等。
|
||||
- 支持多种大语言模型, 当前已支持如下模型:
|
||||
海量模型支持,包括开源、API代理等几十种大语言模型。如LLaMA/LLaMA2、Baichuan、ChatGLM、文心、通义、智谱等。当前已支持如下模型:
|
||||
|
||||
- [Vicuna](https://huggingface.co/Tribbiani/vicuna-13b)
|
||||
- [vicuna-13b-v1.5](https://huggingface.co/lmsys/vicuna-13b-v1.5)
|
||||
- [LLama2](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
|
||||
- [baichuan2-13b](https://huggingface.co/baichuan-inc)
|
||||
- [baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B)
|
||||
- [baichuan2-13b](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat)
|
||||
- [baichuan2-7b](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat)
|
||||
- [chatglm-6b](https://huggingface.co/THUDM/chatglm-6b)
|
||||
- [chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b)
|
||||
- [chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b)
|
||||
- [falcon-40b](https://huggingface.co/tiiuae/falcon-40b)
|
||||
- [internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b)
|
||||
- [internlm-chat-20b](https://huggingface.co/internlm/internlm-chat-20b)
|
||||
- [qwen-7b-chat](https://huggingface.co/Qwen/Qwen-7B-Chat)
|
||||
- [qwen-14b-chat](https://huggingface.co/Qwen/Qwen-14B-Chat)
|
||||
- [wizardlm-13b](https://huggingface.co/WizardLM/WizardLM-13B-V1.2)
|
||||
- [orca-2-7b](https://huggingface.co/microsoft/Orca-2-7b)
|
||||
- [orca-2-13b](https://huggingface.co/microsoft/Orca-2-13b)
|
||||
- [openchat_3.5](https://huggingface.co/openchat/openchat_3.5)
|
||||
- [zephyr-7b-alpha](https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha)
|
||||
- [mistral-7b-instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
|
||||
|
||||
- 支持在线代理模型
|
||||
- [x] [ChatGPT](https://api.openai.com/)
|
||||
|
@ -53,6 +53,8 @@ LLM_MODEL_CONFIG = {
|
||||
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
|
||||
"chatglm2-6b": os.path.join(MODEL_PATH, "chatglm2-6b"),
|
||||
"chatglm2-6b-int4": os.path.join(MODEL_PATH, "chatglm2-6b-int4"),
|
||||
# https://huggingface.co/THUDM/chatglm3-6b
|
||||
"chatglm3-6b": os.path.join(MODEL_PATH, "chatglm3-6b"),
|
||||
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
|
||||
"falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"),
|
||||
"gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"),
|
||||
@ -74,6 +76,18 @@ LLM_MODEL_CONFIG = {
|
||||
"baichuan-7b": os.path.join(MODEL_PATH, "baichuan-7b"),
|
||||
"baichuan2-7b": os.path.join(MODEL_PATH, "Baichuan2-7B-Chat"),
|
||||
"baichuan2-13b": os.path.join(MODEL_PATH, "Baichuan2-13B-Chat"),
|
||||
# https://huggingface.co/Qwen/Qwen-7B-Chat
|
||||
"qwen-7b-chat": os.path.join(MODEL_PATH, "Qwen-7B-Chat"),
|
||||
# https://huggingface.co/Qwen/Qwen-7B-Chat-Int8
|
||||
"qwen-7b-chat-int8": os.path.join(MODEL_PATH, "Qwen-7B-Chat-Int8"),
|
||||
# https://huggingface.co/Qwen/Qwen-7B-Chat-Int4
|
||||
"qwen-7b-chat-int4": os.path.join(MODEL_PATH, "Qwen-7B-Chat-Int4"),
|
||||
# https://huggingface.co/Qwen/Qwen-14B-Chat
|
||||
"qwen-14b-chat": os.path.join(MODEL_PATH, "Qwen-14B-Chat"),
|
||||
# https://huggingface.co/Qwen/Qwen-14B-Chat-Int8
|
||||
"qwen-14b-chat-int8": os.path.join(MODEL_PATH, "Qwen-14B-Chat-Int8"),
|
||||
# https://huggingface.co/Qwen/Qwen-14B-Chat-Int4
|
||||
"qwen-14b-chat-int4": os.path.join(MODEL_PATH, "Qwen-14B-Chat-Int4"),
|
||||
# (Llama2 based) We only support WizardLM-13B-V1.2 for now, which is trained from Llama-2 13b, see https://huggingface.co/WizardLM/WizardLM-13B-V1.2
|
||||
"wizardlm-13b": os.path.join(MODEL_PATH, "WizardLM-13B-V1.2"),
|
||||
# wget https://huggingface.co/TheBloke/vicuna-13B-v1.5-GGUF/resolve/main/vicuna-13b-v1.5.Q4_K_M.gguf -O models/ggml-model-q4_0.gguf
|
||||
@ -88,6 +102,30 @@ LLM_MODEL_CONFIG = {
|
||||
"codellama-13b-sql-sft": os.path.join(MODEL_PATH, "codellama-13b-sql-sft"),
|
||||
# For test now
|
||||
"opt-125m": os.path.join(MODEL_PATH, "opt-125m"),
|
||||
# https://huggingface.co/microsoft/Orca-2-7b
|
||||
"orca-2-7b": os.path.join(MODEL_PATH, "Orca-2-7b"),
|
||||
# https://huggingface.co/microsoft/Orca-2-13b
|
||||
"orca-2-13b": os.path.join(MODEL_PATH, "Orca-2-13b"),
|
||||
# https://huggingface.co/openchat/openchat_3.5
|
||||
"openchat_3.5": os.path.join(MODEL_PATH, "openchat_3.5"),
|
||||
# https://huggingface.co/hfl/chinese-alpaca-2-7b
|
||||
"chinese-alpaca-2-7b": os.path.join(MODEL_PATH, "chinese-alpaca-2-7b"),
|
||||
# https://huggingface.co/hfl/chinese-alpaca-2-13b
|
||||
"chinese-alpaca-2-13b": os.path.join(MODEL_PATH, "chinese-alpaca-2-13b"),
|
||||
# https://huggingface.co/THUDM/codegeex2-6b
|
||||
"codegeex2-6b": os.path.join(MODEL_PATH, "codegeex2-6b"),
|
||||
# https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha
|
||||
"zephyr-7b-alpha": os.path.join(MODEL_PATH, "zephyr-7b-alpha"),
|
||||
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
|
||||
"mistral-7b-instruct-v0.1": os.path.join(MODEL_PATH, "Mistral-7B-Instruct-v0.1"),
|
||||
# https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca
|
||||
"mistral-7b-openorca": os.path.join(MODEL_PATH, "Mistral-7B-OpenOrca"),
|
||||
# https://huggingface.co/Xwin-LM/Xwin-LM-7B-V0.1
|
||||
"xwin-lm-7b-v0.1": os.path.join(MODEL_PATH, "Xwin-LM-7B-V0.1"),
|
||||
# https://huggingface.co/Xwin-LM/Xwin-LM-13B-V0.1
|
||||
"xwin-lm-13b-v0.1": os.path.join(MODEL_PATH, "Xwin-LM-13B-V0.1"),
|
||||
# https://huggingface.co/Xwin-LM/Xwin-LM-70B-V0.1
|
||||
"xwin-lm-70b-v0.1": os.path.join(MODEL_PATH, "Xwin-LM-70B-V0.1"),
|
||||
}
|
||||
|
||||
EMBEDDING_MODEL_CONFIG = {
|
||||
|
@ -8,6 +8,7 @@ from dataclasses import dataclass, asdict
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pilot.utils.parameter_utils import ParameterDescription
|
||||
from pilot.utils.model_utils import GPUInfo
|
||||
|
||||
|
||||
class Message(TypedDict):
|
||||
@ -53,6 +54,8 @@ class WorkerApplyType(str, Enum):
|
||||
class ModelInferenceMetrics:
|
||||
"""A class to represent metrics for assessing the inference performance of a LLM."""
|
||||
|
||||
collect_index: Optional[int] = 0
|
||||
|
||||
start_time_ms: Optional[int] = None
|
||||
"""The timestamp (in milliseconds) when the model inference starts."""
|
||||
|
||||
@ -83,6 +86,12 @@ class ModelInferenceMetrics:
|
||||
speed_per_second: Optional[float] = None
|
||||
"""The average number of tokens generated per second."""
|
||||
|
||||
current_gpu_infos: Optional[List[GPUInfo]] = None
|
||||
"""Current gpu information, all devices"""
|
||||
|
||||
avg_gpu_infos: Optional[List[GPUInfo]] = None
|
||||
"""Average memory usage across all collection points"""
|
||||
|
||||
@staticmethod
|
||||
def create_metrics(
|
||||
last_metrics: Optional["ModelInferenceMetrics"] = None,
|
||||
@ -99,6 +108,8 @@ class ModelInferenceMetrics:
|
||||
completion_tokens = last_metrics.completion_tokens if last_metrics else None
|
||||
total_tokens = last_metrics.total_tokens if last_metrics else None
|
||||
speed_per_second = last_metrics.speed_per_second if last_metrics else None
|
||||
current_gpu_infos = last_metrics.current_gpu_infos if last_metrics else None
|
||||
avg_gpu_infos = last_metrics.avg_gpu_infos if last_metrics else None
|
||||
|
||||
if not start_time_ms:
|
||||
start_time_ms = time.time_ns() // 1_000_000
|
||||
@ -116,6 +127,8 @@ class ModelInferenceMetrics:
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
speed_per_second=speed_per_second,
|
||||
current_gpu_infos=current_gpu_infos,
|
||||
avg_gpu_infos=avg_gpu_infos,
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
|
@ -11,7 +11,7 @@ from pilot.model.base import ModelOutput, ModelInferenceMetrics
|
||||
from pilot.model.loader import ModelLoader, _get_model_real_path
|
||||
from pilot.model.parameter import ModelParameters
|
||||
from pilot.model.cluster.worker_base import ModelWorker
|
||||
from pilot.utils.model_utils import _clear_model_cache
|
||||
from pilot.utils.model_utils import _clear_model_cache, _get_current_cuda_memory
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser, _get_dict_from_obj
|
||||
from pilot.utils.tracer import root_tracer, SpanType, SpanTypeRunName
|
||||
from pilot.utils.system_utils import get_system_info
|
||||
@ -363,6 +363,7 @@ def _new_metrics_from_model_output(
|
||||
usage: Optional[Dict] = None,
|
||||
) -> ModelInferenceMetrics:
|
||||
metrics = ModelInferenceMetrics.create_metrics(last_metric)
|
||||
metrics.collect_index = last_metric.collect_index + 1
|
||||
if is_first_generate:
|
||||
logger.info(f"is_first_generate, usage: {usage}")
|
||||
metrics.first_completion_time_ms = time.time_ns() // 1_000_000
|
||||
@ -385,6 +386,13 @@ def _new_metrics_from_model_output(
|
||||
metrics.first_completion_tokens = completion_tokens
|
||||
if completion_tokens == 1:
|
||||
metrics.first_token_time_ms = metrics.first_completion_time_ms
|
||||
if (
|
||||
not is_first_generate
|
||||
and metrics.first_token_time_ms is None
|
||||
and completion_tokens == 1
|
||||
):
|
||||
# Case: first generate has 0 token, and second generate has 1 token
|
||||
metrics.first_token_time_ms = time.time_ns() // 1_000_000
|
||||
|
||||
if prompt_tokens:
|
||||
metrics.prompt_tokens = prompt_tokens
|
||||
@ -400,4 +408,28 @@ def _new_metrics_from_model_output(
|
||||
# time cost(seconds)
|
||||
duration = (metrics.current_time_ms - metrics.start_time_ms) / 1000.0
|
||||
metrics.speed_per_second = total_tokens / duration
|
||||
|
||||
current_gpu_infos = _get_current_cuda_memory()
|
||||
metrics.current_gpu_infos = current_gpu_infos
|
||||
if not metrics.avg_gpu_infos:
|
||||
metrics.avg_gpu_infos = current_gpu_infos
|
||||
elif current_gpu_infos:
|
||||
for i, last_avg in enumerate(metrics.avg_gpu_infos):
|
||||
allocated_memory_gb = (
|
||||
last_avg.allocated_memory_gb * (metrics.collect_index - 1)
|
||||
+ current_gpu_infos[i].allocated_memory_gb
|
||||
)
|
||||
metrics.avg_gpu_infos[i].allocated_memory_gb = (
|
||||
allocated_memory_gb / metrics.collect_index
|
||||
)
|
||||
metrics.avg_gpu_infos[i].total_memory_gb = current_gpu_infos[
|
||||
i
|
||||
].total_memory_gb
|
||||
metrics.avg_gpu_infos[i].cached_memory_gb = current_gpu_infos[
|
||||
i
|
||||
].cached_memory_gb
|
||||
metrics.avg_gpu_infos[i].available_memory_gb = current_gpu_infos[
|
||||
i
|
||||
].available_memory_gb
|
||||
|
||||
return metrics
|
||||
|
@ -42,7 +42,7 @@ def generate_stream(
|
||||
params: Dict,
|
||||
device: str,
|
||||
context_len: int,
|
||||
stream_interval: int = 2,
|
||||
stream_interval: int = 1,
|
||||
judge_sent_end: bool = False,
|
||||
):
|
||||
if hasattr(model, "device"):
|
||||
|
@ -8,6 +8,7 @@ import argparse
|
||||
import logging
|
||||
import traceback
|
||||
from pilot.configs.model_config import ROOT_PATH, LLM_MODEL_CONFIG
|
||||
from datetime import timedelta, datetime
|
||||
|
||||
from pilot.model.cluster.worker.manager import (
|
||||
run_worker_manager,
|
||||
@ -53,13 +54,20 @@ prompt_file_map = {
|
||||
METRICS_HEADERS = [
|
||||
# Params
|
||||
"model_name",
|
||||
"gpu_nums",
|
||||
"parallel_nums",
|
||||
"input_length",
|
||||
"output_length",
|
||||
# Merge parallel result
|
||||
"test_time_cost_ms",
|
||||
"test_total_tokens",
|
||||
"test_speed_per_second", # (tokens / s)
|
||||
# avg_test_speed_per_second: (tokens / s), test_total_tokens / (test_time_cost_ms / 1000.0)
|
||||
"avg_test_speed_per_second(tokens/s)",
|
||||
# avg_first_token_latency_ms: sum(first_token_time_ms) / parallel_nums
|
||||
"avg_first_token_latency_ms",
|
||||
# avg_latency_ms: sum(end_time_ms - start_time_ms) / parallel_nums
|
||||
"avg_latency_ms",
|
||||
"gpu_mem(GiB)",
|
||||
# Detail for each task
|
||||
"start_time_ms",
|
||||
"end_time_ms",
|
||||
@ -106,7 +114,11 @@ def build_param(
|
||||
|
||||
|
||||
async def run_batch(
|
||||
wh, input_len: int, output_len: int, parallel_num: int, output_file: str
|
||||
wh: WorkerManager,
|
||||
input_len: int,
|
||||
output_len: int,
|
||||
parallel_num: int,
|
||||
output_file: str,
|
||||
):
|
||||
tasks = []
|
||||
prompt = read_prompt_from_file("11k")
|
||||
@ -117,6 +129,10 @@ async def run_batch(
|
||||
max_input_str_len *= 2
|
||||
prompt = prompt[-max_input_str_len:]
|
||||
|
||||
# Warmup first
|
||||
params = build_param(input_len, output_len, prompt, system_prompt="")
|
||||
await wh.generate(params)
|
||||
|
||||
for _ in range(parallel_num):
|
||||
params = build_param(input_len, output_len, prompt, system_prompt="")
|
||||
tasks.append(wh.generate(params))
|
||||
@ -129,6 +145,10 @@ async def run_batch(
|
||||
|
||||
test_time_cost_ms = end_time_ms - start_time_ms
|
||||
test_total_tokens = 0
|
||||
first_token_latency_ms = 0
|
||||
latency_ms = 0
|
||||
gpu_nums = 0
|
||||
avg_gpu_mem = 0
|
||||
rows = []
|
||||
for r in results:
|
||||
metrics = r.metrics
|
||||
@ -136,9 +156,22 @@ async def run_batch(
|
||||
metrics = ModelInferenceMetrics(**metrics)
|
||||
print(r)
|
||||
test_total_tokens += metrics.total_tokens
|
||||
first_token_latency_ms += metrics.first_token_time_ms - metrics.start_time_ms
|
||||
latency_ms += metrics.end_time_ms - metrics.start_time_ms
|
||||
row_data = metrics.to_dict()
|
||||
del row_data["collect_index"]
|
||||
if "avg_gpu_infos" in row_data:
|
||||
avg_gpu_infos = row_data["avg_gpu_infos"]
|
||||
gpu_nums = len(avg_gpu_infos)
|
||||
avg_gpu_mem = (
|
||||
sum(i["allocated_memory_gb"] for i in avg_gpu_infos) / gpu_nums
|
||||
)
|
||||
del row_data["avg_gpu_infos"]
|
||||
del row_data["current_gpu_infos"]
|
||||
rows.append(row_data)
|
||||
test_speed_per_second = test_total_tokens / (test_time_cost_ms / 1000.0)
|
||||
avg_test_speed_per_second = test_total_tokens / (test_time_cost_ms / 1000.0)
|
||||
avg_first_token_latency_ms = first_token_latency_ms / len(results)
|
||||
avg_latency_ms = latency_ms / len(results)
|
||||
|
||||
with open(output_file, "a", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=METRICS_HEADERS)
|
||||
@ -152,7 +185,11 @@ async def run_batch(
|
||||
row["output_length"] = output_len
|
||||
row["test_time_cost_ms"] = test_time_cost_ms
|
||||
row["test_total_tokens"] = test_total_tokens
|
||||
row["test_speed_per_second"] = test_speed_per_second
|
||||
row["avg_test_speed_per_second(tokens/s)"] = avg_test_speed_per_second
|
||||
row["avg_first_token_latency_ms"] = avg_first_token_latency_ms
|
||||
row["avg_latency_ms"] = avg_latency_ms
|
||||
row["gpu_nums"] = gpu_nums
|
||||
row["gpu_mem(GiB)"] = avg_gpu_mem
|
||||
writer.writerow(row)
|
||||
print(
|
||||
f"input_len: {input_len}, output_len: {output_len}, parallel_num: {parallel_num}, save result to {output_file}"
|
||||
@ -164,7 +201,9 @@ async def run_model(wh: WorkerManager) -> None:
|
||||
if not result_csv_file:
|
||||
result_csv_file = get_result_csv_file()
|
||||
if os.path.exists(result_csv_file):
|
||||
os.rename(result_csv_file, f"{result_csv_file}.bak.csv")
|
||||
now = datetime.now()
|
||||
now_str = now.strftime("%Y-%m-%d")
|
||||
os.rename(result_csv_file, f"{result_csv_file}.bak_{now_str}.csv")
|
||||
for parallel_num in parallel_nums:
|
||||
for input_len, output_len in zip(input_lens, output_lens):
|
||||
try:
|
||||
@ -176,6 +215,8 @@ async def run_model(wh: WorkerManager) -> None:
|
||||
logging.error(
|
||||
f"Run benchmarks error, input_len: {input_len}, output_len: {output_len}, parallel_num: {parallel_num}, error message: {msg}"
|
||||
)
|
||||
if "torch.cuda.OutOfMemoryError" in msg:
|
||||
return
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import List, Tuple
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -37,3 +39,46 @@ def _clear_torch_cache(device="cuda"):
|
||||
torch.cuda.ipc_collect()
|
||||
else:
|
||||
logger.info("No cuda or mps, not support clear torch cache yet")
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPUInfo:
|
||||
total_memory_gb: float
|
||||
allocated_memory_gb: float
|
||||
cached_memory_gb: float
|
||||
available_memory_gb: float
|
||||
|
||||
|
||||
def _get_current_cuda_memory() -> List[GPUInfo]:
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
logger.warn("Torch not installed")
|
||||
return []
|
||||
if torch.cuda.is_available():
|
||||
num_gpus = torch.cuda.device_count()
|
||||
gpu_infos = []
|
||||
for gpu_id in range(num_gpus):
|
||||
with torch.cuda.device(gpu_id):
|
||||
device = torch.cuda.current_device()
|
||||
gpu_properties = torch.cuda.get_device_properties(device)
|
||||
total_memory = round(gpu_properties.total_memory / (1.0 * 1024**3), 2)
|
||||
allocated_memory = round(
|
||||
torch.cuda.memory_allocated() / (1.0 * 1024**3), 2
|
||||
)
|
||||
cached_memory = round(
|
||||
torch.cuda.memory_reserved() / (1.0 * 1024**3), 2
|
||||
)
|
||||
available_memory = total_memory - allocated_memory
|
||||
gpu_infos.append(
|
||||
GPUInfo(
|
||||
total_memory_gb=total_memory,
|
||||
allocated_memory_gb=allocated_memory,
|
||||
cached_memory_gb=cached_memory,
|
||||
available_memory_gb=available_memory,
|
||||
)
|
||||
)
|
||||
return gpu_infos
|
||||
else:
|
||||
logger.warn("CUDA is not available.")
|
||||
return []
|
||||
|
@ -1,7 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
default_input_lens="64,64,64,512,1024,1024,2048"
|
||||
default_output_lens="256,512,1024,1024,1024,2048,2048"
|
||||
default_input_lens="8,8,256,1024"
|
||||
default_output_lens="256,512,1024,1024"
|
||||
default_parallel_nums="1,2,4,16,32"
|
||||
|
||||
input_lens=${1:-$default_input_lens}
|
||||
|
Loading…
Reference in New Issue
Block a user