[example] llama2 add fine-tune example (#4673)

* [shardformer] update shardformer readme

[shardformer] update shardformer readme

[shardformer] update shardformer readme

* [shardformer] update llama2/opt finetune example and shardformer update to llama2

* [shardformer] update llama2/opt finetune example and shardformer update to llama2

* [shardformer] update llama2/opt finetune example and shardformer update to llama2

* [shardformer] change dataset

* [shardformer] change dataset

* [shardformer] fix CI

* [shardformer] fix

* [shardformer] fix

* [shardformer] fix

* [shardformer] fix

* [shardformer] fix

[example] update opt example

[example] resolve comments

fix

fix

* [example] llama2 add finetune example

* [example] llama2 add finetune example

* [example] llama2 add finetune example

* [example] llama2 add finetune example

* fix

* update llama2 example

* update llama2 example

* fix

* update llama2 example

* update llama2 example

* update llama2 example

* update llama2 example

* update llama2 example

* update llama2 example

* Update requirements.txt

* update llama2 example

* update llama2 example

* update llama2 example
This commit is contained in:
flybird11111
2023-09-15 18:45:44 +08:00
committed by GitHub
parent ac2797996b
commit 4c4482f3ad
8 changed files with 402 additions and 35 deletions

View File

@@ -21,7 +21,7 @@ from transformers.models.llama.tokenization_llama import LlamaTokenizer
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
@@ -65,9 +65,10 @@ def format_numel_str(numel: int) -> str:
return f'{numel}'
def tokenize_batch(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048):
def tokenize_batch_for_pretrain(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048):
texts = [sample['text'] for sample in batch]
data = tokenizer(texts, return_tensors="pt", padding='max_length', truncation=True, max_length=max_length)
data = {k: v.cuda() for k, v in data.items()}
data['labels'] = data['input_ids'].clone()
return data
@@ -104,6 +105,10 @@ def load(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler:
return running_states['epoch'], running_states['step'], running_states['sample_start_index']
def _criterion(outputs, inputs):
return outputs.loss
def main():
# ==============================
# Parse Arguments
@@ -112,7 +117,7 @@ def main():
parser.add_argument('-c', '--config', type=str, default='7b', help='Model configuration')
parser.add_argument('-p',
'--plugin',
choices=['gemini', 'gemini_auto', 'zero2', 'zero2_cpu'],
choices=['gemini', 'gemini_auto', 'zero2', 'zero2_cpu', 'hybrid_parallel'],
default='gemini',
help='Choose which plugin to use')
parser.add_argument('-d',
@@ -142,13 +147,6 @@ def main():
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
# ==============================
# Initialize Tensorboard
# ==============================
if coordinator.is_master():
os.makedirs(args.tensorboard_dir, exist_ok=True)
writer = SummaryWriter(args.tensorboard_dir)
# ==============================
# Initialize Booster
# ==============================
@@ -170,11 +168,32 @@ def main():
initial_scale=2**16,
cpu_offload=True,
max_norm=args.grad_clip)
elif args.plugin == 'hybrid_parallel':
# modify the param accordingly, default configuration is for llama2-7b
plugin = HybridParallelPlugin(tp_size=4,
pp_size=2,
num_microbatches=None,
microbatch_size=1,
enable_jit_fused=False,
zero_stage=0,
precision='fp32',
initial_scale=1)
else:
raise ValueError(f'Unknown plugin {args.plugin}')
booster = Booster(plugin=plugin)
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage)
# ==============================
# Initialize Tensorboard
# ==============================
if print_flag:
os.makedirs(args.tensorboard_dir, exist_ok=True)
writer = SummaryWriter(args.tensorboard_dir)
# ==============================
# Initialize Tokenizer, Dataset and Dataloader
# ==============================
@@ -188,12 +207,15 @@ def main():
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=partial(tokenize_batch, tokenizer=tokenizer, max_length=args.max_length))
collate_fn=partial(tokenize_batch_for_pretrain,
tokenizer=tokenizer,
max_length=args.max_length))
# ==============================
# Initialize Model, Optimizer and LR Scheduler
# ==============================
config = MODEL_CONFIGS[args.config]
# use lazy init when using GeminiPlugin
init_ctx = LazyInitContext(
default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext()
@@ -236,27 +258,42 @@ def main():
coordinator.print_on_master(f'Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}')
num_steps_per_epoch = len(dataloader)
# if resume training, set the sampler start index to the correct value
dataloader.sampler.set_start_index(sampler_start_idx)
for epoch in range(start_epoch, args.num_epochs):
dataloader.sampler.set_epoch(epoch)
with tqdm(enumerate(dataloader),
step_nums = num_steps_per_epoch - start_step
dataloader_iter = iter(dataloader)
with tqdm(range(step_nums),
desc=f'Epoch {epoch}',
disable=not coordinator.is_master(),
disable=not print_flag,
total=num_steps_per_epoch,
initial=start_step) as pbar:
for step, batch in pbar:
batch = {k: v.cuda() for k, v in batch.items()}
outputs = model(**batch)
loss = outputs[0]
booster.backward(loss, optimizer)
for step in pbar:
if use_pipeline:
outputs = booster.execute_pipeline(dataloader_iter,
model,
_criterion,
optimizer,
return_loss=True,
return_outputs=True)
loss = outputs["loss"]
else:
batch = next(dataloader_iter)
outputs = model(**batch)
loss = outputs[0]
booster.backward(loss, optimizer)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
all_reduce_mean(loss)
pbar.set_postfix({'loss': loss.item()})
if coordinator.is_master():
if not use_pipeline:
all_reduce_mean(loss)
if print_flag:
pbar.set_postfix({'loss': loss.item()})
writer.add_scalar('loss', loss.item(), epoch * num_steps_per_epoch + step)
if args.save_interval > 0 and (step + 1) % args.save_interval == 0: