[gemini] improve compatibility and add static placement policy (#4479)

* [gemini] remove distributed-related part from colotensor (#4379)

* [gemini] remove process group dependency

* [gemini] remove tp part from colo tensor

* [gemini] patch inplace op

* [gemini] fix param op hook and update tests

* [test] remove useless tests

* [test] remove useless tests

* [misc] fix requirements

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [misc] update requirements

* [gemini] refactor gemini optimizer and gemini ddp (#4398)

* [gemini] update optimizer interface

* [gemini] renaming gemini optimizer

* [gemini] refactor gemini ddp class

* [example] update gemini related example

* [example] update gemini related example

* [plugin] fix gemini plugin args

* [test] update gemini ckpt tests

* [gemini] fix checkpoint io

* [example] fix opt example requirements

* [example] fix opt example

* [example] fix opt example

* [example] fix opt example

* [gemini] add static placement policy (#4443)

* [gemini] add static placement policy

* [gemini] fix param offload

* [test] update gemini tests

* [plugin] update gemini plugin

* [plugin] update gemini plugin docstr

* [misc] fix flash attn requirement

* [test] fix gemini checkpoint io test

* [example] update resnet example result (#4457)

* [example] update bert example result (#4458)

* [doc] update gemini doc (#4468)

* [example] update gemini related examples (#4473)

* [example] update gpt example

* [example] update dreambooth example

* [example] update vit

* [example] update opt

* [example] update palm

* [example] update vit and opt benchmark

* [hotfix] fix bert in model zoo (#4480)

* [hotfix] fix bert in model zoo

* [test] remove chatglm gemini test

* [test] remove sam gemini test

* [test] remove vit gemini test

* [hotfix] fix opt tutorial example (#4497)

* [hotfix] fix opt tutorial example

* [hotfix] fix opt tutorial example
This commit is contained in:
Hongxin Liu
2023-08-24 09:29:25 +08:00
committed by GitHub
parent 285fe7ba71
commit 27061426f7
82 changed files with 1008 additions and 4036 deletions

View File

@@ -1,5 +1,5 @@
import gzip
import random
from contextlib import nullcontext
from functools import partial
from time import time
@@ -8,20 +8,17 @@ import torch
import torch.nn as nn
import torch.optim as optim
import tqdm
from packaging import version
from colossalai.nn import HybridAdam
from palm_pytorch import PaLM
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
from torch.utils.data import DataLoader, Dataset
import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import MultiTimer, get_current_device
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, ZeroDDP
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.lazy import LazyInitContext
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn import HybridAdam
from colossalai.utils import get_current_device
# constants
@@ -44,23 +41,10 @@ def parse_args():
help="The distributed plan [colossalai, pytorch].",
)
parser.add_argument(
"--tp_degree",
type=int,
default=1,
help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--placement",
type=str,
default='cpu',
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--shardinit",
type=bool,
default=False,
help=
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
"--offload_optim_frac",
type=float,
default=1.0,
help="Fraction of optimizer states to be offloaded. This is only used for gemini.",
)
parser.add_argument('-p',
'--plugin',
@@ -111,51 +95,6 @@ def get_model_size(model: nn.Module):
return total_numel
# Parameter Sharding Strategies for Tensor Parallelism
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
param.set_tensor_spec(*spec)
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(0, param, pg)
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(-1, param, pg)
# Tensor Parallel
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
"""tensor_parallelize
Sharding the Model Parameters.
Args:
model (torch.nn.Module): a torch module to be sharded
"""
for mn, module in model.named_modules():
for pn, param in module.named_parameters(recurse=False):
if hasattr(param, 'visited'):
continue
param.set_dist_spec(ReplicaSpec())
if 'net.0' in mn:
split_param_col_tp1d(param, pg) # column slice
elif 'to_q' in mn:
split_param_col_tp1d(param, pg) # column slice
elif 'to_kv' in mn:
split_param_row_tp1d(param, pg) # row slice
elif 'to_out' in mn:
split_param_row_tp1d(param, pg) # row slice
elif '1.1' in mn:
split_param_col_tp1d(param, pg) # column slice
elif '1.2' in mn:
split_param_row_tp1d(param, pg) # row slice
else:
param.set_dist_spec(ReplicaSpec())
param.visited = True
args = parse_args()
if args.distplan not in ["colossalai", "pytorch"]:
raise TypeError(f"{args.distplan} is error")
@@ -212,23 +151,18 @@ if args.distplan == "colossalai":
if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5)
plugin = GeminiPlugin(offload_optim_frac=args.offload_optim_frac, initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)
plugin = LowLevelZeroPlugin(initial_scale=2**5)
logger.info(f"plugin: {plugin}")
booster = Booster(plugin=plugin, **booster_kwargs)
default_pg = ProcessGroup(tp_degree=args.tp_degree)
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg)
ctx = LazyInitContext(default_device=get_current_device()) if args.plugin == 'gemini' else nullcontext()
with ctx:
model = PaLM(num_tokens=50304, dim=4096, depth=64)
model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
pg = default_pg
tensor_parallelize(model, pg)
# optimizer
optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE, initial_scale=2**5)