mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-14 11:18:58 +00:00
[example] llama2 add fine-tune example (#4673)
* [shardformer] update shardformer readme [shardformer] update shardformer readme [shardformer] update shardformer readme * [shardformer] update llama2/opt finetune example and shardformer update to llama2 * [shardformer] update llama2/opt finetune example and shardformer update to llama2 * [shardformer] update llama2/opt finetune example and shardformer update to llama2 * [shardformer] change dataset * [shardformer] change dataset * [shardformer] fix CI * [shardformer] fix * [shardformer] fix * [shardformer] fix * [shardformer] fix * [shardformer] fix [example] update opt example [example] resolve comments fix fix * [example] llama2 add finetune example * [example] llama2 add finetune example * [example] llama2 add finetune example * [example] llama2 add finetune example * fix * update llama2 example * update llama2 example * fix * update llama2 example * update llama2 example * update llama2 example * update llama2 example * update llama2 example * update llama2 example * Update requirements.txt * update llama2 example * update llama2 example * update llama2 example
This commit is contained in:
@@ -21,7 +21,7 @@ from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
@@ -65,9 +65,10 @@ def format_numel_str(numel: int) -> str:
|
||||
return f'{numel}'
|
||||
|
||||
|
||||
def tokenize_batch(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048):
|
||||
def tokenize_batch_for_pretrain(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048):
|
||||
texts = [sample['text'] for sample in batch]
|
||||
data = tokenizer(texts, return_tensors="pt", padding='max_length', truncation=True, max_length=max_length)
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
data['labels'] = data['input_ids'].clone()
|
||||
return data
|
||||
|
||||
@@ -104,6 +105,10 @@ def load(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler:
|
||||
return running_states['epoch'], running_states['step'], running_states['sample_start_index']
|
||||
|
||||
|
||||
def _criterion(outputs, inputs):
|
||||
return outputs.loss
|
||||
|
||||
|
||||
def main():
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
@@ -112,7 +117,7 @@ def main():
|
||||
parser.add_argument('-c', '--config', type=str, default='7b', help='Model configuration')
|
||||
parser.add_argument('-p',
|
||||
'--plugin',
|
||||
choices=['gemini', 'gemini_auto', 'zero2', 'zero2_cpu'],
|
||||
choices=['gemini', 'gemini_auto', 'zero2', 'zero2_cpu', 'hybrid_parallel'],
|
||||
default='gemini',
|
||||
help='Choose which plugin to use')
|
||||
parser.add_argument('-d',
|
||||
@@ -142,13 +147,6 @@ def main():
|
||||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
# Initialize Tensorboard
|
||||
# ==============================
|
||||
if coordinator.is_master():
|
||||
os.makedirs(args.tensorboard_dir, exist_ok=True)
|
||||
writer = SummaryWriter(args.tensorboard_dir)
|
||||
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
@@ -170,11 +168,32 @@ def main():
|
||||
initial_scale=2**16,
|
||||
cpu_offload=True,
|
||||
max_norm=args.grad_clip)
|
||||
elif args.plugin == 'hybrid_parallel':
|
||||
# modify the param accordingly, default configuration is for llama2-7b
|
||||
plugin = HybridParallelPlugin(tp_size=4,
|
||||
pp_size=2,
|
||||
num_microbatches=None,
|
||||
microbatch_size=1,
|
||||
enable_jit_fused=False,
|
||||
zero_stage=0,
|
||||
precision='fp32',
|
||||
initial_scale=1)
|
||||
else:
|
||||
raise ValueError(f'Unknown plugin {args.plugin}')
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
|
||||
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
|
||||
print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage)
|
||||
|
||||
# ==============================
|
||||
# Initialize Tensorboard
|
||||
# ==============================
|
||||
if print_flag:
|
||||
os.makedirs(args.tensorboard_dir, exist_ok=True)
|
||||
writer = SummaryWriter(args.tensorboard_dir)
|
||||
|
||||
# ==============================
|
||||
# Initialize Tokenizer, Dataset and Dataloader
|
||||
# ==============================
|
||||
@@ -188,12 +207,15 @@ def main():
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=partial(tokenize_batch, tokenizer=tokenizer, max_length=args.max_length))
|
||||
collate_fn=partial(tokenize_batch_for_pretrain,
|
||||
tokenizer=tokenizer,
|
||||
max_length=args.max_length))
|
||||
|
||||
# ==============================
|
||||
# Initialize Model, Optimizer and LR Scheduler
|
||||
# ==============================
|
||||
config = MODEL_CONFIGS[args.config]
|
||||
# use lazy init when using GeminiPlugin
|
||||
init_ctx = LazyInitContext(
|
||||
default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext()
|
||||
|
||||
@@ -236,27 +258,42 @@ def main():
|
||||
coordinator.print_on_master(f'Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}')
|
||||
|
||||
num_steps_per_epoch = len(dataloader)
|
||||
|
||||
# if resume training, set the sampler start index to the correct value
|
||||
dataloader.sampler.set_start_index(sampler_start_idx)
|
||||
for epoch in range(start_epoch, args.num_epochs):
|
||||
dataloader.sampler.set_epoch(epoch)
|
||||
with tqdm(enumerate(dataloader),
|
||||
step_nums = num_steps_per_epoch - start_step
|
||||
dataloader_iter = iter(dataloader)
|
||||
|
||||
with tqdm(range(step_nums),
|
||||
desc=f'Epoch {epoch}',
|
||||
disable=not coordinator.is_master(),
|
||||
disable=not print_flag,
|
||||
total=num_steps_per_epoch,
|
||||
initial=start_step) as pbar:
|
||||
for step, batch in pbar:
|
||||
batch = {k: v.cuda() for k, v in batch.items()}
|
||||
outputs = model(**batch)
|
||||
loss = outputs[0]
|
||||
booster.backward(loss, optimizer)
|
||||
for step in pbar:
|
||||
if use_pipeline:
|
||||
outputs = booster.execute_pipeline(dataloader_iter,
|
||||
model,
|
||||
_criterion,
|
||||
optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=True)
|
||||
loss = outputs["loss"]
|
||||
else:
|
||||
batch = next(dataloader_iter)
|
||||
outputs = model(**batch)
|
||||
loss = outputs[0]
|
||||
booster.backward(loss, optimizer)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
all_reduce_mean(loss)
|
||||
pbar.set_postfix({'loss': loss.item()})
|
||||
if coordinator.is_master():
|
||||
if not use_pipeline:
|
||||
all_reduce_mean(loss)
|
||||
if print_flag:
|
||||
pbar.set_postfix({'loss': loss.item()})
|
||||
writer.add_scalar('loss', loss.item(), epoch * num_steps_per_epoch + step)
|
||||
|
||||
if args.save_interval > 0 and (step + 1) % args.save_interval == 0:
|
||||
|
Reference in New Issue
Block a user