mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[example] llama2 add fine-tune example (#4673)
* [shardformer] update shardformer readme [shardformer] update shardformer readme [shardformer] update shardformer readme * [shardformer] update llama2/opt finetune example and shardformer update to llama2 * [shardformer] update llama2/opt finetune example and shardformer update to llama2 * [shardformer] update llama2/opt finetune example and shardformer update to llama2 * [shardformer] change dataset * [shardformer] change dataset * [shardformer] fix CI * [shardformer] fix * [shardformer] fix * [shardformer] fix * [shardformer] fix * [shardformer] fix [example] update opt example [example] resolve comments fix fix * [example] llama2 add finetune example * [example] llama2 add finetune example * [example] llama2 add finetune example * [example] llama2 add finetune example * fix * update llama2 example * update llama2 example * fix * update llama2 example * update llama2 example * update llama2 example * update llama2 example * update llama2 example * update llama2 example * Update requirements.txt * update llama2 example * update llama2 example * update llama2 example
This commit is contained in:
@@ -129,14 +129,13 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion:
|
||||
|
||||
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
|
||||
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
|
||||
print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage)
|
||||
total_step = len(train_dataloader)
|
||||
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
train_dataloader_iter = iter(train_dataloader)
|
||||
with tqdm(range(total_step),
|
||||
desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]',
|
||||
disable=not (coordinator.is_master() or is_pp_last_stage)) as pbar:
|
||||
with tqdm(range(total_step), desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not print_flag) as pbar:
|
||||
# Forward pass
|
||||
for _ in pbar:
|
||||
if use_pipeline:
|
||||
@@ -192,13 +191,13 @@ def main():
|
||||
model_name = "albert-xxlarge-v2"
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
# ==============================
|
||||
# Launch Distributed Environment
|
||||
# ==============================
|
||||
colossalai.launch_from_torch(config={}, seed=42)
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# local_batch_size = BATCH_SIZE // coordinator.world_size
|
||||
lr = LEARNING_RATE * coordinator.world_size
|
||||
|
||||
# ==============================
|
||||
|
Reference in New Issue
Block a user