mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 13:59:08 +00:00
[gemini] improve compatibility and add static placement policy (#4479)
* [gemini] remove distributed-related part from colotensor (#4379) * [gemini] remove process group dependency * [gemini] remove tp part from colo tensor * [gemini] patch inplace op * [gemini] fix param op hook and update tests * [test] remove useless tests * [test] remove useless tests * [misc] fix requirements * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [misc] update requirements * [gemini] refactor gemini optimizer and gemini ddp (#4398) * [gemini] update optimizer interface * [gemini] renaming gemini optimizer * [gemini] refactor gemini ddp class * [example] update gemini related example * [example] update gemini related example * [plugin] fix gemini plugin args * [test] update gemini ckpt tests * [gemini] fix checkpoint io * [example] fix opt example requirements * [example] fix opt example * [example] fix opt example * [example] fix opt example * [gemini] add static placement policy (#4443) * [gemini] add static placement policy * [gemini] fix param offload * [test] update gemini tests * [plugin] update gemini plugin * [plugin] update gemini plugin docstr * [misc] fix flash attn requirement * [test] fix gemini checkpoint io test * [example] update resnet example result (#4457) * [example] update bert example result (#4458) * [doc] update gemini doc (#4468) * [example] update gemini related examples (#4473) * [example] update gpt example * [example] update dreambooth example * [example] update vit * [example] update opt * [example] update palm * [example] update vit and opt benchmark * [hotfix] fix bert in model zoo (#4480) * [hotfix] fix bert in model zoo * [test] remove chatglm gemini test * [test] remove sam gemini test * [test] remove vit gemini test * [hotfix] fix opt tutorial example (#4497) * [hotfix] fix opt tutorial example * [hotfix] fix opt tutorial example
This commit is contained in:
@@ -1,25 +1,20 @@
|
||||
import time
|
||||
|
||||
import torch
|
||||
import datasets
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import AutoConfig, OPTForCausalLM, AutoTokenizer
|
||||
from transformers import get_linear_schedule_with_warmup
|
||||
from transformers.utils.versions import require_version
|
||||
from args import parse_demo_args
|
||||
from data import NetflixDataset, netflix_collator
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoConfig, AutoTokenizer, OPTForCausalLM, get_linear_schedule_with_warmup
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.tensor import ProcessGroup, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
||||
from args import parse_demo_args
|
||||
from data import NetflixDataset, netflix_collator
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r requirements.txt")
|
||||
require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt")
|
||||
@@ -30,18 +25,18 @@ def move_to_cuda(batch, device):
|
||||
|
||||
|
||||
def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator):
|
||||
|
||||
|
||||
torch.cuda.synchronize()
|
||||
model.train()
|
||||
|
||||
with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar:
|
||||
|
||||
|
||||
for batch in pbar:
|
||||
|
||||
# Forward
|
||||
optimizer.zero_grad()
|
||||
batch = move_to_cuda(batch, torch.cuda.current_device())
|
||||
|
||||
|
||||
outputs = model(use_cache=False, **batch)
|
||||
loss = outputs['loss']
|
||||
|
||||
@@ -72,7 +67,7 @@ def main():
|
||||
else:
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
|
||||
|
||||
# Build OPT model
|
||||
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
|
||||
@@ -88,43 +83,35 @@ def main():
|
||||
if args.plugin.startswith('torch_ddp'):
|
||||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(device=get_current_device(),
|
||||
placement_policy='cpu',
|
||||
pin_memory=True,
|
||||
strict_ddp_mode=True,
|
||||
initial_scale=2**5)
|
||||
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
||||
|
||||
# Prepare tokenizer and dataloader
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
||||
dataset = NetflixDataset(tokenizer)
|
||||
dataloader = plugin.prepare_dataloader(dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=netflix_collator)
|
||||
|
||||
|
||||
# Set optimizer
|
||||
optimizer = HybridAdam(model.parameters(),
|
||||
lr=(args.learning_rate * world_size),
|
||||
weight_decay=args.weight_decay)
|
||||
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay)
|
||||
|
||||
# Set lr scheduler
|
||||
total_steps = len(dataloader) * args.num_epoch
|
||||
num_warmup_steps = int(args.warmup_ratio * total_steps)
|
||||
lr_scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=len(dataloader) * args.num_epoch
|
||||
)
|
||||
lr_scheduler = get_linear_schedule_with_warmup(optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=len(dataloader) * args.num_epoch)
|
||||
|
||||
# Set booster
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model,
|
||||
optimizer=optimizer,
|
||||
dataloader=dataloader,
|
||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model,
|
||||
optimizer=optimizer,
|
||||
dataloader=dataloader,
|
||||
lr_scheduler=lr_scheduler)
|
||||
|
||||
# Start finetuning
|
||||
|
Reference in New Issue
Block a user