mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[gemini] improve compatibility and add static placement policy (#4479)
* [gemini] remove distributed-related part from colotensor (#4379) * [gemini] remove process group dependency * [gemini] remove tp part from colo tensor * [gemini] patch inplace op * [gemini] fix param op hook and update tests * [test] remove useless tests * [test] remove useless tests * [misc] fix requirements * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [misc] update requirements * [gemini] refactor gemini optimizer and gemini ddp (#4398) * [gemini] update optimizer interface * [gemini] renaming gemini optimizer * [gemini] refactor gemini ddp class * [example] update gemini related example * [example] update gemini related example * [plugin] fix gemini plugin args * [test] update gemini ckpt tests * [gemini] fix checkpoint io * [example] fix opt example requirements * [example] fix opt example * [example] fix opt example * [example] fix opt example * [gemini] add static placement policy (#4443) * [gemini] add static placement policy * [gemini] fix param offload * [test] update gemini tests * [plugin] update gemini plugin * [plugin] update gemini plugin docstr * [misc] fix flash attn requirement * [test] fix gemini checkpoint io test * [example] update resnet example result (#4457) * [example] update bert example result (#4458) * [doc] update gemini doc (#4468) * [example] update gemini related examples (#4473) * [example] update gpt example * [example] update dreambooth example * [example] update vit * [example] update opt * [example] update palm * [example] update vit and opt benchmark * [hotfix] fix bert in model zoo (#4480) * [hotfix] fix bert in model zoo * [test] remove chatglm gemini test * [test] remove sam gemini test * [test] remove vit gemini test * [hotfix] fix opt tutorial example (#4497) * [hotfix] fix opt tutorial example * [hotfix] fix opt tutorial example
This commit is contained in:
@@ -3,5 +3,5 @@ torch >= 1.8.1
|
||||
datasets >= 1.8.0
|
||||
sentencepiece != 0.1.92
|
||||
protobuf
|
||||
accelerate == 0.13.2
|
||||
accelerate
|
||||
transformers
|
||||
|
@@ -30,7 +30,7 @@ from itertools import chain
|
||||
import datasets
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
import transformers.utils.logging as logging
|
||||
from accelerate.utils import set_seed
|
||||
from context import barrier_context
|
||||
from datasets import load_dataset
|
||||
@@ -57,7 +57,7 @@ from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.utils import get_current_device, get_dataloader
|
||||
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer
|
||||
from colossalai.zero import GeminiOptimizer
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
@@ -292,10 +292,10 @@ def main():
|
||||
|
||||
if is_main_process:
|
||||
datasets.utils.logging.set_verbosity_warning()
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
logging.set_verbosity_info()
|
||||
else:
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
logging.set_verbosity_error()
|
||||
|
||||
if args.mem_cap > 0:
|
||||
colo_memory_cap(args.mem_cap)
|
||||
@@ -391,16 +391,28 @@ def main():
|
||||
else:
|
||||
init_dev = get_current_device()
|
||||
|
||||
cai_version = colossalai.__version__
|
||||
logger.info(f'using Colossal-AI version {cai_version}')
|
||||
# build model
|
||||
if version.parse(cai_version) >= version.parse("0.3.1"):
|
||||
from contextlib import nullcontext
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
ctx = LazyInitContext(
|
||||
default_device=init_dev
|
||||
) if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b' else nullcontext()
|
||||
else:
|
||||
from colossalai.zero import ColoInitContext
|
||||
ctx = ColoInitContext(device=init_dev)
|
||||
if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b':
|
||||
# currently, there has a bug in pretrained opt-13b
|
||||
# we can not import it until huggingface fix it
|
||||
logger.info("Train a new model from scratch", ranks=[0])
|
||||
with ColoInitContext(device=init_dev):
|
||||
with ctx:
|
||||
model = OPTForCausalLM(config)
|
||||
else:
|
||||
logger.info("Finetune a pre-trained model", ranks=[0])
|
||||
with ColoInitContext(device=init_dev):
|
||||
with ctx:
|
||||
model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
@@ -410,9 +422,10 @@ def main():
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
PLACEMENT_POLICY = 'auto'
|
||||
cai_version = colossalai.__version__
|
||||
logger.info(f'using Colossal-AI version {cai_version}')
|
||||
if version.parse(cai_version) > version.parse("0.1.10"):
|
||||
if version.parse(cai_version) >= version.parse("0.3.1"):
|
||||
from colossalai.zero import GeminiDDP
|
||||
model = GeminiDDP(model, offload_optim_frac=1.0, pin_memory=True)
|
||||
elif version.parse(cai_version) > version.parse("0.1.10"):
|
||||
try:
|
||||
from colossalai.nn.parallel import GeminiDDP
|
||||
except ImportError:
|
||||
@@ -536,7 +549,6 @@ def main():
|
||||
]
|
||||
|
||||
optimizer = HybridAdam(optimizer_grouped_parameters, lr=args.learning_rate)
|
||||
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**14)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
@@ -551,6 +563,7 @@ def main():
|
||||
num_warmup_steps=args.num_warmup_steps,
|
||||
num_training_steps=args.max_train_steps,
|
||||
)
|
||||
optimizer = GeminiOptimizer(optimizer, model, initial_scale=2**14)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
|
@@ -4,9 +4,9 @@ set -xue
|
||||
|
||||
pip install -r requirements.txt
|
||||
|
||||
BS=8
|
||||
BS=4
|
||||
MEMCAP=0
|
||||
GPUNUM=2
|
||||
GPUNUM=4
|
||||
MODLE="facebook/opt-125m"
|
||||
|
||||
torchrun \
|
||||
|
Reference in New Issue
Block a user