Files
ColossalAI/colossalai/booster/plugin/gemini_plugin.py
flybird11111 a0ad587c24 [shardformer] refactor embedding resize (#5603)
* [branch rebase] rebase main to Feature/resize_embedding (#5554)

* fix

* [release] update version (#5411)

* [hotfix] fix typo s/keywrods/keywords etc. (#5429)

* [devops] fix compatibility (#5444)

* [devops] fix compatibility

* [hotfix] update compatibility test on pr

* [devops] fix compatibility

* [devops] record duration during comp test

* [test] decrease test duration

* fix falcon

* [shardformer] fix gathering output when using tensor parallelism (#5431)

* fix

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* [doc] release Open-Sora 1.0 with model weights (#5468)

* [doc] release Open-Sora 1.0 with model weights

* [doc] release Open-Sora 1.0 with model weights

* [doc] release Open-Sora 1.0 with model weights

* [doc] update open-sora demo (#5479)

* [doc] update open-sora demo

* [doc] update open-sora demo

* [doc] update open-sora demo

* [example] add grok-1 inference (#5485)

* [misc] add submodule

* remove submodule

* [example] support grok-1 tp inference

* [example] add grok-1 inference script

* [example] refactor code

* [example] add grok-1 readme

* [exmaple] add test ci

* [exmaple] update readme

---------

Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* [CI] run pre-commit (#5577)

* fix

* [release] update version (#5411)

* [hotfix] fix typo s/keywrods/keywords etc. (#5429)

* [devops] fix compatibility (#5444)

* [devops] fix compatibility

* [hotfix] update compatibility test on pr

* [devops] fix compatibility

* [devops] record duration during comp test

* [test] decrease test duration

* fix falcon

* [shardformer] fix gathering output when using tensor parallelism (#5431)

* fix

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* [doc] release Open-Sora 1.0 with model weights (#5468)

* [doc] release Open-Sora 1.0 with model weights

* [doc] release Open-Sora 1.0 with model weights

* [doc] release Open-Sora 1.0 with model weights

* [doc] update open-sora demo (#5479)

* [doc] update open-sora demo

* [doc] update open-sora demo

* [doc] update open-sora demo

* [example] add grok-1 inference (#5485)

* [misc] add submodule

* remove submodule

* [example] support grok-1 tp inference

* [example] add grok-1 inference script

* [example] refactor code

* [example] add grok-1 readme

* [exmaple] add test ci

* [exmaple] update readme

* run pre-commit

---------

Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* [rebase] rebase main to resize-embedding (#5581)

* [release] grok-1 314b inference (#5490)

* [release] grok-1 inference

* [release] grok-1 inference

* [release] grok-1 inference

* [example] update Grok-1 inference (#5495)

* revise grok-1 example

* remove unused arg in scripts

* prevent re-installing torch

* update readme

* revert modifying colossalai requirements

* add perf

* trivial

* add tokenizer url

* [hotfix] set return_outputs=False in examples and polish code (#5404)

* fix: simplify merge_batch

* fix: use return_outputs=False to eliminate extra memory consumption

* feat: add return_outputs warning

* style: remove `return_outputs=False` as it is the default value

* [release] grok-1 inference benchmark (#5500)

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [shardformer]Fix lm parallel. (#5480)

* fix

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* fix lm forward distribution

* fix

* test ci

* fix

* [fix] fix grok-1 example typo (#5506)

* [devops] fix example test ci (#5504)

* Fix ColoTensorSpec for py11 (#5440)

* fixed layout converter caching and updated tester

* Empty-Commit

* [shardformer] update colo attention to support custom mask (#5510)

* [feature] refactor colo attention (#5462)

* [extension] update api

* [feature] add colo attention

* [feature] update sdpa

* [feature] update npu attention

* [feature] update flash-attn

* [test] add flash attn test

* [test] update flash attn test

* [shardformer] update modeling to fit colo attention (#5465)

* [misc] refactor folder structure

* [shardformer] update llama flash-attn

* [shardformer] fix llama policy

* [devops] update tensornvme install

* [test] update llama test

* [shardformer] update colo attn kernel dispatch

* [shardformer] update blip2

* [shardformer] update chatglm

* [shardformer] update gpt2

* [shardformer] update gptj

* [shardformer] update opt

* [shardformer] update vit

* [shardformer] update colo attention mask prep

* [shardformer] update whisper

* [test] fix shardformer tests (#5514)

* [test] fix shardformer tests

* [test] fix shardformer tests

* [format] applied code formatting on changed files in pull request 5510 (#5517)

Co-authored-by: github-actions <github-actions@github.com>

* [shardformer] fix pipeline forward error if custom layer distribution is used (#5189)

* Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution

* Change static methods for t5 layer distribution to member functions

* Change static methods for whisper layer distribution to member functions

* Replace whisper policy usage with self one

* Fix test case to use non-static layer distribution methods

* fix: fix typo

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>

* [Fix] Grok-1 use tokenizer from the same pretrained path (#5532)

* [fix] use tokenizer from the same pretrained path

* trust remote code

* [ColossalChat] Update RLHF V2 (#5286)

* Add dpo. Fix sft, ppo, lora. Refactor all

* fix and tested ppo

* 2 nd round refactor

* add ci tests

* fix ci

* fix ci

* fix readme, style

* fix readme style

* fix style, fix benchmark

* reproduce benchmark result, remove useless files

* rename to ColossalChat

* use new image

* fix ci workflow

* fix ci

* use local model/tokenizer for ci tests

* fix ci

* fix ci

* fix ci

* fix ci timeout

* fix rm progress bar. fix ci timeout

* fix ci

* fix ci typo

* remove 3d plugin from ci temporary

* test environment

* cannot save optimizer

* support chat template

* fix readme

* fix path

* test ci locally

* restore build_or_pr

* fix ci data path

* fix benchmark

* fix ci, move ci tests to 3080, disable fast tokenizer

* move ci to 85

* support flash attention 2

* add all-in-one data preparation script. Fix colossal-llama2-chat chat template

* add hardware requirements

* move ci test data

* fix save_model, add unwrap

* fix missing bos

* fix missing bos; support grad accumulation with gemini

* fix ci

* fix ci

* fix ci

* fix llama2 chat template config

* debug sft

* debug sft

* fix colossalai version requirement

* fix ci

* add sanity check to prevent NaN loss

* fix requirements

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* update readme

* update readme

* update readme and ignore

* fix logger bug

* support parallel_output

* modify data preparation logic

* fix tokenization

* update lr

* fix inference

* run pre-commit

---------

Co-authored-by: Tong Li <tong.li352711588@gmail.com>

* [shardformer, pipeline] add `gradient_checkpointing_ratio` and heterogenous shard policy for llama (#5508)

* feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig`

* feat: apply `GradientCheckpointConfig` to policy and llama_forward

* feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager

* fix: add optional args for `distribute_layer` and `get_stage_index`

* fix: fix changed API calls

* test: update llama tests

* style: polish `GradientCheckpointConfig`

* fix: fix pipeline utils tests

* fix incorrect sharding without zero (#5545)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [shardformer] Sequence Parallelism Optimization (#5533)

* sequence parallel optimization

* validate sequence parallel in llama (code to be polished)

* shardformer api writing

* integrate sequence parallel in ShardFormer

* fix pp bugs and sp bugs for LlaMa model

* integrating ring-based sequence parallelism into ShardFormer

* [sequence parallelism]: Add fused megatron function

* integrating ring-based sequence parallelism into ShardFormer

---------

Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>

* fix bugs when useing sp and flashattention together

* fix operation function name

* support flash attention for ulysses-style sp

* clarify sp process group

* fix compatibility bugs in moe plugin

* fix fused linear bugs

* fix linear layer test

* support gpt model all-to-all sp

* modify shard data dimension (meant to be dim=-1)

* support megtron-style sp and distributed attn for llama model

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* finish sp mode 3 support for gpt

* using all_to_all_single when batch size is 1

* support mode 2 sp in gpt2 (#5)

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* refactor ring implementation

* support mode 2 sp in gpt2

* polish code

* enable distributed attn mask when using sp mode 2 and 3 in llama

* automatically enable flash attn when using sp mode 2 and 3 in llama

* inplace attn mask

* add zero2 support for sequence parallel

* polish code

* fix bugs

* fix gemini checkpoint io

* loose tensor checking atol and rtol

* add comment

* fix llama layernorm grad

* fix zero grad

* fix zero grad

* fix conflict

* update split and gather auto grad func

* sequence parallel: inside text split (#6)

* polish code (part 1)

* polish code (part 2)

* polish code (part 2.5)

* polish code (part 3)

* sequence parallel: inside text split

* miscellaneous minor fixes

* polish code

* fix ulysses style ZeRO

* sequence parallel: inside text split

* miscellaneous minor fixes

* disaggregate sp group and dp group for  sp

* fix llama and gpt sp

* polish code

* move ulysses grad sync to ddp (#9)

* remove zero_stage and unbind the grad sync for alltoall sp

* add 2d group creation test

* move ulysses grad sync to ddp

* add 2d group creation test

* remove useless code

* change shard config not to enable sp when enable_all_optimizations

* add sp warnings for several model

* remove useless code

---------

Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>

* [hotfix] quick fixes to make legacy tutorials runnable (#5559)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [fix] fix typo s/muiti-node /multi-node etc. (#5448)

* [hotfix] fix typo s/get_defualt_parser /get_default_parser (#5548)

* [devops] remove post commit ci (#5566)

* [devops] remove post commit ci

* [misc] run pre-commit on all files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

---------

Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Wenhao Chen <cwher@outlook.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [shardformer]enable padding vocabulary size. (#5489)

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* padding vocab

* padding vocabe

* fix

* fix

* fxi

* test ci

* fix

fix

fix

fix

* fix

fix

* fix

* fix

* Update hybrid_parallel_plugin.py

fix

fix

fix

* fix

fix

* fix

fix

* fix

* resolve super init

resolve super init

resolve super init

resolve super init

* resolve comments

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* vocab checkpointio

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

fix

fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* padding vocab

* fix

* fix

fix

* fix

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* cherry-pick

* revert moe modify

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

fix

fix

fix

fix

fix

fix

fix

* resolve comments

resolve comments

resolve comments

resolve comments

resolve comments

* ptensor

ptensor

resolve comments

fix

fix

fix

fix

fix

resolve comments

resolve comments

resolve comments

resolve comments

resolve comments

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix rebase

* fix rebase

---------

Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Wenhao Chen <cwher@outlook.com>
Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-04-18 16:10:18 +08:00

575 lines
27 KiB
Python

import gc
import logging
import os
import random
from pathlib import Path
from typing import Callable, Iterator, List, Optional, Tuple
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.distributed_c10d import _get_default_group
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from colossalai.accelerator import get_accelerator
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import (
get_model_base_filenames,
get_optimizer_base_filenames,
load_shard_state_dict,
save_config_file,
save_state_dict,
save_state_dict_shards,
)
from colossalai.cluster import DistCoordinator, ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats
from .dp_plugin_base import DPPluginBase
__all__ = ["GeminiPlugin"]
SUPPORTED_PRECISION = ["fp16", "bf16"]
PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16}
ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2
def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
# 1. A mapping from integer param_id to param32 shape.
if optim is None:
return {}
param_info = {"id2shape": {}}
start_index = 0
for group in optim.param_groups:
for param_id, param in enumerate(group["params"], start_index):
original_shape = param.shape if isinstance(param, torch.Tensor) else None
param_info["id2shape"][param_id] = original_shape
start_index += len(group["params"])
return param_info
class GeminiCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
Save sharded model to checkpoint but only on master process.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
As there is communication when getting state dict, model.state_dict() must be called on all processes.
"""
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors)
def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
"""
Load model from checkpoint with automatic unwrapping.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
"""
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
super().load_unsharded_model(model, checkpoint, strict=strict)
def save_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool):
"""
Save unsharded optimizer state dict to checkpoint.
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
As there is communication when getting state dict, optimizer.state_dict() must be called on all processes.
The saving process will only be executed by master rank.
"""
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
state_dict = optimizer.state_dict()
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str):
"""
Loading unsharded optimizer from checkpoint file.
For each process, only loading optimizer states of parameters it controls.
"""
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
super().load_unsharded_optimizer(optimizer, checkpoint)
def save_sharded_model(
self,
model: GeminiDDP,
checkpoint_path: str,
gather_dtensor: bool = False,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
):
"""
Save sharded model.
As there is communication when getting state dict, model.state_dict() must be called on all processes.
"""
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
return
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint_path)
# Save shards of optimizer states.
is_master = self.coordinator.is_master()
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=is_master,
use_safetensors=use_safetensors,
)
# only save the index file on the master rank
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model.unwrap(), checkpoint_path)
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
def load_sharded_model(
self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False
):
"""
Load shard model, load model from multiple files.
"""
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
def save_sharded_optimizer(
self, optimizer: GeminiOptimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
):
"""
Save sharded optimizer state dict to checkpoint folder.
As there is communication when getting state dict, this must be called on all processes.
"""
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
index_file.append_meta_data("param_groups", param_group_file)
# Store the information of param groups to param_group_file.
if self.coordinator.is_master():
group_file_path = os.path.join(checkpoint, param_group_file)
param_groups = optimizer.get_param_groups_for_saving()
torch.save(param_groups, group_file_path)
# States are broken into shards within max_shard_size.
state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True)
# Save shards of optimizer states.
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=self.coordinator.is_master(),
use_safetensors=False,
)
# Wrap up index file. Only save it on master rank.
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
logging.info(
f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str):
"""
Loading sharded optimizer from checkpoint folder, with index file given.
For each process, only loading optimizer states of parameters it controls.
"""
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
if not os.path.isfile(checkpoint_index_file):
logging.error(f"Provided path ({checkpoint_index_file}) should be a file")
assert isinstance(optimizer, GeminiOptimizer)
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
# Load param_groups.
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
raise RuntimeError(
f"Invalid index file path {checkpoint_index_file} for an optimizer. \
Lacking param group file under current directory."
)
saved_param_groups = torch.load(param_group_path)
optimizer.load_param_groups(saved_param_groups)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
# Load optimizer states from shard files under checkpoint path.
# For each file, only load the states managed by current process.
for shard_file in checkpoint_files:
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
optimizer.load_param_states(state_dict_shard)
del state_dict_shard
gc.collect()
optimizer.optimizer_loading_epilogue()
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Save model to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
class GeminiPlugin(DPPluginBase):
"""
Plugin for Gemini.
```python
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
model, train_dataset, optimizer, criterion = ...
plugin = GeminiPlugin()
train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
booster = Booster(plugin=plugin)
model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
```
Args:
chunk_config_dict (dict, optional): chunk configuration dictionary.
chunk_init_device (torch.device, optional): device to initialize the chunk.
placement_policy (str, optional): "static" and "auto". Defaults to "static".
enable_gradient_accumulation (bool, optional): Whether to enable gradient accumulation. When set to True, gradient will be stored after doing backward pass. Defaults to False.
shard_param_frac (float, optional): fraction of parameters to be sharded. Only for "static" placement.
If `shard_param_frac` is 1.0, it's equal to zero-3. If `shard_param_frac` is 0.0, it's equal to zero-2. Defaults to 1.0.
offload_optim_frac (float, optional): fraction of optimizer states to be offloaded. Only for "static" placement.
If `shard_param_frac` is 1.0 and `offload_optim_frac` is 0.0, it's equal to old "cuda" placement. Defaults to 0.0.
offload_param_frac (float, optional): fraction of parameters to be offloaded. Only for "static" placement.
For efficiency, this argument is useful only when `shard_param_frac` is 1.0 and `offload_optim_frac` is 1.0.
If `shard_param_frac` is 1.0, `offload_optim_frac` is 1.0 and `offload_param_frac` is 1.0, it's equal to old "cpu" placement.
When using static placement, we recommend users to tune `shard_param_frac` first and then `offload_optim_frac`.
Defaults to 0.0.
warmup_non_model_data_ratio (float, optional): ratio of expected non-model data memory during warmup. Only for "auto" placement. Defaults to 0.8.
steady_cuda_cap_ratio (float, optional): ratio of allowed cuda capacity for model data during steady state. Only for "auto" placement. Defaults to 0.9.
precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'.
master_weights (bool, optional): Whether to keep fp32 master parameter weights in optimizer. Defaults to True.
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False.
search_range_m (int, optional): chunk size searching range divided by 2^20. Defaults to 32.
hidden_dim (int, optional): the hidden dimension of DNN.
Users can provide this argument to speed up searching.
If users do not know this argument before training, it is ok. We will use a default value 1024.
min_chunk_size_m (float, optional): the minimum chunk size divided by 2^20.
If the aggregate size of parameters is still smaller than the minimum chunk size,
all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
which will be used when using hybrid CPU optimizer.
This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto".
Defaults to 0.0.
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**16.
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000.
hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2.
max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32.
max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
norm_type (float, optional): norm_type used for `clip_grad_norm`.
tp_size (int, optional): If 'tp_size' is set to be greater than 1, it means using tensor parallelism strategy, which is implemented in Shardformer, 'tp_size' determines the size of the tensor parallel process group. Default to 1.
extra_dp_size (int, optional): If 'extra_dp_size' is set to be greater than 1, it means creating another group to run with a ddp-like strategy. Default to 1.
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
Currently all the optimization methods include fused normalization, flash attention and JIT.
Defaults to False.
enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False.
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False.
"""
def __init__(
self,
chunk_config_dict: Optional[dict] = None,
chunk_init_device: Optional[torch.device] = None,
placement_policy: str = "static",
enable_gradient_accumulation: bool = False,
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
warmup_non_model_data_ratio: float = 0.8, # only for auto placement
steady_cuda_cap_ratio: float = 0.9, # only for auto placement
precision: str = "fp16",
master_weights: bool = True,
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
search_range_m: int = 32,
hidden_dim: Optional[int] = None,
min_chunk_size_m: float = 32,
memstats: Optional[MemStats] = None,
gpu_margin_mem_ratio: float = 0.0,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
tp_size: int = 1,
extra_dp_size: int = 1,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
enable_sequence_parallelism: bool = False,
enable_jit_fused: bool = False,
enable_sequence_overlap: bool = False,
verbose: bool = False,
) -> None:
super().__init__()
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
if get_accelerator().name == "npu":
assert placement_policy == "static", "NPU only supports static placement policy"
self.gemini_config = dict(
chunk_config_dict=chunk_config_dict,
chunk_init_device=(chunk_init_device or get_accelerator().get_current_device()),
placement_policy=placement_policy,
enable_gradient_accumulation=enable_gradient_accumulation,
shard_param_frac=shard_param_frac,
offload_optim_frac=offload_optim_frac,
offload_param_frac=offload_param_frac,
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
steady_cuda_cap_ratio=steady_cuda_cap_ratio,
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
strict_ddp_mode=strict_ddp_mode,
search_range_m=search_range_m,
hidden_dim=hidden_dim,
min_chunk_size_m=min_chunk_size_m,
memstats=memstats,
mixed_precision=PRECISION_STR_TO_DTYPE[precision],
master_weights=master_weights,
)
self.zero_optim_config = dict(
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
)
self.optim_kwargs = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type,
)
self.enable_tensor_parallelism = tp_size > 1
self.enable_all_optimization = enable_all_optimization
self.enable_fused_normalization = enable_fused_normalization
self.enable_flash_attention = enable_flash_attention
self.enable_sequence_parallelism = enable_sequence_parallelism if self.enable_tensor_parallelism else False
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_overlap = enable_sequence_overlap
self.verbose = verbose
self.tp_size = tp_size
self.extra_dp_size = extra_dp_size
world_size = dist.get_world_size()
self.zero_size = world_size // (self.tp_size * self.extra_dp_size)
assert (
world_size == (self.tp_size * self.extra_dp_size) * self.zero_size
), f"The global group size can't be evenly divided by the subgroup size."
self.pg_mesh = ProcessGroupMesh(self.zero_size, self.extra_dp_size, self.tp_size)
self.zero_group = (
self.pg_mesh.get_group_along_axis(ZERO_AXIS) if self.zero_size < world_size else _get_default_group()
)
self.extra_dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) if self.extra_dp_size > 1 else None
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) if self.tp_size > 1 else None
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
enable_tensor_parallelism=self.enable_tensor_parallelism,
enable_all_optimization=self.enable_all_optimization,
enable_fused_normalization=self.enable_fused_normalization,
enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=self.enable_sequence_parallelism,
enable_sequence_overlap=self.enable_sequence_overlap,
)
def __del__(self):
"""Destroy the process groups in ProcessGroupMesh"""
self.pg_mesh.destroy_mesh_process_groups()
def support_no_sync(self) -> bool:
return False
def control_precision(self) -> bool:
return True
def supported_precisions(self) -> List[str]:
return SUPPORTED_PRECISION
def control_device(self) -> bool:
return True
def supported_devices(self) -> List[str]:
return ["cuda", "npu"]
def prepare_dataloader(
self,
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
distributed_sampler_cls=None,
**kwargs,
):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
Args:
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
seed (int, optional): Random worker seed for sampling, defaults to 1024.
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
is not divisible by the batch size. If False and the size of dataset is not divisible by
the batch size, then the last batch will be smaller, defaults to False.
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
Returns:
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
zero_world_size = self.pg_mesh.size(ZERO_AXIS)
extra_dp_world_size = self.pg_mesh.size(DP_AXIS)
zero_rank = self.pg_mesh.coordinate(ZERO_AXIS)
extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS)
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
sampler = distributed_sampler_cls(
dataset,
num_replicas=zero_world_size * extra_dp_world_size,
rank=zero_rank * extra_dp_world_size + extra_dp_rank,
shuffle=shuffle,
)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs,
)
def configure(
self,
model: nn.Module,
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
params_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper):
# convert model to sync bn
# FIXME(ver217): gemini does not support sync bn
# In torch/nn/modules/_functions.py, line 22, ``mean, invstd = torch.batch_norm_stats(input, eps)`` will get fp32 mean and invstd even though the input is fp16.
# This inconsistency of dtype will cause the error.
# We have two possible solutions:
# 1. keep batch norm always in fp32. This is hard for gemini, as it use chunks.
# 2. patch sync bn or write a new on. This is relatively easy, but we need to test it.
# model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)
# wrap the model with Gemini
if self.enable_tensor_parallelism:
shardformer = ShardFormer(self.shard_config)
model, _ = shardformer.optimize(model)
model = GeminiDDP(
model,
**self.gemini_config,
zero_group=self.zero_group,
extra_dp_group=self.extra_dp_group,
verbose=self.verbose,
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(
optimizer,
model,
**self.zero_optim_config,
**self.optim_kwargs,
tp_group=self.tp_group,
params_info=params_info,
verbose=self.verbose,
)
return model, optimizer, criterion, dataloader, lr_scheduler
def control_checkpoint_io(self) -> bool:
return True
def get_checkpoint_io(self) -> CheckpointIO:
return GeminiCheckpointIO()
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
raise NotImplementedError