mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -14,7 +14,7 @@ from colossalai.legacy.trainer import Trainer, hooks
|
||||
from colossalai.legacy.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn import LinearWarmupLR
|
||||
from colossalai.utils import colo_set_process_memory_fraction, is_using_pp
|
||||
from colossalai.utils import is_using_pp
|
||||
from colossalai.utils.timer import MultiTimer
|
||||
|
||||
|
||||
@@ -30,8 +30,8 @@ VOCAB_SIZE = 50257
|
||||
|
||||
def main():
|
||||
parser = colossalai.get_default_parser()
|
||||
parser.add_argument('--from_torch', default=False, action='store_true')
|
||||
parser.add_argument('--use_dummy_dataset', default=False, action='store_true')
|
||||
parser.add_argument("--from_torch", default=False, action="store_true")
|
||||
parser.add_argument("--use_dummy_dataset", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
disable_existing_loggers()
|
||||
if args.from_torch:
|
||||
@@ -40,28 +40,27 @@ def main():
|
||||
colossalai.launch_from_slurm(config=args.config, host=args.host, port=29500, seed=42)
|
||||
logger = get_dist_logger()
|
||||
|
||||
data_path = None if args.use_dummy_dataset else os.environ['DATA']
|
||||
logger.info(f'Build data loader from path {data_path}', ranks=[0])
|
||||
data_path = None if args.use_dummy_dataset else os.environ["DATA"]
|
||||
logger.info(f"Build data loader from path {data_path}", ranks=[0])
|
||||
|
||||
train_ds = WebtextDataset(path=data_path, seq_len=gpc.config.SEQ_LEN)
|
||||
train_dataloader = utils.get_dataloader(train_ds,
|
||||
seed=42,
|
||||
batch_size=gpc.config.BATCH_SIZE,
|
||||
pin_memory=True,
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
train_dataloader = utils.get_dataloader(
|
||||
train_ds, seed=42, batch_size=gpc.config.BATCH_SIZE, pin_memory=True, shuffle=True, drop_last=True
|
||||
)
|
||||
|
||||
logger.info('Build model', ranks=[0])
|
||||
logger.info("Build model", ranks=[0])
|
||||
use_pipeline = is_using_pp()
|
||||
use_interleaved = hasattr(gpc.config.model, 'num_chunks')
|
||||
use_zero3 = hasattr(gpc.config, 'zero')
|
||||
use_interleaved = hasattr(gpc.config.model, "num_chunks")
|
||||
use_zero3 = hasattr(gpc.config, "zero")
|
||||
ctx = contextlib.nullcontext()
|
||||
if use_zero3:
|
||||
ctx = ZeroInitContext(target_device=torch.cuda.current_device(),
|
||||
shard_strategy=gpc.config.zero.model_config.shard_strategy,
|
||||
shard_param=True)
|
||||
ctx = ZeroInitContext(
|
||||
target_device=torch.cuda.current_device(),
|
||||
shard_strategy=gpc.config.zero.model_config.shard_strategy,
|
||||
shard_param=True,
|
||||
)
|
||||
with ctx:
|
||||
model = gpc.config.model.pop('type')(**gpc.config.model)
|
||||
model = gpc.config.model.pop("type")(**gpc.config.model)
|
||||
if use_pipeline and use_interleaved and not isinstance(model, nn.ModuleList):
|
||||
model = nn.ModuleList([model])
|
||||
|
||||
@@ -70,25 +69,31 @@ def main():
|
||||
else:
|
||||
numel = calc_local_model_size(model)
|
||||
|
||||
tflop = numel * gpc.config.BATCH_SIZE * gpc.config.SEQ_LEN \
|
||||
* gpc.get_world_size(ParallelMode.MODEL) * gpc.get_world_size(ParallelMode.DATA) * 8 / (1024 ** 4)
|
||||
tflop = (
|
||||
numel
|
||||
* gpc.config.BATCH_SIZE
|
||||
* gpc.config.SEQ_LEN
|
||||
* gpc.get_world_size(ParallelMode.MODEL)
|
||||
* gpc.get_world_size(ParallelMode.DATA)
|
||||
* 8
|
||||
/ (1024**4)
|
||||
)
|
||||
|
||||
criterion = getattr(gpc.config, 'loss_fn', None)
|
||||
criterion = getattr(gpc.config, "loss_fn", None)
|
||||
if criterion is not None:
|
||||
criterion = criterion.type()
|
||||
else:
|
||||
criterion = GPTLMLoss()
|
||||
logger.info('Build optimizer', ranks=[0])
|
||||
optimizer = gpc.config.optimizer.pop('type')(model.parameters(), **gpc.config.optimizer)
|
||||
logger.info("Build optimizer", ranks=[0])
|
||||
optimizer = gpc.config.optimizer.pop("type")(model.parameters(), **gpc.config.optimizer)
|
||||
lr_scheduler = LinearWarmupLR(optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=5)
|
||||
engine, train_dataloader, _, lr_scheduler = colossalai.initialize(model,
|
||||
optimizer,
|
||||
criterion,
|
||||
train_dataloader=train_dataloader,
|
||||
lr_scheduler=lr_scheduler)
|
||||
global_batch_size = gpc.config.BATCH_SIZE * \
|
||||
gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1)
|
||||
logger.info(f'Init done, global batch size = {global_batch_size}', ranks=[0])
|
||||
engine, train_dataloader, _, lr_scheduler = colossalai.initialize(
|
||||
model, optimizer, criterion, train_dataloader=train_dataloader, lr_scheduler=lr_scheduler
|
||||
)
|
||||
global_batch_size = (
|
||||
gpc.config.BATCH_SIZE * gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1)
|
||||
)
|
||||
logger.info(f"Init done, global batch size = {global_batch_size}", ranks=[0])
|
||||
timier = MultiTimer()
|
||||
trainer = Trainer(engine=engine, logger=logger, timer=timier)
|
||||
hook_list = [
|
||||
@@ -98,16 +103,18 @@ def main():
|
||||
hooks.ThroughputHook(ignored_steps=10, tflop_per_step=tflop),
|
||||
hooks.LogMetricByStepHook(),
|
||||
hooks.LogMemoryByEpochHook(logger),
|
||||
# hooks.LogMemoryByEpochHook(logger),
|
||||
# hooks.LogTimingByEpochHook(timer, logger),
|
||||
# hooks.LogMemoryByEpochHook(logger),
|
||||
# hooks.LogTimingByEpochHook(timer, logger),
|
||||
]
|
||||
trainer.fit(train_dataloader=train_dataloader,
|
||||
epochs=gpc.config.NUM_EPOCHS,
|
||||
test_interval=1,
|
||||
hooks=hook_list,
|
||||
display_progress=True,
|
||||
return_output_label=False)
|
||||
trainer.fit(
|
||||
train_dataloader=train_dataloader,
|
||||
epochs=gpc.config.NUM_EPOCHS,
|
||||
test_interval=1,
|
||||
hooks=hook_list,
|
||||
display_progress=True,
|
||||
return_output_label=False,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Reference in New Issue
Block a user