mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[gemini] improve compatibility and add static placement policy (#4479)
* [gemini] remove distributed-related part from colotensor (#4379) * [gemini] remove process group dependency * [gemini] remove tp part from colo tensor * [gemini] patch inplace op * [gemini] fix param op hook and update tests * [test] remove useless tests * [test] remove useless tests * [misc] fix requirements * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [misc] update requirements * [gemini] refactor gemini optimizer and gemini ddp (#4398) * [gemini] update optimizer interface * [gemini] renaming gemini optimizer * [gemini] refactor gemini ddp class * [example] update gemini related example * [example] update gemini related example * [plugin] fix gemini plugin args * [test] update gemini ckpt tests * [gemini] fix checkpoint io * [example] fix opt example requirements * [example] fix opt example * [example] fix opt example * [example] fix opt example * [gemini] add static placement policy (#4443) * [gemini] add static placement policy * [gemini] fix param offload * [test] update gemini tests * [plugin] update gemini plugin * [plugin] update gemini plugin docstr * [misc] fix flash attn requirement * [test] fix gemini checkpoint io test * [example] update resnet example result (#4457) * [example] update bert example result (#4458) * [doc] update gemini doc (#4468) * [example] update gemini related examples (#4473) * [example] update gpt example * [example] update dreambooth example * [example] update vit * [example] update opt * [example] update palm * [example] update vit and opt benchmark * [hotfix] fix bert in model zoo (#4480) * [hotfix] fix bert in model zoo * [test] remove chatglm gemini test * [test] remove sam gemini test * [test] remove vit gemini test * [hotfix] fix opt tutorial example (#4497) * [hotfix] fix opt tutorial example * [hotfix] fix opt tutorial example
This commit is contained in:
@@ -22,7 +22,7 @@ from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wra
|
||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.zero import ZeroOptimizer
|
||||
from colossalai.zero import GeminiOptimizer
|
||||
|
||||
|
||||
def main():
|
||||
@@ -46,7 +46,7 @@ def main():
|
||||
args.local_rank = -1
|
||||
args.log_interval = 1
|
||||
else:
|
||||
colossalai.launch_from_torch(config={}) #args.colossal_config
|
||||
colossalai.launch_from_torch(config={}) # args.colossal_config
|
||||
args.local_rank = int(os.environ["LOCAL_RANK"])
|
||||
logger.info(
|
||||
f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' +
|
||||
@@ -123,7 +123,8 @@ def main():
|
||||
get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length)
|
||||
|
||||
# 144003367 is is the length of the entire dataset
|
||||
steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size #len(dataloader)
|
||||
# len(dataloader)
|
||||
steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size
|
||||
total_steps = steps_per_epoch * args.epoch
|
||||
|
||||
lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1)
|
||||
|
Reference in New Issue
Block a user