mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[feat] Dist Loader for Eval (#5950)
* support auto distributed data loader * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * support auto distributed data loader * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix tp error * remove unused parameters * remove unused * update inference * update docs * update inference --------- Co-authored-by: Michelle <qianranma8@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -5,6 +5,8 @@ from typing import Dict, List
|
||||
|
||||
import torch.distributed as dist
|
||||
from colossal_eval import dataset, models, utils
|
||||
from colossal_eval.dataset.base import DistributedDataset
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
@@ -13,6 +15,7 @@ from colossalai.logging import get_dist_logger
|
||||
from colossalai.shardformer import ShardConfig
|
||||
|
||||
logger = get_dist_logger()
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
def rm_and_merge(
|
||||
@@ -54,7 +57,8 @@ def rm_and_merge(
|
||||
)
|
||||
else:
|
||||
rank_answers = utils.jload(directory)
|
||||
answers["data"].extend(rank_answers["data"])
|
||||
deduplidate_answers = [x for x in rank_answers["data"] if x not in answers["data"]]
|
||||
answers["data"].extend(deduplidate_answers)
|
||||
answers["inference_kwargs"] = rank_answers["inference_kwargs"]
|
||||
|
||||
for r in range(dp_size):
|
||||
@@ -65,7 +69,7 @@ def rm_and_merge(
|
||||
os.remove(directory)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
print(len(answers["data"]))
|
||||
all_answers[category] = answers
|
||||
|
||||
all_answers_with_dataset_class["inference_results"] = all_answers
|
||||
@@ -108,7 +112,12 @@ def main(args):
|
||||
tp_rank = coordinates[TP_AXIS]
|
||||
|
||||
shard_config = (
|
||||
ShardConfig(tensor_parallel_process_group=tp_group, enable_tensor_parallelism=args.tp_size > 1)
|
||||
ShardConfig(
|
||||
tensor_parallel_process_group=tp_group,
|
||||
enable_tensor_parallelism=args.tp_size > 1,
|
||||
parallel_output=False,
|
||||
enable_all_optimization=True,
|
||||
)
|
||||
if args.tp_size > 1
|
||||
else None
|
||||
)
|
||||
@@ -183,6 +192,7 @@ def main(args):
|
||||
model_name = model_parameter["name"]
|
||||
model_class = eval(f"models.{model_parameter['model_class']}")
|
||||
paramerters = model_parameter["parameters"]
|
||||
batch_size = paramerters["batch_size"]
|
||||
paramerters.update({"logger": logger})
|
||||
paramerters.update({"prompt_template": utils.prompt_templates[paramerters["prompt_template"]]})
|
||||
paramerters.update({"shard_config": shard_config})
|
||||
@@ -192,7 +202,6 @@ def main(args):
|
||||
raise ValueError(f"Model class {model_parameter['model_class']} is not a subclass of BaseModel.")
|
||||
|
||||
for dataset_name, split_data in inference_data.items():
|
||||
start = 0
|
||||
prev_questions = None
|
||||
for category, category_data in split_data.items():
|
||||
num_turn = category_data["inference_kwargs"].get("turns", 1)
|
||||
@@ -201,26 +210,33 @@ def main(args):
|
||||
raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!")
|
||||
|
||||
answers_to_dump = copy.deepcopy(category_data)
|
||||
partition_size = len(category_data["data"]) // dp_size
|
||||
redundant = len(category_data["data"]) % dp_size
|
||||
|
||||
# Ensure that the amount of data for inference is as consistent as possible across different processes.
|
||||
lengths = [partition_size for _ in range(dp_size)]
|
||||
for j in range(redundant):
|
||||
lengths[(j + start) % dp_size] += 1
|
||||
|
||||
start = (start + redundant) % dp_size
|
||||
|
||||
for turn in range(num_turn):
|
||||
if turn == 0:
|
||||
questions = category_data["data"][
|
||||
sum(lengths[0:dp_rank]) : sum(lengths[0:dp_rank]) + lengths[dp_rank]
|
||||
]
|
||||
dist_dataset = DistributedDataset(category_data["data"])
|
||||
else:
|
||||
questions = prev_questions
|
||||
dist_dataset = DistributedDataset(prev_questions)
|
||||
|
||||
sampler = DistributedSampler(
|
||||
dist_dataset,
|
||||
num_replicas=pg_mesh.size(DP_AXIS),
|
||||
rank=pg_mesh.coordinate(DP_AXIS),
|
||||
shuffle=False,
|
||||
)
|
||||
questions_loader = DataLoader(
|
||||
dist_dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
collate_fn=lambda x: x,
|
||||
)
|
||||
category_data["inference_kwargs"]["dataset"] = dataset_name
|
||||
category_data["inference_kwargs"]["category"] = category
|
||||
|
||||
answers_per_rank = model_.inference(
|
||||
questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name]
|
||||
data_loader=questions_loader,
|
||||
inference_kwargs=category_data["inference_kwargs"],
|
||||
debug=debug_args[dataset_name],
|
||||
)
|
||||
prev_questions = answers_per_rank
|
||||
|
||||
|
Reference in New Issue
Block a user