[gemini] improve compatibility and add static placement policy (#4479)

* [gemini] remove distributed-related part from colotensor (#4379)

* [gemini] remove process group dependency

* [gemini] remove tp part from colo tensor

* [gemini] patch inplace op

* [gemini] fix param op hook and update tests

* [test] remove useless tests

* [test] remove useless tests

* [misc] fix requirements

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [misc] update requirements

* [gemini] refactor gemini optimizer and gemini ddp (#4398)

* [gemini] update optimizer interface

* [gemini] renaming gemini optimizer

* [gemini] refactor gemini ddp class

* [example] update gemini related example

* [example] update gemini related example

* [plugin] fix gemini plugin args

* [test] update gemini ckpt tests

* [gemini] fix checkpoint io

* [example] fix opt example requirements

* [example] fix opt example

* [example] fix opt example

* [example] fix opt example

* [gemini] add static placement policy (#4443)

* [gemini] add static placement policy

* [gemini] fix param offload

* [test] update gemini tests

* [plugin] update gemini plugin

* [plugin] update gemini plugin docstr

* [misc] fix flash attn requirement

* [test] fix gemini checkpoint io test

* [example] update resnet example result (#4457)

* [example] update bert example result (#4458)

* [doc] update gemini doc (#4468)

* [example] update gemini related examples (#4473)

* [example] update gpt example

* [example] update dreambooth example

* [example] update vit

* [example] update opt

* [example] update palm

* [example] update vit and opt benchmark

* [hotfix] fix bert in model zoo (#4480)

* [hotfix] fix bert in model zoo

* [test] remove chatglm gemini test

* [test] remove sam gemini test

* [test] remove vit gemini test

* [hotfix] fix opt tutorial example (#4497)

* [hotfix] fix opt tutorial example

* [hotfix] fix opt tutorial example
This commit is contained in:
Hongxin Liu
2023-08-24 09:29:25 +08:00
committed by GitHub
parent 285fe7ba71
commit 27061426f7
82 changed files with 1008 additions and 4036 deletions

View File

@@ -7,6 +7,14 @@ This directory includes two parts: Using the Booster API finetune Huggingface Be
bash test_ci.sh
```
### Results on 2-GPU
| Plugin | Accuracy | F1-score |
| -------------- | -------- | -------- |
| torch_ddp | 84.4% | 88.6% |
| torch_ddp_fp16 | 84.7% | 88.8% |
| gemini | 84.0% | 88.4% |
## Benchmark
```
bash benchmark.sh
@@ -14,9 +22,9 @@ bash benchmark.sh
Now include these metrics in benchmark: CUDA mem occupy, throughput and the number of model parameters. If you have custom metrics, you can add them to benchmark_util.
## Results
### Results
### Bert
#### Bert
| | max cuda mem | throughput(sample/s) | params |
| :-----| -----------: | :--------: | :----: |
@@ -25,10 +33,10 @@ Now include these metrics in benchmark: CUDA mem occupy, throughput and the numb
| gemini | 11.0 GB | 12.9 | 82M |
| low_level_zero | 11.29 G | 14.7 | 82M |
### AlBert
#### AlBert
| | max cuda mem | throughput(sample/s) | params |
| :-----| -----------: | :--------: | :----: |
| ddp | OOM | | |
| ddp_fp16 | OOM | | |
| gemini | 69.39 G | 1.3 | 208M |
| low_level_zero | 56.89 G | 1.4 | 208M |
| low_level_zero | 56.89 G | 1.4 | 208M |

View File

@@ -38,8 +38,8 @@ def move_to_cuda(batch):
@torch.no_grad()
def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int, task_name: str,
eval_splits: List[str], coordinator: DistCoordinator):
def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int,
task_name: str, eval_splits: List[str], coordinator: DistCoordinator):
metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size)
model.eval()
@@ -142,7 +142,7 @@ def main():
if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5)
plugin = GeminiPlugin(initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)
@@ -208,7 +208,7 @@ def main():
train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator)
results = evaluate_model(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits,
coordinator)
coordinator)
if coordinator.is_master():
print(results)

View File

@@ -4,9 +4,6 @@ export DISTPLAN=${DISTPLAN:-"CAI_Gemini"}
# The following options only valid when DISTPLAN="colossalai"
export GPUNUM=${GPUNUM:-1}
export TPDEGREE=${TPDEGREE:-1}
export PLACEMENT=${PLACEMENT:-"cpu"}
export USE_SHARD_INIT=${USE_SHARD_INIT:-False}
export BATCH_SIZE=${BATCH_SIZE:-16}
export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}
export TRAIN_STEP=${TRAIN_STEP:-10}
@@ -21,11 +18,8 @@ fi
mkdir -p gemini_logs
torchrun --standalone --nproc_per_node=${GPUNUM} ./train_gpt_demo.py \
--tp_degree=${TPDEGREE} \
--model_type=${MODEL_TYPE} \
--batch_size=${BATCH_SIZE} \
--placement=${PLACEMENT} \
${USE_SHARD_INIT} \
--distplan=${DISTPLAN} \
--train_step=${TRAIN_STEP} \
2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log

View File

@@ -6,29 +6,17 @@ for MODEL_TYPE in "gpt2_medium"; do
for DISTPLAN in "CAI_Gemini"; do
for BATCH_SIZE in 2; do
for GPUNUM in 1 4; do
for TPDEGREE in 1 2; do
if [ ${TPDEGREE} -gt ${GPUNUM} ]; then
continue
fi
for PLACEMENT in "cpu" "auto"; do
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE} PLACEMENT=${PLACEMENT} \
bash ./run_gemini.sh
done
done
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} \
bash ./run_gemini.sh
done
done
done
for DISTPLAN in "zero1" "zero2"; do
for DISTPLAN in "CAI_ZeRO2" "CAI_ZeRO1"; do
for BATCH_SIZE in 2; do
for GPUNUM in 1 4; do
for TPDEGREE in 1; do
if [ ${TPDEGREE} -gt ${GPUNUM} ]; then
continue
fi
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE}\
bash ./run_gemini.sh
done
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} \
bash ./run_gemini.sh
done
done
done

View File

@@ -1,4 +1,5 @@
import os
from contextlib import nullcontext
from functools import partial
from time import time
@@ -13,11 +14,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.lazy import LazyInitContext
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext
CAI_VERSION = colossalai.__version__
@@ -30,24 +30,6 @@ def parse_args():
default='CAI_Gemini',
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
)
parser.add_argument(
"--tp_degree",
type=int,
default=1,
help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--placement",
type=str,
default='cpu',
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--shardinit",
action='store_true',
help=
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--batch_size",
type=int,
@@ -71,20 +53,6 @@ def parse_args():
return args
# Parameter Sharding Strategies for Tensor Parallelism
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
param.set_tensor_spec(*spec)
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(0, param, pg)
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(-1, param, pg)
class GPTLMLoss(nn.Module):
def __init__(self):
@@ -140,47 +108,6 @@ def set_cpu_maximum_parallelism():
print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.")
# Tensor Parallel
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
"""tensor_parallelize
Sharding the Model Parameters.
Args:
model (torch.nn.Module): a torch module to be sharded
"""
for mn, module in model.named_modules():
for pn, param in module.named_parameters(recurse=False):
# NOTE() a param maybe shared by two modules
if hasattr(param, 'visited'):
continue
# if shard init, then convert param to replica and use the dp-only ProcessGroup
param: ColoParameter = param
param.set_dist_spec(ReplicaSpec())
param.set_process_group(pg)
# shard it w.r.t tp pattern
if 'mlp.c_fc' in mn:
if 'weight' in pn or 'bias' in pn:
split_param_col_tp1d(param, pg) # column slice
# keep the shape of the output from c_fc
param.compute_spec.set_output_replicate(False)
else:
param.set_dist_spec(ReplicaSpec())
elif 'mlp.c_proj' in mn:
if 'weight' in pn:
split_param_row_tp1d(param, pg) # row slice
else:
param.set_dist_spec(ReplicaSpec())
elif 'wte' in mn or 'wpe' in mn:
split_param_col_tp1d(param, pg) # column slice
elif 'c_attn' in mn or 'c_proj' in mn:
split_param_col_tp1d(param, pg) # column slice
else:
param.set_dist_spec(ReplicaSpec())
param.visited = True
def main():
# version check
# this example is supposed to work for versions greater than 0.2.0
@@ -213,30 +140,13 @@ def main():
# build criterion
criterion = GPTLMLoss()
torch.manual_seed(123)
if args.distplan.startswith("CAI"):
# all param must use the same process group.
world_size = torch.distributed.get_world_size()
shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
if args.shardinit and args.distplan != "CAI_Gemini":
raise RuntimeError("You can only use shardinit with CAI_Gemini")
ctx = LazyInitContext(default_device=get_current_device()) if args.distplan == "CAI_Gemini" else nullcontext()
# build GPT model
with ColoInitContext(device=get_current_device(),
dtype=torch.half,
default_dist_spec=default_dist_spec,
default_pg=shard_pg):
with ctx:
model = model_builder(args.model_type)(checkpoint=True)
tp_pg = ProcessGroup(tp_degree=args.tp_degree)
# Tensor Parallelism (TP)
# You should notice that v0.1.10 is not compatible with TP degree > 1
if args.tp_degree > 1:
tensor_parallelize(model, tp_pg)
# assign running configurations
if args.distplan == "CAI_ZeRO1":
zero_stage = 1
@@ -254,13 +164,7 @@ def main():
overlap_communication=True,
verbose=True)
elif args.distplan == "CAI_Gemini":
plugin = GeminiPlugin(device=get_current_device(),
placement_policy=args.placement,
pin_memory=True,
strict_ddp_mode=args.tp_degree == 1,
search_range_m=128,
hidden_dim=model.config.n_embd,
gpu_margin_mem_ratio=0.)
plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd)
else:
raise RuntimeError

View File

@@ -1,22 +1,18 @@
import time
import torch
import tqdm
import transformers
from args import parse_benchmark_args
from transformers import AutoConfig, OPTForCausalLM
from transformers.utils.versions import require_version
import tqdm
import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from args import parse_benchmark_args
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt")
@@ -61,11 +57,11 @@ def main():
transformers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
# Whether to set limit of memory capacity
if args.mem_cap > 0:
colo_memory_cap(args.mem_cap)
# Build OPT model
config = AutoConfig.from_pretrained(args.model_name_or_path)
model = OPTForCausalLM(config=config)
@@ -81,11 +77,7 @@ def main():
if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
plugin = GeminiPlugin(device=get_current_device(),
placement_policy='cpu',
pin_memory=True,
strict_ddp_mode=True,
initial_scale=2**5)
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
@@ -96,18 +88,18 @@ def main():
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
model, optimizer, _, _, _ = booster.boost(model, optimizer)
SEQ_LEN = 1024
VOCAB_SIZE = 50257
# Start training.
logger.info(f"Start testing", ranks=[0])
progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master())
torch.cuda.synchronize()
model.train()
start_time = time.time()
for _ in range(args.max_train_steps):
input_ids, attn_mask = get_data(args.batch_size, SEQ_LEN, VOCAB_SIZE)
@@ -119,18 +111,19 @@ def main():
torch.cuda.synchronize()
progress_bar.update(1)
# Compute Statistics
# Compute Statistics
end_time = time.time()
throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time))
max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True)
logger.info(f"Testing finished, "
f"batch size per gpu: {args.batch_size}, "
f"plugin: {args.plugin}, "
f"throughput: {throughput}, "
f"maximum memory usage per gpu: {max_mem}.",
ranks=[0])
logger.info(
f"Testing finished, "
f"batch size per gpu: {args.batch_size}, "
f"plugin: {args.plugin}, "
f"throughput: {throughput}, "
f"maximum memory usage per gpu: {max_mem}.",
ranks=[0])
if __name__ == "__main__":

View File

@@ -1,25 +1,20 @@
import time
import torch
import datasets
import torch
import transformers
from transformers import AutoConfig, OPTForCausalLM, AutoTokenizer
from transformers import get_linear_schedule_with_warmup
from transformers.utils.versions import require_version
from args import parse_demo_args
from data import NetflixDataset, netflix_collator
from tqdm import tqdm
from transformers import AutoConfig, AutoTokenizer, OPTForCausalLM, get_linear_schedule_with_warmup
from transformers.utils.versions import require_version
import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from args import parse_demo_args
from data import NetflixDataset, netflix_collator
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
require_version("datasets>=1.8.0", "To fix: pip install -r requirements.txt")
require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt")
@@ -30,18 +25,18 @@ def move_to_cuda(batch, device):
def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator):
torch.cuda.synchronize()
model.train()
with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar:
for batch in pbar:
# Forward
optimizer.zero_grad()
batch = move_to_cuda(batch, torch.cuda.current_device())
outputs = model(use_cache=False, **batch)
loss = outputs['loss']
@@ -72,7 +67,7 @@ def main():
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
# Build OPT model
config = AutoConfig.from_pretrained(args.model_name_or_path)
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
@@ -88,43 +83,35 @@ def main():
if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
plugin = GeminiPlugin(device=get_current_device(),
placement_policy='cpu',
pin_memory=True,
strict_ddp_mode=True,
initial_scale=2**5)
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
# Prepare tokenizer and dataloader
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
dataset = NetflixDataset(tokenizer)
dataloader = plugin.prepare_dataloader(dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=netflix_collator)
# Set optimizer
optimizer = HybridAdam(model.parameters(),
lr=(args.learning_rate * world_size),
weight_decay=args.weight_decay)
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay)
# Set lr scheduler
total_steps = len(dataloader) * args.num_epoch
num_warmup_steps = int(args.warmup_ratio * total_steps)
lr_scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=len(dataloader) * args.num_epoch
)
lr_scheduler = get_linear_schedule_with_warmup(optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=len(dataloader) * args.num_epoch)
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model,
optimizer=optimizer,
dataloader=dataloader,
model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model,
optimizer=optimizer,
dataloader=dataloader,
lr_scheduler=lr_scheduler)
# Start finetuning

View File

@@ -1,5 +1,5 @@
import gzip
import random
from contextlib import nullcontext
from functools import partial
from time import time
@@ -8,20 +8,17 @@ import torch
import torch.nn as nn
import torch.optim as optim
import tqdm
from packaging import version
from colossalai.nn import HybridAdam
from palm_pytorch import PaLM
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
from torch.utils.data import DataLoader, Dataset
import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import MultiTimer, get_current_device
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, ZeroDDP
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.lazy import LazyInitContext
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn import HybridAdam
from colossalai.utils import get_current_device
# constants
@@ -44,23 +41,10 @@ def parse_args():
help="The distributed plan [colossalai, pytorch].",
)
parser.add_argument(
"--tp_degree",
type=int,
default=1,
help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--placement",
type=str,
default='cpu',
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--shardinit",
type=bool,
default=False,
help=
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
"--offload_optim_frac",
type=float,
default=1.0,
help="Fraction of optimizer states to be offloaded. This is only used for gemini.",
)
parser.add_argument('-p',
'--plugin',
@@ -111,51 +95,6 @@ def get_model_size(model: nn.Module):
return total_numel
# Parameter Sharding Strategies for Tensor Parallelism
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
param.set_tensor_spec(*spec)
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(0, param, pg)
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(-1, param, pg)
# Tensor Parallel
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
"""tensor_parallelize
Sharding the Model Parameters.
Args:
model (torch.nn.Module): a torch module to be sharded
"""
for mn, module in model.named_modules():
for pn, param in module.named_parameters(recurse=False):
if hasattr(param, 'visited'):
continue
param.set_dist_spec(ReplicaSpec())
if 'net.0' in mn:
split_param_col_tp1d(param, pg) # column slice
elif 'to_q' in mn:
split_param_col_tp1d(param, pg) # column slice
elif 'to_kv' in mn:
split_param_row_tp1d(param, pg) # row slice
elif 'to_out' in mn:
split_param_row_tp1d(param, pg) # row slice
elif '1.1' in mn:
split_param_col_tp1d(param, pg) # column slice
elif '1.2' in mn:
split_param_row_tp1d(param, pg) # row slice
else:
param.set_dist_spec(ReplicaSpec())
param.visited = True
args = parse_args()
if args.distplan not in ["colossalai", "pytorch"]:
raise TypeError(f"{args.distplan} is error")
@@ -212,23 +151,18 @@ if args.distplan == "colossalai":
if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5)
plugin = GeminiPlugin(offload_optim_frac=args.offload_optim_frac, initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)
plugin = LowLevelZeroPlugin(initial_scale=2**5)
logger.info(f"plugin: {plugin}")
booster = Booster(plugin=plugin, **booster_kwargs)
default_pg = ProcessGroup(tp_degree=args.tp_degree)
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg)
ctx = LazyInitContext(default_device=get_current_device()) if args.plugin == 'gemini' else nullcontext()
with ctx:
model = PaLM(num_tokens=50304, dim=4096, depth=64)
model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
pg = default_pg
tensor_parallelize(model, pg)
# optimizer
optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE, initial_scale=2**5)