mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[gemini] improve compatibility and add static placement policy (#4479)
* [gemini] remove distributed-related part from colotensor (#4379) * [gemini] remove process group dependency * [gemini] remove tp part from colo tensor * [gemini] patch inplace op * [gemini] fix param op hook and update tests * [test] remove useless tests * [test] remove useless tests * [misc] fix requirements * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [misc] update requirements * [gemini] refactor gemini optimizer and gemini ddp (#4398) * [gemini] update optimizer interface * [gemini] renaming gemini optimizer * [gemini] refactor gemini ddp class * [example] update gemini related example * [example] update gemini related example * [plugin] fix gemini plugin args * [test] update gemini ckpt tests * [gemini] fix checkpoint io * [example] fix opt example requirements * [example] fix opt example * [example] fix opt example * [example] fix opt example * [gemini] add static placement policy (#4443) * [gemini] add static placement policy * [gemini] fix param offload * [test] update gemini tests * [plugin] update gemini plugin * [plugin] update gemini plugin docstr * [misc] fix flash attn requirement * [test] fix gemini checkpoint io test * [example] update resnet example result (#4457) * [example] update bert example result (#4458) * [doc] update gemini doc (#4468) * [example] update gemini related examples (#4473) * [example] update gpt example * [example] update dreambooth example * [example] update vit * [example] update opt * [example] update palm * [example] update vit and opt benchmark * [hotfix] fix bert in model zoo (#4480) * [hotfix] fix bert in model zoo * [test] remove chatglm gemini test * [test] remove sam gemini test * [test] remove vit gemini test * [hotfix] fix opt tutorial example (#4497) * [hotfix] fix opt tutorial example * [hotfix] fix opt tutorial example
This commit is contained in:
@@ -7,6 +7,14 @@ This directory includes two parts: Using the Booster API finetune Huggingface Be
|
||||
bash test_ci.sh
|
||||
```
|
||||
|
||||
### Results on 2-GPU
|
||||
|
||||
| Plugin | Accuracy | F1-score |
|
||||
| -------------- | -------- | -------- |
|
||||
| torch_ddp | 84.4% | 88.6% |
|
||||
| torch_ddp_fp16 | 84.7% | 88.8% |
|
||||
| gemini | 84.0% | 88.4% |
|
||||
|
||||
## Benchmark
|
||||
```
|
||||
bash benchmark.sh
|
||||
@@ -14,9 +22,9 @@ bash benchmark.sh
|
||||
|
||||
Now include these metrics in benchmark: CUDA mem occupy, throughput and the number of model parameters. If you have custom metrics, you can add them to benchmark_util.
|
||||
|
||||
## Results
|
||||
### Results
|
||||
|
||||
### Bert
|
||||
#### Bert
|
||||
|
||||
| | max cuda mem | throughput(sample/s) | params |
|
||||
| :-----| -----------: | :--------: | :----: |
|
||||
@@ -25,10 +33,10 @@ Now include these metrics in benchmark: CUDA mem occupy, throughput and the numb
|
||||
| gemini | 11.0 GB | 12.9 | 82M |
|
||||
| low_level_zero | 11.29 G | 14.7 | 82M |
|
||||
|
||||
### AlBert
|
||||
#### AlBert
|
||||
| | max cuda mem | throughput(sample/s) | params |
|
||||
| :-----| -----------: | :--------: | :----: |
|
||||
| ddp | OOM | | |
|
||||
| ddp_fp16 | OOM | | |
|
||||
| gemini | 69.39 G | 1.3 | 208M |
|
||||
| low_level_zero | 56.89 G | 1.4 | 208M |
|
||||
| low_level_zero | 56.89 G | 1.4 | 208M |
|
||||
|
@@ -38,8 +38,8 @@ def move_to_cuda(batch):
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int, task_name: str,
|
||||
eval_splits: List[str], coordinator: DistCoordinator):
|
||||
def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int,
|
||||
task_name: str, eval_splits: List[str], coordinator: DistCoordinator):
|
||||
metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size)
|
||||
model.eval()
|
||||
|
||||
@@ -142,7 +142,7 @@ def main():
|
||||
if args.plugin.startswith('torch_ddp'):
|
||||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5)
|
||||
plugin = GeminiPlugin(initial_scale=2**5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
|
||||
@@ -208,7 +208,7 @@ def main():
|
||||
train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator)
|
||||
|
||||
results = evaluate_model(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits,
|
||||
coordinator)
|
||||
coordinator)
|
||||
|
||||
if coordinator.is_master():
|
||||
print(results)
|
||||
|
@@ -4,9 +4,6 @@ export DISTPLAN=${DISTPLAN:-"CAI_Gemini"}
|
||||
|
||||
# The following options only valid when DISTPLAN="colossalai"
|
||||
export GPUNUM=${GPUNUM:-1}
|
||||
export TPDEGREE=${TPDEGREE:-1}
|
||||
export PLACEMENT=${PLACEMENT:-"cpu"}
|
||||
export USE_SHARD_INIT=${USE_SHARD_INIT:-False}
|
||||
export BATCH_SIZE=${BATCH_SIZE:-16}
|
||||
export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}
|
||||
export TRAIN_STEP=${TRAIN_STEP:-10}
|
||||
@@ -21,11 +18,8 @@ fi
|
||||
mkdir -p gemini_logs
|
||||
|
||||
torchrun --standalone --nproc_per_node=${GPUNUM} ./train_gpt_demo.py \
|
||||
--tp_degree=${TPDEGREE} \
|
||||
--model_type=${MODEL_TYPE} \
|
||||
--batch_size=${BATCH_SIZE} \
|
||||
--placement=${PLACEMENT} \
|
||||
${USE_SHARD_INIT} \
|
||||
--distplan=${DISTPLAN} \
|
||||
--train_step=${TRAIN_STEP} \
|
||||
2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log
|
||||
|
@@ -6,29 +6,17 @@ for MODEL_TYPE in "gpt2_medium"; do
|
||||
for DISTPLAN in "CAI_Gemini"; do
|
||||
for BATCH_SIZE in 2; do
|
||||
for GPUNUM in 1 4; do
|
||||
for TPDEGREE in 1 2; do
|
||||
if [ ${TPDEGREE} -gt ${GPUNUM} ]; then
|
||||
continue
|
||||
fi
|
||||
for PLACEMENT in "cpu" "auto"; do
|
||||
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE} PLACEMENT=${PLACEMENT} \
|
||||
bash ./run_gemini.sh
|
||||
done
|
||||
done
|
||||
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} \
|
||||
bash ./run_gemini.sh
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
for DISTPLAN in "zero1" "zero2"; do
|
||||
for DISTPLAN in "CAI_ZeRO2" "CAI_ZeRO1"; do
|
||||
for BATCH_SIZE in 2; do
|
||||
for GPUNUM in 1 4; do
|
||||
for TPDEGREE in 1; do
|
||||
if [ ${TPDEGREE} -gt ${GPUNUM} ]; then
|
||||
continue
|
||||
fi
|
||||
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE}\
|
||||
bash ./run_gemini.sh
|
||||
done
|
||||
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} \
|
||||
bash ./run_gemini.sh
|
||||
done
|
||||
done
|
||||
done
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from time import time
|
||||
|
||||
@@ -13,11 +14,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
|
||||
CAI_VERSION = colossalai.__version__
|
||||
|
||||
@@ -30,24 +30,6 @@ def parse_args():
|
||||
default='CAI_Gemini',
|
||||
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp_degree",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--placement",
|
||||
type=str,
|
||||
default='cpu',
|
||||
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shardinit",
|
||||
action='store_true',
|
||||
help=
|
||||
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
@@ -71,20 +53,6 @@ def parse_args():
|
||||
return args
|
||||
|
||||
|
||||
# Parameter Sharding Strategies for Tensor Parallelism
|
||||
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
|
||||
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
param.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(0, param, pg)
|
||||
|
||||
|
||||
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(-1, param, pg)
|
||||
|
||||
|
||||
class GPTLMLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
@@ -140,47 +108,6 @@ def set_cpu_maximum_parallelism():
|
||||
print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.")
|
||||
|
||||
|
||||
# Tensor Parallel
|
||||
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
||||
"""tensor_parallelize
|
||||
Sharding the Model Parameters.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): a torch module to be sharded
|
||||
"""
|
||||
for mn, module in model.named_modules():
|
||||
for pn, param in module.named_parameters(recurse=False):
|
||||
# NOTE() a param maybe shared by two modules
|
||||
if hasattr(param, 'visited'):
|
||||
continue
|
||||
|
||||
# if shard init, then convert param to replica and use the dp-only ProcessGroup
|
||||
param: ColoParameter = param
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
param.set_process_group(pg)
|
||||
|
||||
# shard it w.r.t tp pattern
|
||||
if 'mlp.c_fc' in mn:
|
||||
if 'weight' in pn or 'bias' in pn:
|
||||
split_param_col_tp1d(param, pg) # column slice
|
||||
# keep the shape of the output from c_fc
|
||||
param.compute_spec.set_output_replicate(False)
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
elif 'mlp.c_proj' in mn:
|
||||
if 'weight' in pn:
|
||||
split_param_row_tp1d(param, pg) # row slice
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
elif 'wte' in mn or 'wpe' in mn:
|
||||
split_param_col_tp1d(param, pg) # column slice
|
||||
elif 'c_attn' in mn or 'c_proj' in mn:
|
||||
split_param_col_tp1d(param, pg) # column slice
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
param.visited = True
|
||||
|
||||
|
||||
def main():
|
||||
# version check
|
||||
# this example is supposed to work for versions greater than 0.2.0
|
||||
@@ -213,30 +140,13 @@ def main():
|
||||
|
||||
# build criterion
|
||||
criterion = GPTLMLoss()
|
||||
|
||||
torch.manual_seed(123)
|
||||
if args.distplan.startswith("CAI"):
|
||||
# all param must use the same process group.
|
||||
world_size = torch.distributed.get_world_size()
|
||||
shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
|
||||
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
|
||||
|
||||
if args.shardinit and args.distplan != "CAI_Gemini":
|
||||
raise RuntimeError("You can only use shardinit with CAI_Gemini")
|
||||
|
||||
ctx = LazyInitContext(default_device=get_current_device()) if args.distplan == "CAI_Gemini" else nullcontext()
|
||||
# build GPT model
|
||||
with ColoInitContext(device=get_current_device(),
|
||||
dtype=torch.half,
|
||||
default_dist_spec=default_dist_spec,
|
||||
default_pg=shard_pg):
|
||||
with ctx:
|
||||
model = model_builder(args.model_type)(checkpoint=True)
|
||||
|
||||
tp_pg = ProcessGroup(tp_degree=args.tp_degree)
|
||||
# Tensor Parallelism (TP)
|
||||
# You should notice that v0.1.10 is not compatible with TP degree > 1
|
||||
if args.tp_degree > 1:
|
||||
tensor_parallelize(model, tp_pg)
|
||||
|
||||
# assign running configurations
|
||||
if args.distplan == "CAI_ZeRO1":
|
||||
zero_stage = 1
|
||||
@@ -254,13 +164,7 @@ def main():
|
||||
overlap_communication=True,
|
||||
verbose=True)
|
||||
elif args.distplan == "CAI_Gemini":
|
||||
plugin = GeminiPlugin(device=get_current_device(),
|
||||
placement_policy=args.placement,
|
||||
pin_memory=True,
|
||||
strict_ddp_mode=args.tp_degree == 1,
|
||||
search_range_m=128,
|
||||
hidden_dim=model.config.n_embd,
|
||||
gpu_margin_mem_ratio=0.)
|
||||
plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd)
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
|
@@ -1,22 +1,18 @@
|
||||
import time
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
import transformers
|
||||
from args import parse_benchmark_args
|
||||
from transformers import AutoConfig, OPTForCausalLM
|
||||
from transformers.utils.versions import require_version
|
||||
import tqdm
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.tensor import ProcessGroup, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
||||
from args import parse_benchmark_args
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt")
|
||||
|
||||
@@ -61,11 +57,11 @@ def main():
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
|
||||
|
||||
# Whether to set limit of memory capacity
|
||||
if args.mem_cap > 0:
|
||||
colo_memory_cap(args.mem_cap)
|
||||
|
||||
|
||||
# Build OPT model
|
||||
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||
model = OPTForCausalLM(config=config)
|
||||
@@ -81,11 +77,7 @@ def main():
|
||||
if args.plugin.startswith('torch_ddp'):
|
||||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(device=get_current_device(),
|
||||
placement_policy='cpu',
|
||||
pin_memory=True,
|
||||
strict_ddp_mode=True,
|
||||
initial_scale=2**5)
|
||||
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
||||
@@ -96,18 +88,18 @@ def main():
|
||||
# Set booster
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
model, optimizer, _, _, _ = booster.boost(model, optimizer)
|
||||
|
||||
|
||||
SEQ_LEN = 1024
|
||||
VOCAB_SIZE = 50257
|
||||
|
||||
# Start training.
|
||||
logger.info(f"Start testing", ranks=[0])
|
||||
progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master())
|
||||
|
||||
|
||||
torch.cuda.synchronize()
|
||||
model.train()
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
for _ in range(args.max_train_steps):
|
||||
|
||||
input_ids, attn_mask = get_data(args.batch_size, SEQ_LEN, VOCAB_SIZE)
|
||||
@@ -119,18 +111,19 @@ def main():
|
||||
|
||||
torch.cuda.synchronize()
|
||||
progress_bar.update(1)
|
||||
|
||||
# Compute Statistics
|
||||
|
||||
# Compute Statistics
|
||||
end_time = time.time()
|
||||
throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time))
|
||||
max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True)
|
||||
|
||||
logger.info(f"Testing finished, "
|
||||
f"batch size per gpu: {args.batch_size}, "
|
||||
f"plugin: {args.plugin}, "
|
||||
f"throughput: {throughput}, "
|
||||
f"maximum memory usage per gpu: {max_mem}.",
|
||||
ranks=[0])
|
||||
|
||||
logger.info(
|
||||
f"Testing finished, "
|
||||
f"batch size per gpu: {args.batch_size}, "
|
||||
f"plugin: {args.plugin}, "
|
||||
f"throughput: {throughput}, "
|
||||
f"maximum memory usage per gpu: {max_mem}.",
|
||||
ranks=[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@@ -1,25 +1,20 @@
|
||||
import time
|
||||
|
||||
import torch
|
||||
import datasets
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import AutoConfig, OPTForCausalLM, AutoTokenizer
|
||||
from transformers import get_linear_schedule_with_warmup
|
||||
from transformers.utils.versions import require_version
|
||||
from args import parse_demo_args
|
||||
from data import NetflixDataset, netflix_collator
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoConfig, AutoTokenizer, OPTForCausalLM, get_linear_schedule_with_warmup
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.tensor import ProcessGroup, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
||||
from args import parse_demo_args
|
||||
from data import NetflixDataset, netflix_collator
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r requirements.txt")
|
||||
require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt")
|
||||
@@ -30,18 +25,18 @@ def move_to_cuda(batch, device):
|
||||
|
||||
|
||||
def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator):
|
||||
|
||||
|
||||
torch.cuda.synchronize()
|
||||
model.train()
|
||||
|
||||
with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar:
|
||||
|
||||
|
||||
for batch in pbar:
|
||||
|
||||
# Forward
|
||||
optimizer.zero_grad()
|
||||
batch = move_to_cuda(batch, torch.cuda.current_device())
|
||||
|
||||
|
||||
outputs = model(use_cache=False, **batch)
|
||||
loss = outputs['loss']
|
||||
|
||||
@@ -72,7 +67,7 @@ def main():
|
||||
else:
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
|
||||
|
||||
# Build OPT model
|
||||
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
|
||||
@@ -88,43 +83,35 @@ def main():
|
||||
if args.plugin.startswith('torch_ddp'):
|
||||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(device=get_current_device(),
|
||||
placement_policy='cpu',
|
||||
pin_memory=True,
|
||||
strict_ddp_mode=True,
|
||||
initial_scale=2**5)
|
||||
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
||||
|
||||
# Prepare tokenizer and dataloader
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
||||
dataset = NetflixDataset(tokenizer)
|
||||
dataloader = plugin.prepare_dataloader(dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=netflix_collator)
|
||||
|
||||
|
||||
# Set optimizer
|
||||
optimizer = HybridAdam(model.parameters(),
|
||||
lr=(args.learning_rate * world_size),
|
||||
weight_decay=args.weight_decay)
|
||||
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay)
|
||||
|
||||
# Set lr scheduler
|
||||
total_steps = len(dataloader) * args.num_epoch
|
||||
num_warmup_steps = int(args.warmup_ratio * total_steps)
|
||||
lr_scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=len(dataloader) * args.num_epoch
|
||||
)
|
||||
lr_scheduler = get_linear_schedule_with_warmup(optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=len(dataloader) * args.num_epoch)
|
||||
|
||||
# Set booster
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model,
|
||||
optimizer=optimizer,
|
||||
dataloader=dataloader,
|
||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model,
|
||||
optimizer=optimizer,
|
||||
dataloader=dataloader,
|
||||
lr_scheduler=lr_scheduler)
|
||||
|
||||
# Start finetuning
|
||||
|
@@ -1,5 +1,5 @@
|
||||
import gzip
|
||||
import random
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from time import time
|
||||
|
||||
@@ -8,20 +8,17 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import tqdm
|
||||
from packaging import version
|
||||
|
||||
from colossalai.nn import HybridAdam
|
||||
from palm_pytorch import PaLM
|
||||
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||
from colossalai.utils import MultiTimer, get_current_device
|
||||
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, ZeroDDP
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
# constants
|
||||
|
||||
@@ -44,23 +41,10 @@ def parse_args():
|
||||
help="The distributed plan [colossalai, pytorch].",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp_degree",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--placement",
|
||||
type=str,
|
||||
default='cpu',
|
||||
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shardinit",
|
||||
type=bool,
|
||||
default=False,
|
||||
help=
|
||||
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
|
||||
"--offload_optim_frac",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Fraction of optimizer states to be offloaded. This is only used for gemini.",
|
||||
)
|
||||
parser.add_argument('-p',
|
||||
'--plugin',
|
||||
@@ -111,51 +95,6 @@ def get_model_size(model: nn.Module):
|
||||
return total_numel
|
||||
|
||||
|
||||
|
||||
|
||||
# Parameter Sharding Strategies for Tensor Parallelism
|
||||
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
|
||||
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
param.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(0, param, pg)
|
||||
|
||||
|
||||
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(-1, param, pg)
|
||||
|
||||
|
||||
# Tensor Parallel
|
||||
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
||||
"""tensor_parallelize
|
||||
Sharding the Model Parameters.
|
||||
Args:
|
||||
model (torch.nn.Module): a torch module to be sharded
|
||||
"""
|
||||
for mn, module in model.named_modules():
|
||||
for pn, param in module.named_parameters(recurse=False):
|
||||
if hasattr(param, 'visited'):
|
||||
continue
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
if 'net.0' in mn:
|
||||
split_param_col_tp1d(param, pg) # column slice
|
||||
elif 'to_q' in mn:
|
||||
split_param_col_tp1d(param, pg) # column slice
|
||||
elif 'to_kv' in mn:
|
||||
split_param_row_tp1d(param, pg) # row slice
|
||||
elif 'to_out' in mn:
|
||||
split_param_row_tp1d(param, pg) # row slice
|
||||
elif '1.1' in mn:
|
||||
split_param_col_tp1d(param, pg) # column slice
|
||||
elif '1.2' in mn:
|
||||
split_param_row_tp1d(param, pg) # row slice
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
param.visited = True
|
||||
|
||||
|
||||
args = parse_args()
|
||||
if args.distplan not in ["colossalai", "pytorch"]:
|
||||
raise TypeError(f"{args.distplan} is error")
|
||||
@@ -212,23 +151,18 @@ if args.distplan == "colossalai":
|
||||
if args.plugin.startswith('torch_ddp'):
|
||||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5)
|
||||
plugin = GeminiPlugin(offload_optim_frac=args.offload_optim_frac, initial_scale=2**5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
logger.info(f"plugin: {plugin}")
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
|
||||
default_pg = ProcessGroup(tp_degree=args.tp_degree)
|
||||
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
|
||||
ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg)
|
||||
ctx = LazyInitContext(default_device=get_current_device()) if args.plugin == 'gemini' else nullcontext()
|
||||
|
||||
with ctx:
|
||||
model = PaLM(num_tokens=50304, dim=4096, depth=64)
|
||||
model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
|
||||
|
||||
pg = default_pg
|
||||
tensor_parallelize(model, pg)
|
||||
|
||||
# optimizer
|
||||
|
||||
optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE, initial_scale=2**5)
|
||||
|
Reference in New Issue
Block a user