mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
Merge branch 'main' into feature/shardformer
This commit is contained in:
@@ -17,6 +17,13 @@ def data_gen_fn():
|
||||
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
def data_gen_for_pretrain():
|
||||
inputs = data_gen_fn()
|
||||
inputs['labels'] = inputs['input_ids'].clone()
|
||||
inputs['sentence_order_label'] = torch.zeros(BATCH_SIZE, dtype=torch.int64)
|
||||
return inputs
|
||||
|
||||
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
config = transformers.AlbertConfig(embedding_size=128,
|
||||
@@ -26,14 +33,14 @@ config = transformers.AlbertConfig(embedding_size=128,
|
||||
intermediate_size=256)
|
||||
|
||||
model_zoo.register(name='transformers_albert',
|
||||
model_fn=lambda: transformers.AlbertModel(config),
|
||||
model_fn=lambda: transformers.AlbertModel(config, add_pooling_layer=False),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_albert_for_pretraining',
|
||||
model_fn=lambda: transformers.AlbertForPreTraining(config),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
data_gen_fn=data_gen_for_pretrain,
|
||||
output_transform_fn=lambda x: dict(loss=x.loss),
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_albert_for_masked_lm',
|
||||
model_fn=lambda: transformers.AlbertForMaskedLM(config),
|
||||
|
@@ -113,6 +113,7 @@ def data_gen_for_qa():
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
# define loss funciton
|
||||
|
||||
loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state
|
||||
))
|
||||
loss_fn = lambda x: x.loss
|
||||
@@ -126,7 +127,7 @@ config = transformers.BertConfig(hidden_size=128,
|
||||
|
||||
# register the BERT variants
|
||||
model_zoo.register(name='transformers_bert',
|
||||
model_fn=lambda: transformers.BertModel(config),
|
||||
model_fn=lambda: transformers.BertModel(config, add_pooling_layer=False),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_bert_model,
|
||||
|
@@ -57,6 +57,12 @@ def data_gen_for_sequence_classification():
|
||||
return data
|
||||
|
||||
|
||||
def date_gen_for_double_heads():
|
||||
data = data_gen_for_lm()
|
||||
data['mc_labels'] = torch.zeros(data['input_ids'].shape[0], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
# define output transform function
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
@@ -94,8 +100,8 @@ model_zoo.register(name='transformers_gpt_lm',
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_gpt_double_heads',
|
||||
model_fn=lambda: transformers.GPT2DoubleHeadsModel(config),
|
||||
data_gen_fn=data_gen_for_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
data_gen_fn=date_gen_for_double_heads,
|
||||
output_transform_fn=lambda x: dict(loss=x.loss + x.mc_loss),
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_gpt_for_question_answering',
|
||||
|
@@ -12,19 +12,16 @@ from colossalai.lazy.lazy_init import LazyInitContext
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.zero import ColoInitContext
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
|
||||
try:
|
||||
if init_method == 'colo':
|
||||
ctx = ColoInitContext()
|
||||
elif init_method == 'lazy':
|
||||
if init_method == 'lazy':
|
||||
ctx = LazyInitContext()
|
||||
else:
|
||||
ctx = nullcontext()
|
||||
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, max_norm=1.0, initial_scale=2**5)
|
||||
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5)
|
||||
booster = Booster(plugin=plugin)
|
||||
with ctx:
|
||||
model = model_fn()
|
||||
@@ -50,6 +47,7 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[
|
||||
optimizer.step()
|
||||
|
||||
except Exception as e:
|
||||
# raise e
|
||||
return repr(e)
|
||||
|
||||
|
||||
@@ -57,8 +55,9 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[
|
||||
# @parameterize('init_method', ['lazy', 'none', 'colo'])
|
||||
|
||||
|
||||
@parameterize('subset', ['torchvision', 'transformers', 'diffusers'])
|
||||
@parameterize('init_method', ['none'])
|
||||
def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
|
||||
def check_gemini_plugin(subset: str, init_method: str = 'none', early_stop: bool = True):
|
||||
"""check gemini plugin over model zoo
|
||||
|
||||
Args:
|
||||
@@ -71,29 +70,23 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
|
||||
passed_models = []
|
||||
failed_info = {} # (model_name, error) pair
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.get_sub_registry(subset).items():
|
||||
# These models lead to CUDA error
|
||||
if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp',
|
||||
'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext'):
|
||||
'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext',
|
||||
'torchvision_convnext_base'):
|
||||
continue
|
||||
# These models are not compatible with gemini
|
||||
if name in [
|
||||
'diffusers_clip_vision_model', 'timm_resnet', 'timm_beit', 'timm_beitv2', 'timm_eca_nfnet',
|
||||
'timm_efficientformer', 'timm_hrnet_w18_small', 'timm_nf_ecaresnet101', 'timm_nf_regnet_b0',
|
||||
'timm_skresnet18', 'timm_wide_resnet50_2', 'timm_convit', 'timm_dm_nfnet', 'timm_swin_transformer',
|
||||
'torchaudio_conformer', 'torchaudio_deepspeech', 'torchaudio_wavernn', 'torchaudio_tacotron',
|
||||
'deepfm_interactionarch', 'deepfm_simpledeepfmnn', 'dlrm', 'dlrm_interactionarch',
|
||||
'torchvision_googlenet', 'torchvision_inception_v3', 'torchvision_mobilenet_v3_small',
|
||||
'torchvision_resnet18', 'torchvision_resnext50_32x4d', 'torchvision_wide_resnet50_2',
|
||||
'torchvision_vit_b_16', 'torchvision_convnext_base', 'torchvision_swin_s', 'transformers_albert',
|
||||
'transformers_albert_for_pretraining', 'transformers_bert', 'transformers_bert_for_pretraining',
|
||||
'transformers_gpt_double_heads', 'torchaudio_hubert_base', 'torchaudio_wav2vec2_base',
|
||||
'transformers_t5_for_conditional_generation', 'transformers_t5', 'transformers_t5_encoder_model',
|
||||
'transformers_vit', 'transformers_vit_for_masked_image_modeling',
|
||||
'transformers_vit_for_image_classification', 'transformers_chatglm',
|
||||
'transformers_chatglm_for_conditional_generation', 'transformers_blip2',
|
||||
'transformers_blip2_conditional_gerneration', 'transformers_sam', 'transformers_whisper',
|
||||
'transformers_whisper_for_conditional_generation', 'transformers_whisper_for_audio_classification'
|
||||
'timm_convit',
|
||||
'timm_dm_nfnet',
|
||||
'torchvision_vit_b_16',
|
||||
'transformers_t5',
|
||||
'transformers_t5_for_conditional_generation',
|
||||
'transformers_t5_encoder_model', # does not support apex rmsnorm
|
||||
'transformers_chatglm',
|
||||
'transformers_sam',
|
||||
'transformers_vit'
|
||||
]:
|
||||
continue
|
||||
|
||||
@@ -105,7 +98,6 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
|
||||
]:
|
||||
continue
|
||||
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if err is None:
|
||||
passed_models.append(name)
|
||||
|
@@ -18,12 +18,45 @@ from colossalai.testing import (
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
MODEL_PLACEMENT_CONFIGS = [
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0
|
||||
}, # zero2
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 1.0
|
||||
}, # zero3
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.5
|
||||
}, # zero3-half
|
||||
]
|
||||
|
||||
OPTIM_PLACEMENT_CONFIGS = [
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0,
|
||||
'offload_optim_frac': 0.0
|
||||
}, # zero2
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0,
|
||||
'offload_optim_frac': 1.0
|
||||
}, # zero2-offload
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0,
|
||||
'offload_optim_frac': 0.5
|
||||
}, # zero2-offload-half
|
||||
]
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||
@parameterize('placement_config', MODEL_PLACEMENT_CONFIGS)
|
||||
@parameterize('model_name', ['transformers_bert_for_sequence_classification'])
|
||||
@parameterize('use_safetensors', [False, True])
|
||||
def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: bool):
|
||||
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool):
|
||||
from transformers import BertForSequenceClassification
|
||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
bert_model = model_fn()
|
||||
@@ -32,7 +65,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
|
||||
pretrained_path = os.path.join(tempdir, 'pretrained')
|
||||
bert_model.config.save_pretrained(save_directory=pretrained_path)
|
||||
|
||||
plugin = GeminiPlugin(placement_policy=placement_policy)
|
||||
plugin = GeminiPlugin(**placement_config)
|
||||
booster = Booster(plugin=plugin)
|
||||
bert_model, _, _, _, _ = booster.boost(bert_model)
|
||||
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
|
||||
@@ -46,19 +79,19 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
|
||||
dist.barrier()
|
||||
|
||||
new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
|
||||
check_state_dict_equal(bert_model.unwrap().state_dict(only_rank_0=False, dtype=torch.float32),
|
||||
check_state_dict_equal(bert_model.state_dict(only_rank_0=False, dtype=torch.float32),
|
||||
new_bert_model.state_dict(), False)
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||
@parameterize('placement_config', OPTIM_PLACEMENT_CONFIGS)
|
||||
@parameterize('shard', [False, True])
|
||||
@parameterize('model_name', ['transformers_gpt'])
|
||||
@parameterize('size_per_shard', [32])
|
||||
def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_shard: int):
|
||||
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int):
|
||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
criterion = lambda x: x.mean()
|
||||
plugin = GeminiPlugin(placement_policy=placement_policy, precision="fp16", initial_scale=(2**14))
|
||||
plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14))
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
model = model_fn()
|
||||
@@ -87,12 +120,11 @@ def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_sha
|
||||
dist.barrier()
|
||||
|
||||
booster.load_model(new_model, model_ckpt_path)
|
||||
check_state_dict_equal(model.unwrap().state_dict(only_rank_0=False),
|
||||
new_model.unwrap().state_dict(only_rank_0=False), False)
|
||||
check_state_dict_equal(model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False)
|
||||
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.unwrap().state_dict(only_rank_0=False),
|
||||
new_optimizer.unwrap().state_dict(only_rank_0=False), False)
|
||||
check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False),
|
||||
False)
|
||||
|
||||
# Check the new model/optimizer can successfully run.
|
||||
data = data_gen_fn()
|
||||
|
@@ -60,12 +60,11 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str):
|
||||
new_booster.load_model(new_model, model_ckpt_path, strict=True)
|
||||
|
||||
# Add prefix to get aligned with pytorch parameter names.
|
||||
check_state_dict_equal(
|
||||
model.unwrap().state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32),
|
||||
new_model.state_dict(), False)
|
||||
check_state_dict_equal(model.state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32),
|
||||
new_model.state_dict(), False)
|
||||
|
||||
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.unwrap().state_dict(only_rank_0=False), new_optimizer.state_dict(), False)
|
||||
check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(), False)
|
||||
|
||||
# Check the new model/optimizer can successfully run.
|
||||
data = data_gen_fn()
|
||||
@@ -124,13 +123,12 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str):
|
||||
new_booster.load_model(new_model, model_ckpt_path, strict=True)
|
||||
|
||||
# Add prefix to get aligned with pytorch parameter names.
|
||||
check_state_dict_equal(
|
||||
new_model.unwrap().state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32),
|
||||
model.state_dict(), False)
|
||||
check_state_dict_equal(new_model.state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32),
|
||||
model.state_dict(), False)
|
||||
|
||||
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
old_state_dict = optimizer.state_dict()
|
||||
new_state_dict = new_optimizer.unwrap().state_dict(only_rank_0=False)
|
||||
new_state_dict = new_optimizer.state_dict(only_rank_0=False)
|
||||
|
||||
# Comparison of param_groups needs special care here,
|
||||
# since not all hyperparameters in Adam are used by HybridAdam
|
||||
@@ -138,7 +136,7 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str):
|
||||
for old_group, new_group in zip(old_state_dict['param_groups'], new_state_dict['param_groups']):
|
||||
for k in hyperparameters_to_examine:
|
||||
assert k in old_group and k in new_group, \
|
||||
f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}"
|
||||
f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}"
|
||||
assert old_group[k] == new_group[k]
|
||||
check_state_dict_equal(old_state_dict['state'], new_state_dict['state'], False)
|
||||
|
||||
|
@@ -16,19 +16,21 @@ from colossalai.testing import (
|
||||
)
|
||||
|
||||
|
||||
# stage 1 and 2 process the optimizer/mode the same way
|
||||
# only test 2 is fine
|
||||
@clear_cache_before_run()
|
||||
@parameterize('stage', [2])
|
||||
@parameterize('shard', [True, False])
|
||||
def check_low_level_zero_checkpointIO(stage: int, shard: bool):
|
||||
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32)
|
||||
@parameterize('offload', [False, True])
|
||||
def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
|
||||
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32, cpu_offload=offload)
|
||||
booster = Booster(plugin=plugin)
|
||||
model = resnet18()
|
||||
criterion = lambda x: x.mean()
|
||||
optimizer = HybridAdam((model.parameters()), lr=0.001)
|
||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||
|
||||
x = torch.randn(4, 3, 224, 224)
|
||||
x = x.to('cuda')
|
||||
x = torch.randn(1, 3, 224, 224, device='cuda')
|
||||
output = model(x)
|
||||
loss = criterion(output)
|
||||
booster.backward(loss, optimizer)
|
||||
@@ -50,15 +52,17 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool):
|
||||
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
|
||||
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
|
||||
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
check_low_level_zero_checkpointIO()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_low_level_zero_checkpointIO():
|
||||
spawn(run_dist, 2)
|
||||
|
||||
|
@@ -1,104 +0,0 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
|
||||
import colossalai
|
||||
from colossalai.amp import AMP_TYPE
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.engine.schedule._pipeline_schedule_v2 import PipelineScheduleV2
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn import CrossEntropyLoss
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.pipeline.pipelinable import PipelinableContext
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.trainer import Trainer, hooks
|
||||
from colossalai.utils import get_dataloader
|
||||
|
||||
disable_existing_loggers()
|
||||
BATCH_SIZE = 4
|
||||
NUM_EPOCHS = 10
|
||||
WARMUP_EPOCHS = 5
|
||||
CONFIG = dict(NUM_MICRO_BATCHES=2,
|
||||
parallel=dict(pipeline=2, tensor=dict(size=1, mode='1d')),
|
||||
fp16=dict(mode=AMP_TYPE.NAIVE),
|
||||
gradient_accumulation=2)
|
||||
|
||||
|
||||
def run_trainer(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
disable_existing_loggers()
|
||||
# get logger
|
||||
logger = get_dist_logger()
|
||||
|
||||
pipelinable = PipelinableContext()
|
||||
try:
|
||||
from titans.model.vit import vit_tiny_patch4_32
|
||||
except ImportError:
|
||||
logger.warning('skip the test_cifar_with_data_pipeline_tensor test because titan is not installed')
|
||||
logger.warning('please install titan from https://github.com/hpcaitech/Titans')
|
||||
return
|
||||
with pipelinable:
|
||||
model = vit_tiny_patch4_32()
|
||||
pipelinable.to_layer_list()
|
||||
pipelinable.policy = "uniform"
|
||||
model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
|
||||
|
||||
# create dataloaders
|
||||
root = Path(os.environ['DATA'])
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomCrop(32, padding=4, pad_if_needed=True),
|
||||
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
train_dataset = CIFAR10(root=root, train=True, download=True, transform=transform_train)
|
||||
train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True)
|
||||
|
||||
# create loss function
|
||||
criterion = CrossEntropyLoss(label_smoothing=0.1)
|
||||
|
||||
# create optimizer
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0)
|
||||
|
||||
# create lr scheduler
|
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=NUM_EPOCHS, warmup_steps=WARMUP_EPOCHS)
|
||||
|
||||
# initialize
|
||||
engine, train_dataloader, *_ = colossalai.initialize(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader)
|
||||
|
||||
engine._schedule = PipelineScheduleV2(num_microbatches=gpc.config.NUM_MICRO_BATCHES)
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
trainer = Trainer(engine=engine, logger=logger)
|
||||
|
||||
hook_list = [
|
||||
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
|
||||
]
|
||||
|
||||
trainer.fit(train_dataloader=train_dataloader,
|
||||
max_steps=2,
|
||||
epochs=NUM_EPOCHS,
|
||||
hooks=hook_list,
|
||||
display_progress=True)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_hybrid_parallel():
|
||||
spawn(run_trainer, 2)
|
||||
disable_existing_loggers()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_hybrid_parallel()
|
@@ -1,92 +0,0 @@
|
||||
import os
|
||||
import random
|
||||
from typing import Callable, Type
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.parallel import ColoDDP
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext, ZeroDDP
|
||||
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
random.seed(seed)
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
def init_ddp(module: torch.nn.Module) -> ColoDDP:
|
||||
pg = ProcessGroup()
|
||||
return ColoDDP(module, process_group=pg)
|
||||
|
||||
|
||||
def init_ddpv2(module: torch.nn.Module) -> ZeroDDP:
|
||||
chunk_config, *_ = search_chunk_configuration(module, 4, 1024)
|
||||
chunk_manager = ChunkManager(chunk_config)
|
||||
gemini_manager = GeminiManager('cuda', chunk_manager)
|
||||
return ZeroDDP(module, gemini_manager)
|
||||
|
||||
|
||||
class Net(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(3, 3, bias=False)
|
||||
self.fc2 = torch.nn.Linear(3, 1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc2(self.fc1(x))
|
||||
|
||||
|
||||
def run_fwd_bwd(ddp_cls: Type[ColoDDP], init_ddp_func: Callable[[torch.nn.Module], ColoDDP]):
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = Net().cuda()
|
||||
w1 = model.fc1.weight
|
||||
w2 = model.fc2.weight
|
||||
ddp_cls.set_params_to_ignore([w2])
|
||||
model = init_ddp_func(model)
|
||||
x = torch.rand(2, 3, device=get_current_device())
|
||||
logits = model(x)
|
||||
loss = torch.sum(logits)
|
||||
model.backward(loss)
|
||||
|
||||
if ddp_cls is ZeroDDP:
|
||||
w1s_grad = w1
|
||||
else:
|
||||
w1s_grad = w1.grad
|
||||
|
||||
w1_grads = [torch.empty_like(w1) for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(w1_grads, w1s_grad)
|
||||
assert torch.equal(w1_grads[0], w1_grads[1])
|
||||
w2_grads = [torch.empty_like(w2) for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(w2_grads, w2.grad)
|
||||
assert not torch.equal(w2_grads[0], w2_grads[1])
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
set_seed(dist.get_rank())
|
||||
run_fwd_bwd(ColoDDP, init_ddp)
|
||||
run_fwd_bwd(ZeroDDP, init_ddpv2)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_ddp_ignore_params(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_ddp_ignore_params(2)
|
@@ -1,67 +0,0 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.parallel import ColoDDP
|
||||
from colossalai.tensor import ColoParameter, ProcessGroup
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict):
|
||||
for (k1, t1), (k2, t2) in zip(state_dict.items(), other_state_dict.items()):
|
||||
assert k1 == k2
|
||||
|
||||
if t1.device != t2.device:
|
||||
temp_t2 = t2.to(t1.device)
|
||||
else:
|
||||
temp_t2 = t2
|
||||
|
||||
assert torch.equal(t1, temp_t2), "\t{}\n\t{}".format(t1, temp_t2)
|
||||
|
||||
|
||||
def init_ddp(module: torch.nn.Module) -> ColoDDP:
|
||||
pg = ProcessGroup()
|
||||
return ColoDDP(module, process_group=pg)
|
||||
|
||||
|
||||
def run_ddp_state_dict():
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
torch_model = model_builder().cuda()
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
model = init_ddp(model)
|
||||
torch_state_dict = torch_model.state_dict()
|
||||
|
||||
for param in model.parameters():
|
||||
if isinstance(param, ColoParameter):
|
||||
assert param.get_process_group() is not None
|
||||
model.load_state_dict(torch_state_dict)
|
||||
|
||||
for param in model.parameters():
|
||||
if isinstance(param, ColoParameter):
|
||||
assert param.get_process_group() is not None
|
||||
|
||||
state_dict = model.state_dict()
|
||||
check_state_dict_equal(torch_state_dict, state_dict)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_ddp_state_dict()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_state_dict(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_state_dict(2)
|
@@ -1,47 +0,0 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.parallel.reducer import Reducer
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
REDUCE_CNT = 0
|
||||
|
||||
|
||||
def check_eq(grad, grad_clone):
|
||||
global REDUCE_CNT
|
||||
print(f'Rank{dist.get_rank()} check {REDUCE_CNT}')
|
||||
REDUCE_CNT += 1
|
||||
assert torch.allclose(grad, grad_clone)
|
||||
|
||||
|
||||
def run_reducer():
|
||||
grads = [torch.rand(64, i + 1, device=get_current_device()) for i in range(10)]
|
||||
grads_clone = [g.clone().detach() for g in grads]
|
||||
for g in grads:
|
||||
dist.all_reduce(g)
|
||||
reducer = Reducer(bucket_size_mb=1)
|
||||
for g, g_clone in zip(grads, grads_clone):
|
||||
reducer.all_reduce_async(g_clone, _get_default_group(), partial(check_eq, g))
|
||||
reducer.flush()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_reducer()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_reducer(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_reducer(2)
|
@@ -1,73 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
class Conv1D(nn.Module):
|
||||
"""
|
||||
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
|
||||
Basically works like a linear layer but the weights are transposed.
|
||||
Args:
|
||||
nf (`int`): The number of output features.
|
||||
nx (`int`): The number of input features.
|
||||
"""
|
||||
|
||||
def __init__(self, nf, nx):
|
||||
super().__init__()
|
||||
self.nf = nf
|
||||
w = torch.empty(nx, nf)
|
||||
nn.init.normal_(w, std=0.02)
|
||||
self.weight = nn.Parameter(w)
|
||||
self.bias = nn.Parameter(torch.ones(nf))
|
||||
|
||||
def forward(self, x):
|
||||
size_out = x.size()[:-1] + (self.nf,)
|
||||
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
|
||||
x = x.view(size_out)
|
||||
return x
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func, split_bias):
|
||||
model = Conv1D(4, 16).cuda()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
|
||||
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()), ColoTensorSpec(pg))
|
||||
bias = ColoTensor(torch.nn.Parameter(model.bias.detach()), ColoTensorSpec(pg))
|
||||
|
||||
spec_init_func(weight, pg)
|
||||
if split_bias:
|
||||
spec_init_func(bias, pg)
|
||||
|
||||
x = torch.rand(2, 16).cuda()
|
||||
out = model(x)
|
||||
colo_out = torch.addmm(bias, x, weight)
|
||||
colo_out = colo_out.to_replicate()
|
||||
assert tensor_equal(out, colo_out)
|
||||
grad = torch.rand_like(out)
|
||||
out.backward(grad)
|
||||
colo_out.backward(grad)
|
||||
tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
tensor_shard_equal(model.bias.grad, bias.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_with_spec(spec_init_func=split_param_row_tp1d, split_bias=False)
|
||||
run_with_spec(spec_init_func=split_param_col_tp1d, split_bias=True)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_addmm_1d(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_addmm_1d(4)
|
@@ -1,43 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoParameter, ColoTensorSpec, ProcessGroup
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.test_tensor.common_utils import split_param_col_tp1d, tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func):
|
||||
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
|
||||
model = torch.nn.EmbeddingBag(10, 4).cuda()
|
||||
weight = ColoParameter(model.weight.clone(), True, ColoTensorSpec(pg))
|
||||
|
||||
spec_init_func(weight, pg)
|
||||
|
||||
inputs = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]).cuda()
|
||||
offsets = torch.tensor([0, 4]).cuda()
|
||||
out = model(inputs, offsets=offsets)
|
||||
colo_out = F.embedding_bag(inputs, weight, offsets=offsets)
|
||||
assert tensor_equal(out, colo_out)
|
||||
grad = torch.rand_like(out)
|
||||
out.backward(grad)
|
||||
colo_out.backward(grad)
|
||||
assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_with_spec(split_param_col_tp1d)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_embedding_bag_1d(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_embedding_bag_1d(4)
|
@@ -1,44 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func, pg: ProcessGroup):
|
||||
model = torch.nn.Embedding(12, 32).cuda()
|
||||
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()), ColoTensorSpec(pg))
|
||||
|
||||
spec_init_func(weight, pg)
|
||||
|
||||
x = torch.tensor((0, 3, 6, 9)).cuda()
|
||||
out = model(x)
|
||||
colo_out = F.embedding(x, weight)
|
||||
assert tensor_equal(out, colo_out)
|
||||
grad = torch.rand_like(out)
|
||||
out.backward(grad)
|
||||
colo_out.backward(grad)
|
||||
# compare grad inside a TP group
|
||||
assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
# config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
run_with_spec(split_param_row_tp1d, pg)
|
||||
run_with_spec(split_param_col_tp1d, pg)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_embedding_1d(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_embedding_1d(4)
|
@@ -1,48 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func, split_bias):
|
||||
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
|
||||
model = torch.nn.Linear(4, 8).cuda()
|
||||
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()), ColoTensorSpec(pg))
|
||||
bias = ColoTensor(torch.nn.Parameter(model.bias.detach()), ColoTensorSpec(pg))
|
||||
|
||||
spec_init_func(weight, pg)
|
||||
if split_bias:
|
||||
spec_init_func(bias, pg)
|
||||
|
||||
x = torch.rand(2, 4).cuda()
|
||||
out = model(x)
|
||||
colo_out = F.linear(x, weight, bias)
|
||||
colo_out = colo_out.to_replicate()
|
||||
assert tensor_equal(out, colo_out)
|
||||
grad = torch.rand_like(out)
|
||||
out.backward(grad)
|
||||
colo_out.backward(grad)
|
||||
assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
assert tensor_shard_equal(model.bias.grad, bias.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_with_spec(spec_init_func=split_param_col_tp1d, split_bias=False)
|
||||
run_with_spec(spec_init_func=split_param_row_tp1d, split_bias=True)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_linear_1d(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_linear_1d(4)
|
@@ -1,48 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def check_cross_entropy():
|
||||
input_t = torch.randn(4, 4, device=get_current_device(), requires_grad=True)
|
||||
input_ct = torch.randn(4, 4, device=get_current_device(), requires_grad=True)
|
||||
with torch.no_grad():
|
||||
input_ct.copy_(input_t)
|
||||
|
||||
target = torch.randint(4, (4,), dtype=torch.int64, device=get_current_device())
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg))
|
||||
input_shard = input_t_colo.redistribute(ShardSpec([-1], [pg.tp_world_size()]))
|
||||
input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D))
|
||||
|
||||
output = F.cross_entropy(input_t, target)
|
||||
output_colo = F.cross_entropy(input_shard, target)
|
||||
assert torch.allclose(output_colo, output)
|
||||
|
||||
output.backward()
|
||||
output_colo.backward()
|
||||
|
||||
assert torch.allclose(input_t.grad, input_ct.grad)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
check_cross_entropy()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_loss_func(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_loss_func(1)
|
@@ -1,87 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ShardSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def _run_layer_norm():
|
||||
ln_op = torch.nn.LayerNorm(2, 3, device=get_current_device())
|
||||
|
||||
input_t = torch.randn(3, 2, device=get_current_device())
|
||||
|
||||
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
|
||||
input_t_colo = ColoTensor.from_torch_tensor(input_t.clone().detach(), ColoTensorSpec(pg))
|
||||
|
||||
# prepare colossalai LN
|
||||
weight = ColoTensor(Parameter(ln_op.weight.detach()), ColoTensorSpec(pg))
|
||||
bias = ColoTensor(Parameter(ln_op.bias.detach()), ColoTensorSpec(pg))
|
||||
|
||||
output = ln_op(input_t)
|
||||
output_colo = F.layer_norm(input_t_colo, ln_op.normalized_shape, weight, bias, ln_op.eps)
|
||||
|
||||
assert torch.allclose(output_colo, output)
|
||||
|
||||
torch.mean(output).backward()
|
||||
torch.mean(output_colo).backward()
|
||||
|
||||
assert torch.allclose(ln_op.weight.grad, weight.grad)
|
||||
|
||||
|
||||
def check_spec_eq(tensor, other):
|
||||
assert isinstance(tensor, ColoTensor) and isinstance(other, ColoTensor)
|
||||
for k in dir(tensor.dist_spec):
|
||||
if not k.startswith('__'):
|
||||
assert hasattr(other.dist_spec, k), f"{k}"
|
||||
assert getattr(tensor.dist_spec, k) == getattr(other.dist_spec, k)
|
||||
|
||||
|
||||
def check_element_wise_ops():
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
t = torch.rand(2, 2)
|
||||
x = ColoTensor(t, spec=ColoTensorSpec(pg, ShardSpec([0], [pg.tp_world_size()])))
|
||||
|
||||
check_spec_eq(x, x.cuda())
|
||||
assert torch.equal(x.cuda(), t.cuda())
|
||||
check_spec_eq(x, torch.abs(x))
|
||||
assert torch.equal(torch.abs(x), torch.abs(t))
|
||||
check_spec_eq(x, F.sigmoid(x))
|
||||
assert torch.equal(F.sigmoid(x), F.sigmoid(t))
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
check_element_wise_ops()
|
||||
_run_layer_norm()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_element_wise_ops(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
def run_dist2(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
_run_layer_norm()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_ln(world_size):
|
||||
spawn(run_dist2, world_size)
|
||||
|
||||
|
||||
def check_all():
|
||||
test_element_wise_ops(2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
check_all()
|
@@ -1,97 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ShardSpec
|
||||
from colossalai.tensor.distspec import DistPlacementPattern
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_tensor.common_utils import debug_print, split_param_col_tp1d, split_param_row_tp1d
|
||||
|
||||
|
||||
def exam_view_core(pg):
|
||||
# the case of replicated ColoTensors
|
||||
x = torch.randn(4, 4).cuda()
|
||||
x_colo = ColoTensor(x, ColoTensorSpec(pg))
|
||||
|
||||
y = x.view(2, -1, 2)
|
||||
y_colo = x_colo.view(2, -1, 2)
|
||||
|
||||
assert torch.all(y == y_colo)
|
||||
assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE
|
||||
# the perfect case of col-sliced ColoTensors
|
||||
split_param_col_tp1d(x_colo, pg)
|
||||
|
||||
z = x.view(torch.Size((2, 1, 2, -1)))
|
||||
z_colo = x_colo.view(torch.Size((2, 1, 2, -1)))
|
||||
if dist.get_rank() == 0:
|
||||
z = z[:, :, :, 0:2]
|
||||
else:
|
||||
z = z[:, :, :, 2:]
|
||||
assert torch.all(z == z_colo)
|
||||
assert z_colo.dist_spec == x_colo.dist_spec
|
||||
# the perfect case of row-sliced ColoTensors
|
||||
split_param_row_tp1d(x_colo, pg)
|
||||
|
||||
z = x.view(torch.Size((-1, 2, 2)))
|
||||
z_colo = x_colo.view(torch.Size((-1, 2, 2)))
|
||||
if dist.get_rank() == 0:
|
||||
z = z[0:2, :, :]
|
||||
else:
|
||||
z = z[2:, :, :]
|
||||
assert torch.all(z == z_colo)
|
||||
assert z_colo.dist_spec == x_colo.dist_spec
|
||||
# the normal case of row-sliced ColoTensors
|
||||
z = x.view(-1, 2, 2, 2)
|
||||
z_colo = x_colo.view(-1, 2, 2, 2)
|
||||
assert torch.all(z == z_colo)
|
||||
assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE
|
||||
|
||||
|
||||
def exam_view_autograd(pg):
|
||||
x = torch.randn(8, 2, device=get_current_device(), requires_grad=True)
|
||||
y = torch.randn(8, 2, device=get_current_device(), requires_grad=True)
|
||||
with torch.no_grad():
|
||||
y.copy_(x)
|
||||
y = ColoTensor(y, ColoTensorSpec(pg))
|
||||
y_slice = y.redistribute(ShardSpec([-1], [pg.tp_world_size()]))
|
||||
|
||||
xx = x.view(2, 2, -1)
|
||||
yy_slice = y_slice.view(2, 2, -1)
|
||||
yy = yy_slice.to_replicate()
|
||||
grad = torch.randn(2, 2, 4, device=get_current_device())
|
||||
|
||||
xx.backward(grad)
|
||||
yy.backward(grad)
|
||||
assert torch.all(x.grad == y.grad)
|
||||
|
||||
|
||||
def exam_view_errors(pg):
|
||||
x = torch.randn(8, 2, device=get_current_device())
|
||||
x = ColoTensor(x, ColoTensorSpec(pg))
|
||||
split_param_row_tp1d(x, pg)
|
||||
|
||||
x.view('a', 'b', 'c')
|
||||
x.view(8, -1)
|
||||
x.view([-2, -2, -2])
|
||||
x.view((-1, -1, -1))
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
|
||||
exam_view_core(pg)
|
||||
exam_view_autograd(pg)
|
||||
# exam_view_errors(pg)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_view(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_view(2)
|
@@ -1,3 +1,4 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.pipeline.pipelinable import PipelinableContext
|
||||
@@ -48,6 +49,7 @@ def run_pipelinable(rank, world_size, port):
|
||||
assert layers_count_in_part_0 + layers_count_in_part_1 == pipelinable.layers_count
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="this is useless")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_pipelinable():
|
||||
spawn(run_pipelinable, 1)
|
||||
|
@@ -219,6 +219,7 @@ def check_gpt2_3d(rank, world_size, port):
|
||||
run_gpt2_3d_test()
|
||||
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
|
@@ -1,153 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
from numpy import allclose
|
||||
|
||||
import colossalai
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ReplicaSpec, ShardSpec, distspec
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def _run_tensor_indexing():
|
||||
pg = ProcessGroup()
|
||||
torch_t = torch.randn(2, 3)
|
||||
colo_t = ColoTensor(torch_t, ColoTensorSpec(pg))
|
||||
assert allclose(torch_t[:, 1], colo_t[:, 1])
|
||||
|
||||
|
||||
def _run_wrapped_tensor_func():
|
||||
pg = ProcessGroup()
|
||||
t_ref = torch.randn(4, 5)
|
||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg))
|
||||
|
||||
# non-func attr
|
||||
assert t.is_cuda == t_ref.is_cuda
|
||||
|
||||
# return 1 torch.Tensor
|
||||
t_abs = t.abs()
|
||||
assert isinstance(t_abs, ColoTensor) and torch.equal(t_abs, t_ref.abs())
|
||||
|
||||
# return 1 non-torch.Tensor
|
||||
assert t.dim() == t_ref.dim()
|
||||
|
||||
# return >1 torch.Tensor
|
||||
assert isinstance(t, ColoTensor)
|
||||
t_split1, t_split2 = t.split(2)
|
||||
assert isinstance(t_split1, ColoTensor) and isinstance(t_split2, ColoTensor), f"{type(t_split1)} {type(t_split2)}"
|
||||
|
||||
|
||||
def _run_operand(world_size):
|
||||
pg = ProcessGroup()
|
||||
t_ref = torch.randn(4, 5)
|
||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg))
|
||||
|
||||
t_ref_res = t_ref + t_ref
|
||||
t_res = t + t
|
||||
|
||||
assert isinstance(t_res, ColoTensor)
|
||||
assert torch.allclose(t_ref_res, t_res)
|
||||
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg))
|
||||
t.set_dist_spec(ShardSpec([0], [world_size]))
|
||||
t_new = torch.zeros_like(t)
|
||||
assert isinstance(t_new, ColoTensor)
|
||||
assert t_new.is_sharded()
|
||||
|
||||
|
||||
#### Test Distributed init a Colotensor
|
||||
|
||||
|
||||
def _run_view(world_size):
|
||||
t_ref = torch.randn(4, 5)
|
||||
rank = gpc.get_global_rank()
|
||||
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
|
||||
t = ColoTensor.from_torch_tensor(
|
||||
t_ref, ColoTensorSpec(pg, dist_attr=ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()])))
|
||||
|
||||
assert t.size_global()[0] == 4 * world_size
|
||||
assert t.size_global(1) == 5
|
||||
assert t.size_global() == torch.Size([4 * world_size, 5])
|
||||
|
||||
t = t.view(4 * 5 * world_size)
|
||||
assert t.shape == torch.Size([4 * 5 * world_size])
|
||||
|
||||
|
||||
def _run_tensor_shard_init(world_size):
|
||||
t_ref = torch.randn(4, 5)
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
shard_attr = ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()])
|
||||
tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr)
|
||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||
t.set_dist_spec(ReplicaSpec())
|
||||
|
||||
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape} vs ({4 * world_size, 5})"
|
||||
|
||||
|
||||
def _run_tensor_replicated_init(world_size):
|
||||
t_ref = torch.randn(4 * world_size, 5)
|
||||
pg = ProcessGroup()
|
||||
spec = ColoTensorSpec(pg)
|
||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), spec)
|
||||
|
||||
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape}"
|
||||
|
||||
|
||||
def _run_process_group(world_size):
|
||||
pg1 = ProcessGroup()
|
||||
pg2 = ProcessGroup()
|
||||
assert pg1 == pg2
|
||||
|
||||
|
||||
def _run_redistributed(world_size):
|
||||
if world_size != 4:
|
||||
return
|
||||
pg1 = ProcessGroup(tp_degree=2, dp_degree=2)
|
||||
pg2 = ProcessGroup(tp_degree=4, dp_degree=1)
|
||||
|
||||
spec1 = ColoTensorSpec(pg1)
|
||||
t1 = ColoTensor.from_torch_tensor(torch.randn(2, 3, 4), spec1)
|
||||
t1 = t1.redistribute(ShardSpec([0], [pg1.tp_world_size()]))
|
||||
assert t1.is_sharded()
|
||||
t1 = t1.redistribute(ShardSpec([-1], [pg2.tp_world_size()]), pg2)
|
||||
assert t1.is_sharded()
|
||||
pg3 = ProcessGroup(tp_degree=1, dp_degree=4)
|
||||
t1 = t1.redistribute(ReplicaSpec(), pg3)
|
||||
assert t1.is_replicate()
|
||||
|
||||
|
||||
def _run_set_tensor_spec(world_size):
|
||||
if world_size != 4:
|
||||
return
|
||||
pg = ProcessGroup(tp_degree=2, dp_degree=2)
|
||||
spec1 = ColoTensorSpec(pg)
|
||||
t1 = ColoTensor.from_torch_tensor(torch.randn(2, 3, 4), spec1)
|
||||
|
||||
dist_spec2 = ShardSpec([-1], [pg.tp_world_size()])
|
||||
assert t1.is_replicate()
|
||||
t1.set_dist_spec(dist_spec2)
|
||||
assert t1.is_shard_1dcol()
|
||||
|
||||
|
||||
def run_dist_tests(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
_run_tensor_shard_init(world_size)
|
||||
_run_tensor_replicated_init(world_size)
|
||||
_run_view(world_size)
|
||||
_run_process_group(world_size)
|
||||
_run_tensor_indexing()
|
||||
_run_operand(world_size)
|
||||
_run_wrapped_tensor_func()
|
||||
_run_redistributed(world_size)
|
||||
_run_set_tensor_spec(world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_dist_cases(world_size):
|
||||
spawn(run_dist_tests, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_dist_cases(4)
|
@@ -1,148 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.parallel.data_parallel import ColoDDP
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_tensor.common_utils import (
|
||||
debug_print,
|
||||
set_seed,
|
||||
split_param_col_tp1d,
|
||||
split_param_row_tp1d,
|
||||
tensor_equal,
|
||||
tensor_shard_equal,
|
||||
)
|
||||
|
||||
|
||||
def init_1d_row_spec(model, pg: ProcessGroup):
|
||||
tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
for n, p in model.named_parameters():
|
||||
p.set_process_group(pg)
|
||||
if 'weight' in n and 'ln' not in n:
|
||||
p.set_tensor_spec(*tensor_spec)
|
||||
|
||||
|
||||
def init_1d_col_spec(model, pg: ProcessGroup):
|
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
|
||||
for n, p in model.named_parameters():
|
||||
p.set_process_group(pg)
|
||||
if 'ln' not in n and ('weight' in n or 'bias' in n):
|
||||
p.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def init_megatron_spec(model, pg: ProcessGroup):
|
||||
for mn, module in model.named_modules():
|
||||
# debug_print([0], mn)
|
||||
for pn, param in module.named_parameters(recurse=False):
|
||||
# debug_print([0], '\t', pn, param.compute_spec, param.shape)
|
||||
param.set_process_group(pg)
|
||||
|
||||
if 'mlp.c_fc' in mn:
|
||||
if 'weight' in pn or 'bias' in pn:
|
||||
split_param_col_tp1d(param, pg)
|
||||
param.compute_spec.set_output_replicate(False)
|
||||
else:
|
||||
raise RuntimeError
|
||||
elif 'mlp.c_proj' in mn:
|
||||
if 'weight' in pn:
|
||||
split_param_row_tp1d(param, pg)
|
||||
else:
|
||||
assert 'bias' in pn
|
||||
elif 'wte' in mn or 'wpe' in mn:
|
||||
assert 'weight' in pn
|
||||
split_param_col_tp1d(param, pg)
|
||||
elif 'c_attn' in mn or 'c_proj' in mn:
|
||||
split_param_col_tp1d(param, pg)
|
||||
# debug_print([0], '\t', param.compute_spec, param.shape)
|
||||
|
||||
|
||||
def check_param_equal(model, torch_model, pg: ProcessGroup):
|
||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||
assert pg.tp_local_rank() is not None, f"{pg.rank()} {pg.tp_world_size()} {pg._tp_degree} {pg.tp_local_rank()}1"
|
||||
assert pg.tp_world_size() is not None
|
||||
assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def check_grad_equal(model, torch_model, pg: ProcessGroup):
|
||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||
assert tensor_shard_equal(torch_p.grad, p.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def run_gpt(init_spec_func, use_ddp):
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
# build a PG with TP and DP hybrid
|
||||
pg = ProcessGroup(dp_degree=(2 if (use_ddp and world_size >= 2) else 1))
|
||||
|
||||
# set seed make processes of the same tp group use the same seed
|
||||
# set_seed(pg.tp_local_rank())
|
||||
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
# make sure torch_model and model has the same parameter values
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
model = model.cuda()
|
||||
torch_model = model_builder().cuda()
|
||||
|
||||
if use_ddp:
|
||||
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
|
||||
model = ColoDDP(model, process_group=pg)
|
||||
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
torch_p.data.copy_(p)
|
||||
|
||||
init_spec_func(model, pg)
|
||||
|
||||
check_param_equal(model, torch_model, pg)
|
||||
|
||||
# close the dropout in eval mode
|
||||
model.eval()
|
||||
torch_model.eval()
|
||||
set_seed(pg.dp_local_rank())
|
||||
torch.distributed.barrier()
|
||||
for i, (input_ids, label) in enumerate(train_dataloader):
|
||||
colo_input = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg))
|
||||
logits = model(colo_input)
|
||||
torch_logits = torch_model(input_ids)
|
||||
assert tensor_equal(torch_logits, logits), f"{torch_logits - logits}"
|
||||
loss = criterion(logits, input_ids)
|
||||
torch_loss = criterion(torch_logits, input_ids)
|
||||
if use_ddp:
|
||||
model.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
torch_loss.backward()
|
||||
check_grad_equal(model, torch_model, pg)
|
||||
if i > 0:
|
||||
break
|
||||
set_seed(313)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, use_ddp):
|
||||
if use_ddp and world_size == 1:
|
||||
return
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
# Comments below tests for speed concern
|
||||
# run_gpt(init_1d_row_spec, use_ddp)
|
||||
# run_gpt(init_1d_col_spec, use_ddp)
|
||||
run_gpt(init_megatron_spec, use_ddp)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@pytest.mark.parametrize('use_ddp', [False, True])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gpt(world_size, use_ddp):
|
||||
spawn(run_dist, world_size, use_ddp=use_ddp)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_gpt(4, use_ddp=False)
|
@@ -1,334 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.tensor import ColoTensor, ProcessGroup
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.testing import free_port, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_tensor.common_utils import (
|
||||
check_equal,
|
||||
set_seed,
|
||||
split_param_col_tp1d,
|
||||
split_param_row_tp1d,
|
||||
tensor_shard_equal,
|
||||
)
|
||||
|
||||
|
||||
def run_1d_hybrid_tp(model_name):
|
||||
# A simple net with two stacked nn.Linear
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
set_seed(1)
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder(checkpoint=True)
|
||||
|
||||
if rank == 0:
|
||||
model_torch = model_builder(checkpoint=True)
|
||||
model_torch = model_torch.cuda()
|
||||
|
||||
optimizer_torch = ColossalaiOptimizer(torch.optim.SGD(model_torch.parameters(), lr=0.1))
|
||||
|
||||
# Make two models have the same init params
|
||||
for p1, p2 in zip(model.parameters(), model_torch.parameters()):
|
||||
p2.data.copy_(p1.data)
|
||||
else:
|
||||
model_torch = None
|
||||
optimizer_torch = None
|
||||
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
if 'bert' == model_name:
|
||||
for name, p in model.named_parameters():
|
||||
if not isinstance(p, ColoTensor):
|
||||
continue
|
||||
|
||||
# num_class = type_vocab_size = 2 | (8, 2)
|
||||
if 'classifier' in name and 'weight' in name:
|
||||
split_param_col_tp1d(p, pg)
|
||||
# num_class = vocab_size = 30524 | (30524, 8)
|
||||
elif 'word_embeddings' in name and 'weight' in name:
|
||||
split_param_row_tp1d(p, pg)
|
||||
# num_class = seq_len = 512 | (512, 8)
|
||||
elif 'position_embeddings' in name and 'weight' in name:
|
||||
split_param_row_tp1d(p, pg)
|
||||
# num_class = type_vocab_size = 2 | (2, 8)
|
||||
elif 'token_type_embeddings' in name and 'weight' in name:
|
||||
split_param_col_tp1d(p, pg)
|
||||
|
||||
elif "simple_net" == model_name:
|
||||
# A naive way to set spec for all weights in Linear
|
||||
for name, p in model.named_parameters():
|
||||
if not isinstance(p, ColoTensor):
|
||||
continue
|
||||
if 'embed' in name and 'weight' in name:
|
||||
split_param_col_tp1d(p, pg)
|
||||
if 'proj1' in name and ('weight' in name or 'bias' in name):
|
||||
split_param_row_tp1d(p, pg)
|
||||
if 'proj2' in name and 'weight' in name:
|
||||
split_param_col_tp1d(p, pg)
|
||||
if 'classifier' in name and ('weight' in name or 'bias' in name):
|
||||
split_param_row_tp1d(p, pg)
|
||||
|
||||
model = model.cuda()
|
||||
model.eval()
|
||||
if rank == 0:
|
||||
model_torch.eval()
|
||||
|
||||
colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.parameters(), lr=0.1))
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
|
||||
# Zero grad
|
||||
colo_optimizer.zero_grad()
|
||||
if rank == 0:
|
||||
optimizer_torch.zero_grad()
|
||||
torch.distributed.barrier()
|
||||
|
||||
data = data.to(get_current_device())
|
||||
label = label.to(get_current_device())
|
||||
|
||||
torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
|
||||
torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
|
||||
|
||||
# Bcast rank0 data to all processes
|
||||
if criterion:
|
||||
output = model(data)
|
||||
loss = criterion(output, label)
|
||||
else:
|
||||
output = model(data, label)
|
||||
loss = output
|
||||
|
||||
# Test output
|
||||
if rank == 0:
|
||||
if criterion:
|
||||
output_torch = model_torch(data)
|
||||
loss_torch = criterion(output_torch, label)
|
||||
else:
|
||||
output_torch = model_torch(data, label)
|
||||
loss_torch = output_torch
|
||||
assert torch.allclose(loss, loss_torch, rtol=1e-2), f"model_name {model_name} failed"
|
||||
torch.distributed.barrier()
|
||||
|
||||
loss.backward()
|
||||
colo_optimizer.step()
|
||||
|
||||
if rank == 0:
|
||||
loss_torch.backward()
|
||||
optimizer_torch.step()
|
||||
|
||||
with torch.no_grad():
|
||||
# check param
|
||||
for p, torch_p in zip(model.parameters(), model_torch.parameters()):
|
||||
assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size())
|
||||
torch.distributed.barrier()
|
||||
if i > 5:
|
||||
break
|
||||
|
||||
|
||||
# Test the overrided parameters() and named_parameters() member functions
|
||||
def test_model_parameters():
|
||||
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||
|
||||
# build a module with 2 Linear, 4 parameters in total.
|
||||
class Net(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fcs = torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.Linear(3, 2))
|
||||
self.extra_param = torch.nn.Parameter(torch.randn(2))
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = Net()
|
||||
|
||||
param_cnt = 0
|
||||
for name, p in model.named_parameters():
|
||||
param_cnt += 1
|
||||
assert param_cnt == 5
|
||||
|
||||
for name, colo_p in model.named_parameters():
|
||||
assert colo_p.is_model_data()
|
||||
|
||||
param_cnt = 0
|
||||
for name, p in model.named_parameters(recurse=False):
|
||||
param_cnt += 1
|
||||
assert param_cnt == 1
|
||||
|
||||
param_cnt = 0
|
||||
for p in model.fcs[0].parameters(recurse=False):
|
||||
param_cnt += 1
|
||||
assert param_cnt == 2
|
||||
|
||||
|
||||
def test_colo_optimizer():
|
||||
get_components_func = non_distributed_component_funcs.get_callable('simple_net')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
set_seed(1)
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder(checkpoint=True)
|
||||
|
||||
colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.parameters(), lr=0.1))
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
colo_optimizer.zero_grad()
|
||||
data = data.to(get_current_device())
|
||||
label = label.to(get_current_device())
|
||||
|
||||
# Bcast rank0 data to all processes
|
||||
if criterion:
|
||||
output = model(data)
|
||||
loss = criterion(output, label)
|
||||
else:
|
||||
output = model(data, label)
|
||||
loss = output
|
||||
|
||||
loss.backward()
|
||||
colo_optimizer.step()
|
||||
|
||||
if i > 5:
|
||||
break
|
||||
|
||||
|
||||
def run_1d_row_tp(model_name: str):
|
||||
# A simple net with two stacked nn.Linear
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
rank = torch.distributed.get_rank()
|
||||
|
||||
set_seed(1)
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder(checkpoint=True)
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
|
||||
set_seed(1)
|
||||
if rank == 0:
|
||||
model_torch = model_builder(checkpoint=True)
|
||||
model_torch = model_torch.cuda()
|
||||
|
||||
# A naive way to set spec for all weights in Linear
|
||||
for mo_name, module in model.named_modules():
|
||||
# print(mo_name)
|
||||
for pa_name, param in module.named_parameters(recurse=False):
|
||||
# print('\t', pa_name, param.shape)
|
||||
if not isinstance(param, ColoTensor):
|
||||
continue
|
||||
if 'weight' in pa_name:
|
||||
if 'embed' in mo_name and 'token' not in mo_name and 'LayerNorm' not in mo_name:
|
||||
split_param_row_tp1d(param, pg)
|
||||
elif 'LayerNorm' not in mo_name and 'ln' not in mo_name:
|
||||
split_param_col_tp1d(param, pg)
|
||||
|
||||
model = model.cuda()
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
data = data.to(get_current_device())
|
||||
label = label.to(get_current_device())
|
||||
|
||||
torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
|
||||
torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
|
||||
|
||||
# Bcast rank0 data to all processes
|
||||
if criterion:
|
||||
output = model(data)
|
||||
loss = criterion(output, label)
|
||||
else:
|
||||
output = model(data, label)
|
||||
loss = output
|
||||
|
||||
# For reference
|
||||
if rank == 0:
|
||||
if criterion:
|
||||
output_torch = model_torch(data)
|
||||
loss_torch = criterion(output_torch, label)
|
||||
else:
|
||||
output_torch = model_torch(data, label)
|
||||
loss_torch = output_torch
|
||||
assert torch.allclose(loss, loss_torch, rtol=1e-2)
|
||||
torch.distributed.barrier()
|
||||
|
||||
loss.backward()
|
||||
|
||||
if rank == 0:
|
||||
loss_torch.backward()
|
||||
torch.distributed.barrier()
|
||||
|
||||
if i > 5:
|
||||
break
|
||||
|
||||
|
||||
def _run_pretrain_load():
|
||||
from transformers import BertForMaskedLM
|
||||
set_seed(1)
|
||||
model_pretrained = BertForMaskedLM.from_pretrained('bert-base-uncased')
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
|
||||
|
||||
model_pretrained = model_pretrained.cuda()
|
||||
model = model.cuda()
|
||||
|
||||
dict_pretrained = {}
|
||||
dict_col = {}
|
||||
c_ref = 0
|
||||
for name, param in model_pretrained.named_parameters():
|
||||
dict_pretrained[name] = param
|
||||
c_ref += 1
|
||||
c1 = 0
|
||||
c2 = 0
|
||||
for name, param in model.named_parameters():
|
||||
if isinstance(param, ColoParameter):
|
||||
c1 += 1
|
||||
else:
|
||||
c2 += 1
|
||||
dict_col[name] = param
|
||||
assert c_ref == c1
|
||||
assert c2 == 0
|
||||
if model_pretrained.cls.predictions.decoder.bias is model_pretrained.cls.predictions.bias:
|
||||
assert model.cls.predictions.decoder.bias is model.cls.predictions.bias
|
||||
|
||||
for name, param in dict_pretrained.items():
|
||||
check_equal(param, dict_col[name])
|
||||
|
||||
|
||||
def run_model_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
# Comment below test for speed consideration
|
||||
# for name in ['bert', 'simple_net']:
|
||||
# run_1d_row_tp(name)
|
||||
for name in ['bert', 'simple_net']:
|
||||
run_1d_hybrid_tp(name)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_model(world_size):
|
||||
spawn(run_model_dist, world_size)
|
||||
|
||||
|
||||
def run_pretrain_load_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
_run_pretrain_load()
|
||||
|
||||
|
||||
# The test case has to download huggingface pretrained models from the internet
|
||||
# So we manually trigger the test.
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_pretrain_load(world_size):
|
||||
spawn(run_pretrain_load_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test_model_parameters()
|
||||
# test_colo_optimizer()
|
||||
test_model(4)
|
||||
# test_pretrain_load(4)
|
@@ -1,227 +0,0 @@
|
||||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.parallel.layers import check_colo_module, init_colo_module
|
||||
from colossalai.tensor import (
|
||||
ColoTensor,
|
||||
ColoTensorSpec,
|
||||
ComputePattern,
|
||||
ComputeSpec,
|
||||
ProcessGroup,
|
||||
ReplicaSpec,
|
||||
ShardSpec,
|
||||
distspec,
|
||||
)
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_tensor.common_utils import set_seed, tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
def run_model_with_spec(mode, model_name):
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
rank = pg.rank()
|
||||
|
||||
set_seed(1)
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder(checkpoint=False)
|
||||
|
||||
if rank == 0:
|
||||
model_seq = model_builder(checkpoint=False)
|
||||
model_seq = model_seq.cuda()
|
||||
|
||||
# Make two models have the same init params
|
||||
for p1, p2 in zip(model.parameters(), model_seq.parameters()):
|
||||
p2.data.copy_(p1.data)
|
||||
|
||||
compute_spec = ComputeSpec(ComputePattern.TP1D)
|
||||
# Not all layers in Bert can be mod by 4.
|
||||
# e.g. row shard for all layers is invalid because the first dim of some layer is the classification type size 2.
|
||||
if 'bert' == model_name:
|
||||
if 'col' == mode:
|
||||
init_colo_module(model.bert.embeddings, compute_spec, pg=pg, recursive=True, mode=mode)
|
||||
init_colo_module(model.bert.encoder, compute_spec, pg=pg, recursive=True, mode=mode)
|
||||
init_colo_module(model.classifier, compute_spec, pg=pg, recursive=True, mode='row')
|
||||
elif 'row' == mode:
|
||||
init_colo_module(model.bert.embeddings, compute_spec, pg=pg, recursive=True, mode='col')
|
||||
init_colo_module(model.bert.encoder, compute_spec, pg=pg, recursive=True, mode=mode)
|
||||
init_colo_module(model.classifier, compute_spec, pg=pg, recursive=True, mode=mode)
|
||||
elif 'simple_net' == model_name:
|
||||
init_colo_module(model, compute_spec, pg=pg, recursive=True, mode=mode)
|
||||
|
||||
model = model.cuda()
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
data = data.to(get_current_device())
|
||||
label = label.to(get_current_device())
|
||||
|
||||
torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
|
||||
torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
|
||||
|
||||
if criterion:
|
||||
output = model(data)
|
||||
loss = criterion(output, label)
|
||||
else:
|
||||
output = model(data, label)
|
||||
loss = output
|
||||
|
||||
# For reference
|
||||
if rank == 0:
|
||||
if criterion:
|
||||
output_seq = model_seq(data)
|
||||
loss_seq = criterion(output_seq, label)
|
||||
else:
|
||||
output_seq = model_seq(data, label)
|
||||
loss_seq = output_seq
|
||||
|
||||
if rank == 0:
|
||||
with torch.no_grad():
|
||||
assert torch.allclose(loss, loss_seq, rtol=1e-2)
|
||||
|
||||
loss.backward()
|
||||
|
||||
if rank == 0:
|
||||
loss_seq.backward()
|
||||
|
||||
with torch.no_grad():
|
||||
# check param
|
||||
for p1, p2 in zip(model.parameters(), model_seq.parameters()):
|
||||
if p1.size() == p2.size():
|
||||
assert torch.allclose(p1, p2)
|
||||
else:
|
||||
if p1.size(-1) < p2.size(-1): # col
|
||||
world_size = p2.size(-1) // p1.size(-1)
|
||||
split_p2 = torch.chunk(p2, world_size, dim=-1)[0]
|
||||
|
||||
elif p1.size(0) < p2.size(0): # row
|
||||
world_size = p2.size(0) // p1.size(0)
|
||||
split_p2 = torch.chunk(p2, world_size, dim=0)[0]
|
||||
|
||||
assert torch.allclose(p1, split_p2)
|
||||
|
||||
if i > 3:
|
||||
break
|
||||
|
||||
|
||||
def run_linear_with_spec(mode):
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = torch.nn.Linear(4, 8)
|
||||
|
||||
model_handy = deepcopy(model)
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
compute_spec = ComputeSpec(ComputePattern.TP1D)
|
||||
init_colo_module(model, compute_spec, pg=pg, recursive=True, mode=mode)
|
||||
|
||||
x = torch.rand(2, 4).cuda()
|
||||
colo_x = ColoTensor.from_torch_tensor(x, ColoTensorSpec(pg))
|
||||
|
||||
out = model(x)
|
||||
colo_out = model_handy(colo_x)
|
||||
assert tensor_equal(out, colo_out)
|
||||
|
||||
grad = torch.rand_like(out)
|
||||
out.backward(grad)
|
||||
colo_out.backward(grad)
|
||||
|
||||
assert tensor_shard_equal(model_handy.weight.grad, model.weight.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
assert tensor_shard_equal(model_handy.bias.grad, model.bias.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def run_check_shared_param():
|
||||
from transformers import BertConfig, BertForMaskedLM
|
||||
hidden_dim = 8
|
||||
num_head = 4
|
||||
sequence_length = 12
|
||||
num_layer = 2
|
||||
vocab_size = 24
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
rank = pg.rank()
|
||||
|
||||
config = BertConfig(vocab_size=vocab_size,
|
||||
hidden_size=hidden_dim,
|
||||
intermediate_size=hidden_dim * 4,
|
||||
num_attention_heads=num_head,
|
||||
max_position_embeddings=sequence_length,
|
||||
num_hidden_layers=num_layer,
|
||||
hidden_dropout_prob=0.,
|
||||
attention_probs_dropout_prob=0.)
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = BertForMaskedLM(config)
|
||||
|
||||
model = model.cuda()
|
||||
compute_spec = ComputeSpec(ComputePattern.TP1D)
|
||||
# model.cls.predictions.decoder and model.cls.predictions share the bias, so they should have the same spec
|
||||
assert len(model.cls.predictions.decoder.bias.shared_param_modules) == 2
|
||||
# They are all Linear, so both row is allowed. This should pass check.
|
||||
init_colo_module(model, compute_spec, pg=pg, recursive=True, mode='row')
|
||||
# This should be detected by check because you can not set weight as row while set bias as col.
|
||||
col_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
|
||||
# TODO(jiaruifang) optimize this line
|
||||
if not model.cls.predictions.bias.has_initialized:
|
||||
model.cls.predictions.bias.pg = pg
|
||||
model.cls.predictions.bias.dist_spec = ReplicaSpec()
|
||||
model.cls.predictions.bias.has_initialized = True
|
||||
model.cls.predictions.bias.set_tensor_spec(*col_spec)
|
||||
try:
|
||||
check_colo_module(model.cls.predictions.decoder, pg=pg, recursive=False)
|
||||
except Exception as e:
|
||||
assert 'incorrectly sharded' in str(e)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_linear_with_spec('col')
|
||||
run_linear_with_spec('row')
|
||||
|
||||
|
||||
def run_dist_model(rank, world_size, port):
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
for model_name in ['simple_net', 'bert']:
|
||||
run_model_with_spec('col', model_name)
|
||||
run_model_with_spec('row', model_name)
|
||||
|
||||
|
||||
def run_dist_check(rank, world_size, port):
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_check_shared_param()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@pytest.mark.skip("for higher testing speed")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_module_linear_1d(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@pytest.mark.skip("for higher testing speed")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_module_model(world_size):
|
||||
spawn(run_dist_model, world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@pytest.mark.skip("for higher testing speed")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_module_check(world_size):
|
||||
spawn(run_dist_check, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_module_linear_1d(4)
|
@@ -1,41 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor
|
||||
from tests.test_tensor.common_utils import tensor_shard_equal
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, dp_degree, tp_degree):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
pg = ProcessGroup(dp_degree=dp_degree, tp_degree=tp_degree)
|
||||
x = torch.randn(4, 4)
|
||||
param = ColoTensor(torch.nn.Parameter(x), spec=ColoTensorSpec(pg))
|
||||
spec = ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)
|
||||
param.set_tensor_spec(*spec)
|
||||
|
||||
gather_tensor(param)
|
||||
if dist.get_rank() == 0:
|
||||
assert torch.all(x == param)
|
||||
else:
|
||||
assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size())
|
||||
dist.barrier()
|
||||
|
||||
scatter_tensor(param, spec[0])
|
||||
assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size())
|
||||
assert param.requires_grad is True
|
||||
dist.barrier()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_checkpoint(world_size):
|
||||
spawn(run_dist, world_size, dp_degree=2, tp_degree=world_size // 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_checkpoint(world_size=4)
|
@@ -1,64 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import (
|
||||
ColoParameter,
|
||||
ColoTensorSpec,
|
||||
ComputePattern,
|
||||
ComputeSpec,
|
||||
ProcessGroup,
|
||||
ReplicaSpec,
|
||||
ShardSpec,
|
||||
)
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_tensor.common_utils import set_seed
|
||||
|
||||
|
||||
def run_colo_init_context(rank: int, world_size: int, port: int):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
# make sure seed of each process is the same, so the params are consistent among processes and the params are exactly replicated.
|
||||
set_seed(42)
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
# keep parameters replicated during init
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model1 = model_builder()
|
||||
|
||||
# shard the parameters during init
|
||||
set_seed(42)
|
||||
shard_spec = ReplicaSpec()
|
||||
|
||||
# If using ShardSpec, the assertations will failed.
|
||||
# But it is not a bug, the initialized values are not consist with the original one.
|
||||
# shard_spec = ShardSpec(dims=[0], num_partitions=[world_size])
|
||||
default_pg = ProcessGroup(tp_degree=world_size)
|
||||
with ColoInitContext(device=get_current_device(), default_pg=default_pg, default_dist_spec=shard_spec):
|
||||
model2 = model_builder()
|
||||
|
||||
# reshard both models
|
||||
new_shard = ShardSpec(dims=[-1], num_partitions=[world_size])
|
||||
for p1, p2 in zip(model1.parameters(), model2.parameters()):
|
||||
p1: ColoParameter = p1
|
||||
p1.set_process_group(ProcessGroup(tp_degree=world_size))
|
||||
p1.set_dist_spec(new_shard)
|
||||
p2.set_dist_spec(new_shard)
|
||||
|
||||
for p1, p2 in zip(model1.parameters(), model2.parameters()):
|
||||
assert (torch.allclose(p1, p2))
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_colo_init_context(world_size):
|
||||
spawn(run_colo_init_context, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_colo_init_context(2)
|
@@ -1,232 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import colossalai
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.nn._ops._utils import gather_forward_split_backward
|
||||
from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
# create mlp vars
|
||||
x = ColoTensor.from_torch_tensor(torch.rand(4, 4, 8, requires_grad=True)).cuda()
|
||||
w = ColoParameter.from_torch_tensor(torch.rand(16, 8, requires_grad=True)).cuda()
|
||||
b = ColoParameter.from_torch_tensor(torch.rand(16, requires_grad=True)).cuda()
|
||||
|
||||
# run normal forward
|
||||
out = F.linear(x, w, b)
|
||||
|
||||
# create mesh meta
|
||||
# the mesh is in the following topo
|
||||
# [[0, 1],
|
||||
# [2, 3]]
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
row_id = rank // 2
|
||||
column_id = rank % 2
|
||||
|
||||
# create pg
|
||||
row_process_group = None
|
||||
col_process_group = None
|
||||
row_to_ranks = {0: [0, 1], 1: [2, 3]}
|
||||
col_to_ranks = {0: [0, 2], 1: [1, 3]}
|
||||
|
||||
for idx in range(2):
|
||||
# row ranks
|
||||
row_ranks = row_to_ranks[idx]
|
||||
row_pg = ProcessGroup(ranks=row_ranks, tp_degree=2)
|
||||
|
||||
# col ranks
|
||||
col_ranks = col_to_ranks[idx]
|
||||
col_pg = ProcessGroup(ranks=col_ranks, tp_degree=2)
|
||||
|
||||
if rank in row_ranks:
|
||||
row_process_group = row_pg
|
||||
|
||||
if rank in col_ranks:
|
||||
col_process_group = col_pg
|
||||
|
||||
########################
|
||||
# RRR x RS0 -> RRS0 #
|
||||
########################
|
||||
# w will be transposed in F.linear
|
||||
x_replica = x.detach().clone()
|
||||
w_shard = torch.chunk(w.detach().clone(), chunks=2, dim=0)[row_id]
|
||||
b_shard = torch.chunk(b.detach().clone(), chunks=2, dim=0)[row_id]
|
||||
|
||||
# adding sharding spec
|
||||
x_replica.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={})
|
||||
w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={0: [0]})
|
||||
b_shard.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={0: [0]})
|
||||
|
||||
# check sharding spec
|
||||
assert str(x_replica.sharding_spec.sharding_sequence) == "[R, R, R]"
|
||||
assert str(w_shard.sharding_spec.sharding_sequence) == "[S0, R]"
|
||||
assert str(b_shard.sharding_spec.sharding_sequence) == "[S0]"
|
||||
|
||||
w_shard.pg_axis0 = col_process_group
|
||||
w_shard.pg_axis1 = row_process_group
|
||||
|
||||
out_shard = F.linear(x_replica, w_shard, b_shard)
|
||||
assert str(out_shard.sharding_spec.sharding_sequence) == "[R, R, S0]"
|
||||
|
||||
# each row only has a mini-batch
|
||||
expected_out_shard = torch.chunk(out, chunks=2, dim=2)[row_id]
|
||||
assert torch.allclose(out_shard, expected_out_shard)
|
||||
|
||||
########################
|
||||
# S0RR x RS1 -> S0RS1 #
|
||||
########################
|
||||
# w will be transposed in F.linear
|
||||
x_shard = torch.chunk(x.detach().clone(), chunks=2, dim=0)[row_id]
|
||||
w_shard = torch.chunk(w.detach().clone(), chunks=2, dim=0)[column_id]
|
||||
b_shard = torch.chunk(b.detach().clone(), chunks=2, dim=0)[column_id]
|
||||
|
||||
# adding sharding spec
|
||||
x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={0: [0]})
|
||||
w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={0: [1]})
|
||||
b_shard.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={0: [1]})
|
||||
|
||||
# check sharding spec
|
||||
assert str(x_shard.sharding_spec.sharding_sequence) == "[S0, R, R]"
|
||||
assert str(w_shard.sharding_spec.sharding_sequence) == "[S1, R]"
|
||||
assert str(b_shard.sharding_spec.sharding_sequence) == "[S1]"
|
||||
|
||||
w_shard.pg_axis0 = col_process_group
|
||||
w_shard.pg_axis1 = row_process_group
|
||||
|
||||
out_shard = F.linear(x_shard, w_shard, b_shard)
|
||||
|
||||
# each row only has a mini-batch
|
||||
expected_out_shard = torch.chunk(out, chunks=2, dim=0)[row_id]
|
||||
expected_out_shard = torch.chunk(expected_out_shard, chunks=2, dim=2)[column_id]
|
||||
assert torch.allclose(out_shard, expected_out_shard)
|
||||
|
||||
########################
|
||||
# S0RS1 x S1R -> S0RR #
|
||||
########################
|
||||
# w will be transposed in F.linear
|
||||
x_shard = torch.chunk(x.clone(), chunks=2, dim=0)[row_id]
|
||||
x_shard = torch.chunk(x_shard, chunks=2, dim=2)[column_id]
|
||||
w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[column_id]
|
||||
b_replica = b.clone()
|
||||
|
||||
# adding sharding spec
|
||||
x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={0: [0], 2: [1]})
|
||||
w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={1: [1]})
|
||||
b_replica.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={})
|
||||
|
||||
# check sharding spec
|
||||
assert str(x_shard.sharding_spec.sharding_sequence) == "[S0, R, S1]"
|
||||
assert str(w_shard.sharding_spec.sharding_sequence) == "[R, S1]"
|
||||
assert str(b_replica.sharding_spec.sharding_sequence) == "[R]"
|
||||
|
||||
w_shard.pg_axis0 = col_process_group
|
||||
w_shard.pg_axis1 = row_process_group
|
||||
|
||||
out_shard = F.linear(x_shard, w_shard, b_replica)
|
||||
|
||||
# each row only has a mini-batch
|
||||
expected_out_shard = torch.chunk(out, chunks=2, dim=0)[row_id]
|
||||
assert torch.allclose(out_shard, expected_out_shard)
|
||||
|
||||
########################
|
||||
# RRS0 x S0R -> RRR #
|
||||
########################
|
||||
# w will be transposed in F.linear
|
||||
x_shard = torch.chunk(x.clone(), chunks=2, dim=2)[row_id]
|
||||
w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[row_id]
|
||||
b_replica = b.clone()
|
||||
|
||||
# adding sharding spec
|
||||
x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={2: [0]})
|
||||
w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={1: [0]})
|
||||
b_replica.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={})
|
||||
|
||||
# check sharding spec
|
||||
assert str(x_shard.sharding_spec.sharding_sequence) == "[R, R, S0]"
|
||||
assert str(w_shard.sharding_spec.sharding_sequence) == "[R, S0]"
|
||||
assert str(b_replica.sharding_spec.sharding_sequence) == "[R]"
|
||||
|
||||
w_shard.pg_axis0 = col_process_group
|
||||
w_shard.pg_axis1 = row_process_group
|
||||
|
||||
out_shard = F.linear(x_shard, w_shard, b_replica)
|
||||
|
||||
# each row only has a mini-batch
|
||||
expected_out_shard = out
|
||||
assert torch.allclose(out_shard, expected_out_shard)
|
||||
|
||||
########################
|
||||
# RS0S1 x S1R -> RS0R #
|
||||
########################
|
||||
# w will be transposed in F.linear
|
||||
x_shard = torch.chunk(x.clone(), chunks=2, dim=1)[row_id]
|
||||
x_shard = torch.chunk(x_shard, chunks=2, dim=2)[column_id]
|
||||
w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[column_id]
|
||||
b_replica = b.clone()
|
||||
|
||||
# adding sharding spec
|
||||
x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={1: [0], 2: [1]})
|
||||
w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={1: [1]})
|
||||
b_replica.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={})
|
||||
|
||||
# check sharding spec
|
||||
assert str(x_shard.sharding_spec.sharding_sequence) == "[R, S0, S1]"
|
||||
assert str(w_shard.sharding_spec.sharding_sequence) == "[R, S1]"
|
||||
assert str(b_replica.sharding_spec.sharding_sequence) == "[R]"
|
||||
|
||||
w_shard.pg_axis0 = col_process_group
|
||||
w_shard.pg_axis1 = row_process_group
|
||||
|
||||
out_shard = F.linear(x_shard, w_shard, b_replica)
|
||||
|
||||
# each row only has a mini-batch
|
||||
expected_out_shard = torch.chunk(out, chunks=2, dim=1)[row_id]
|
||||
assert torch.allclose(out_shard, expected_out_shard)
|
||||
|
||||
########################
|
||||
# RRS0 x S0S1 -> RRS1 #
|
||||
########################
|
||||
# w will be transposed in F.linear
|
||||
x_shard = torch.chunk(x.clone(), chunks=2, dim=2)[row_id]
|
||||
w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[row_id]
|
||||
w_shard = torch.chunk(w_shard, chunks=2, dim=0)[column_id]
|
||||
b_shard = torch.chunk(b.clone(), chunks=2, dim=0)[column_id]
|
||||
|
||||
# adding sharding spec
|
||||
x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={2: [0]})
|
||||
w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={0: [1], 1: [0]})
|
||||
b_shard.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={0: [1]})
|
||||
|
||||
# check sharding spec
|
||||
assert str(x_shard.sharding_spec.sharding_sequence) == "[R, R, S0]"
|
||||
assert str(w_shard.sharding_spec.sharding_sequence) == "[S1, S0]"
|
||||
assert str(b_shard.sharding_spec.sharding_sequence) == "[S1]"
|
||||
|
||||
w_shard.pg_axis0 = col_process_group
|
||||
w_shard.pg_axis1 = row_process_group
|
||||
|
||||
out_shard = F.linear(x_shard, w_shard, b_shard)
|
||||
|
||||
# each row only has a mini-batch
|
||||
expected_out_shard = torch.chunk(out, chunks=2, dim=2)[column_id]
|
||||
assert torch.allclose(out_shard, expected_out_shard)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_sharded_mlp(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sharded_mlp(4)
|
@@ -1,143 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
import colossalai
|
||||
from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP, ZeroDDP
|
||||
from colossalai.zero.gemini import search_chunk_configuration
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_tensor.common_utils import set_seed, tensor_shard_equal
|
||||
from tests.test_tensor.model.test_gpt2 import init_megatron_spec
|
||||
|
||||
|
||||
def check_param(model: ZeroDDP, torch_model: torch.nn.Module, pg: ProcessGroup):
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
torch_dict = torch_model.state_dict()
|
||||
|
||||
for key, value in torch_dict.items():
|
||||
# key is 'module.model.PARAMETER', so we truncate it
|
||||
key = key[7:]
|
||||
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
||||
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
||||
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
|
||||
assert tensor_shard_equal(value, temp_zero_value, pg.tp_local_rank(), pg.tp_world_size()), \
|
||||
"parameter '{}' has problem.".format(key)
|
||||
|
||||
|
||||
def run_fwd_bwd(model, criterion, optimizer, input_ids):
|
||||
optimizer.zero_grad()
|
||||
logits = model(input_ids)
|
||||
logits = logits.float()
|
||||
loss = criterion(logits, input_ids)
|
||||
optimizer.backward(loss)
|
||||
return logits
|
||||
|
||||
|
||||
def init_1d_row_spec(model, pg: ProcessGroup):
|
||||
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
for n, p in model.named_parameters():
|
||||
p.set_process_group(pg)
|
||||
if 'weight' in n and 'ln' not in n:
|
||||
p.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def init_1d_col_spec(model, pg: ProcessGroup):
|
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
for n, p in model.named_parameters():
|
||||
p.set_process_group(pg)
|
||||
if 'ln' not in n and ('weight' in n or 'bias' in n):
|
||||
p.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||
def run_gpt(placement_policy, tp_init_spec_func=None):
|
||||
set_seed(42)
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
model = model.cuda()
|
||||
torch_model = model_builder().cuda()
|
||||
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
torch_p.data.copy_(p.data)
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
# world size, dp = 2, tp =2, construct a hybrid parallelism.
|
||||
if world_size == 4:
|
||||
pg = ProcessGroup(tp_degree=2)
|
||||
else:
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
|
||||
if tp_init_spec_func:
|
||||
tp_init_spec_func(model, pg)
|
||||
|
||||
dp_world_size = pg.dp_world_size()
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[dp_world_size]['chunk_size'] = 5000
|
||||
config_dict[dp_world_size]['keep_gathered'] = False
|
||||
if placement_policy != 'cuda':
|
||||
init_device = torch.device('cpu')
|
||||
else:
|
||||
init_device = None
|
||||
|
||||
model = GeminiDDP(model, init_device, placement_policy, True, False)
|
||||
# The same as the following 3 lines
|
||||
# chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
||||
# gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
# model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
|
||||
zero_optim = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=1)
|
||||
# The same as the following 2 lines
|
||||
# optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
# zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1)
|
||||
|
||||
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
|
||||
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
||||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
|
||||
|
||||
check_param(model, torch_model, pg)
|
||||
|
||||
model.eval()
|
||||
torch_model.eval()
|
||||
|
||||
set_seed(pg.dp_local_rank())
|
||||
for i, (input_ids, label) in enumerate(train_dataloader):
|
||||
if i > 2:
|
||||
break
|
||||
input_ids_colo = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg))
|
||||
zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids_colo)
|
||||
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids)
|
||||
assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2)
|
||||
|
||||
zero_optim.step()
|
||||
torch_optim.step()
|
||||
check_param(model, torch_model, pg)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
if world_size == 4:
|
||||
run_gpt(tp_init_spec_func=init_megatron_spec)
|
||||
else:
|
||||
run_gpt(tp_init_spec_func=init_1d_col_spec)
|
||||
run_gpt(tp_init_spec_func=init_1d_row_spec)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gpt(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_gpt(4)
|
@@ -1,206 +0,0 @@
|
||||
import os
|
||||
import shutil
|
||||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.checkpoint import load_checkpoint, save_checkpoint
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup):
|
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
weight.set_process_group(pg)
|
||||
weight.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def init_1d_col_linear(weight, pg):
|
||||
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
weight.set_process_group(pg)
|
||||
weight.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def init_1d_row_embedding(weight, pg):
|
||||
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
weight.set_process_group(pg)
|
||||
weight.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def init_1d_col_embedding(weight, pg):
|
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
weight.set_process_group(pg)
|
||||
weight.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup):
|
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
for name, p in model.named_parameters():
|
||||
if not isinstance(p, ColoTensor):
|
||||
continue
|
||||
if 'embed' in name and 'weight' in name:
|
||||
init_1d_col_embedding(p, pg)
|
||||
if 'proj1' in name and ('weight' in name or 'bias' in name):
|
||||
init_1d_col_linear(p, pg)
|
||||
if 'proj2' in name and 'weight' in name:
|
||||
init_1d_row_linear(p, pg)
|
||||
if 'classifier' in name and ('weight' in name or 'bias' in name):
|
||||
init_1d_col_linear(p, pg)
|
||||
|
||||
|
||||
def check_param_equal(model, torch_model):
|
||||
for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()):
|
||||
assert torch.all(p.data == tp.data), "{} went wrong.\n {} vs {}\n{}".format(n, p, tp, p.shape)
|
||||
|
||||
|
||||
def remove(path):
|
||||
""" param <path> could either be relative or absolute. """
|
||||
if os.path.isfile(path) or os.path.islink(path):
|
||||
os.remove(path)
|
||||
elif os.path.isdir(path):
|
||||
shutil.rmtree(path)
|
||||
else:
|
||||
raise ValueError("file {} is not a file or dir.".format(path))
|
||||
|
||||
|
||||
def compare_optims(optim1, optim2):
|
||||
state1 = optim1.state_dict()['state']
|
||||
state2 = optim2.state_dict()['state']
|
||||
for k, p1 in state1.items():
|
||||
if k not in state2:
|
||||
continue
|
||||
p2 = state2[k]
|
||||
for n, t1 in p1.items():
|
||||
if n not in p2:
|
||||
continue
|
||||
t2 = p2[n]
|
||||
if isinstance(t1, ColoTensor):
|
||||
assert isinstance(t2, ColoTensor)
|
||||
assert torch.allclose(t1, t2, rtol=0, atol=0)
|
||||
|
||||
|
||||
def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
# set_seed(1)
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder(checkpoint=True)
|
||||
|
||||
if use_mp_reload:
|
||||
if 'bert' == model_name:
|
||||
for name, p in model.named_parameters():
|
||||
if not isinstance(p, ColoTensor):
|
||||
continue
|
||||
# num_class = type_vocab_size = 2 | (8, 2)
|
||||
if 'classifier' in name and 'weight' in name:
|
||||
init_1d_row_linear(p, pg)
|
||||
# num_class = vocab_size = 30524 | (30524, 8)
|
||||
elif 'word_embeddings' in name and 'weight' in name:
|
||||
init_1d_row_embedding(p, pg)
|
||||
# num_class = seq_len = 512 | (512, 8)
|
||||
elif 'position_embeddings' in name and 'weight' in name:
|
||||
init_1d_row_embedding(p, pg)
|
||||
# num_class = type_vocab_size = 2 | (2, 8)
|
||||
elif 'token_type_embeddings' in name and 'weight' in name:
|
||||
init_1d_col_embedding(p, pg)
|
||||
elif p.process_group.tp_world_size() == 1:
|
||||
p.set_process_group(pg)
|
||||
elif "simple_net" == model_name:
|
||||
init_spec_func(model, pg)
|
||||
|
||||
model_reload = deepcopy(model)
|
||||
model = model.cuda()
|
||||
model.eval()
|
||||
|
||||
model_reload = model_reload.cuda()
|
||||
model_reload.eval()
|
||||
|
||||
opt_class = torch.optim.Adam
|
||||
colo_optimizer = ColossalaiOptimizer(opt_class(model.parameters(), lr=0.1))
|
||||
colo_optimizer_reload = ColossalaiOptimizer(opt_class(model_reload.parameters(), lr=0.1))
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
|
||||
# Zero grad
|
||||
colo_optimizer.zero_grad()
|
||||
colo_optimizer_reload.zero_grad()
|
||||
|
||||
data = data.to(get_current_device())
|
||||
label = label.to(get_current_device())
|
||||
|
||||
dist.broadcast(data, pg.tp_rank_list()[0], pg.tp_process_group())
|
||||
dist.broadcast(label, pg.tp_rank_list()[0], pg.tp_process_group())
|
||||
|
||||
# Bcast rank0 data to all processes
|
||||
if criterion:
|
||||
output = model(data)
|
||||
output_reload = model_reload(data)
|
||||
loss = criterion(output, label)
|
||||
loss_reload = criterion(output_reload, label)
|
||||
else:
|
||||
loss = model(data, label)
|
||||
loss_reload = model_reload(data, label)
|
||||
|
||||
loss.backward()
|
||||
loss_reload.backward()
|
||||
|
||||
colo_optimizer.step()
|
||||
colo_optimizer_reload.step()
|
||||
|
||||
if i > 2:
|
||||
break
|
||||
|
||||
if not os.path.isdir('./checkpoint') and rank == 0:
|
||||
os.mkdir('./checkpoint')
|
||||
dist.barrier()
|
||||
|
||||
save_checkpoint('./checkpoint', 0, model, colo_optimizer, None)
|
||||
load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None)
|
||||
|
||||
check_param_equal(model, model_reload)
|
||||
compare_optims(colo_optimizer, colo_optimizer_reload)
|
||||
|
||||
if rank == 0:
|
||||
remove('./checkpoint')
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
|
||||
# the data loader of BERT is in DDP mode, causing the input data is not replicated in the TP context
|
||||
for model_name in ['bert']:
|
||||
_run_checkpoint(model_name,
|
||||
init_1d_row_for_linear_weight_spec,
|
||||
use_ddp,
|
||||
use_mp_reload,
|
||||
test_scheduler=test_scheduler,
|
||||
pg=pg)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@pytest.mark.parametrize('use_ddp', [False])
|
||||
@pytest.mark.parametrize('use_mp_reload', [True, False])
|
||||
# @pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda'])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler=None):
|
||||
spawn(run_dist, world_size, use_ddp=use_ddp, use_mp_reload=use_mp_reload, test_scheduler=test_scheduler)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_checkpoint(2, use_ddp=False, use_mp_reload=True, test_scheduler="torch_cosine")
|
@@ -66,6 +66,7 @@ def run_dist(rank, world_size, port):
|
||||
run_grad_clip_norm(world_size=world_size)
|
||||
|
||||
|
||||
@pytest.mark.skip("this need to be updated")
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
|
@@ -1,8 +1,9 @@
|
||||
import pytest
|
||||
import torch
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
|
||||
from colossalai.tensor import ColoTensor
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.zero.gemini.chunk import ChunkManager
|
||||
from tests.test_tensor.common_utils import debug_print
|
||||
@@ -15,19 +16,18 @@ CPU_MEM = {True: {True: 0, False: 0}, False: {True: 512, False: 0}}
|
||||
@parameterize('keep_gathered', [True, False])
|
||||
@parameterize('pin_memory', [True, False])
|
||||
def exam_chunk_memory(keep_gathered, pin_memory):
|
||||
pg = ProcessGroup()
|
||||
|
||||
debug_print([0], "keep_gathered: {}, pin_memory: {}".format(keep_gathered, pin_memory))
|
||||
|
||||
params = [ColoTensor(torch.rand(8, 8), spec=ColoTensorSpec(pg)) for _ in range(3)]
|
||||
params = [ColoTensor(torch.rand(8, 8)) for _ in range(3)]
|
||||
config = {2: dict(chunk_size=128, keep_gathered=keep_gathered)}
|
||||
|
||||
chunk_manager = ChunkManager(config)
|
||||
assert chunk_manager.total_mem['cpu'] == 0
|
||||
assert chunk_manager.total_mem['cuda'] == 0
|
||||
|
||||
process_group = _get_default_group()
|
||||
for p in params:
|
||||
chunk_manager.register_tensor(p, 'param', 2, pin_memory=pin_memory)
|
||||
chunk_manager.register_tensor(p, 'param', 2, process_group, pin_memory=pin_memory)
|
||||
chunk_manager.close_all_groups()
|
||||
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory]
|
||||
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered]
|
||||
|
@@ -1,10 +1,10 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoParameter
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.gemini import TensorState
|
||||
@@ -36,7 +36,7 @@ def check_equal(param, param_cp):
|
||||
@parameterize('pin_memory', [True, False])
|
||||
def exam_chunk_basic(init_device, keep_gathered, pin_memory):
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ColoProcessGroup()
|
||||
pg = _get_default_group()
|
||||
my_chunk = Chunk(chunk_size=1024,
|
||||
process_group=pg,
|
||||
dtype=torch.float32,
|
||||
|
@@ -1,23 +1,40 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer
|
||||
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||
from tests.components_to_test import run_fwd, run_fwd_bwd
|
||||
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
||||
from colossalai.zero.gemini.chunk import search_chunk_configuration
|
||||
from tests.components_to_test import run_fwd_bwd
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_tensor.common_utils import set_seed
|
||||
|
||||
PLACEMENT_CONFIGS = [
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0
|
||||
}, # zero2
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 1.0
|
||||
}, # zero3
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.5
|
||||
}, # zero3-half
|
||||
{
|
||||
'placement_policy': 'auto'
|
||||
}
|
||||
]
|
||||
|
||||
def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
|
||||
|
||||
def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
|
||||
chunk_manager = model.chunk_manager
|
||||
param_list = [p for p in model.parameters()]
|
||||
chunk_list = chunk_manager.get_chunks(param_list)
|
||||
@@ -28,12 +45,12 @@ def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
|
||||
assert_close(p0, p1.grad, rtol=1e-3, atol=5e-5)
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
||||
@parameterize('placement_config', PLACEMENT_CONFIGS)
|
||||
@parameterize('keep_gather', [False, True])
|
||||
@parameterize('model_name', ['gpt2', 'bert', 'albert'])
|
||||
@parameterize('use_grad_checkpoint', [False, True])
|
||||
def exam_gpt_fwd_bwd(
|
||||
placement_policy,
|
||||
placement_config,
|
||||
keep_gather,
|
||||
model_name: str,
|
||||
use_grad_checkpoint: bool = False,
|
||||
@@ -43,8 +60,7 @@ def exam_gpt_fwd_bwd(
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
set_seed(42)
|
||||
with ColoInitContext(device=init_device):
|
||||
model = model_builder(use_grad_checkpoint)
|
||||
model = model_builder(use_grad_checkpoint)
|
||||
|
||||
set_seed(42)
|
||||
torch_model = model_builder(use_grad_checkpoint).cuda()
|
||||
@@ -55,19 +71,17 @@ def exam_gpt_fwd_bwd(
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = keep_gather
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
model = GeminiDDP(model, config_dict, init_device, pin_memory=True, **placement_config)
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1)
|
||||
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1)
|
||||
|
||||
pg = ProcessGroup()
|
||||
rank = dist.get_rank()
|
||||
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
|
||||
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
||||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
|
||||
torch_model = DDP(torch_model, device_ids=[rank])
|
||||
|
||||
set_seed(pg.dp_local_rank())
|
||||
set_seed(rank)
|
||||
for i, (input_ids, label) in enumerate(train_dataloader):
|
||||
# you can only test a single fwd + bwd.
|
||||
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
|
||||
@@ -89,65 +103,10 @@ def exam_gpt_fwd_bwd(
|
||||
check_grad(model, torch_model)
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||
@parameterize('keep_gather', [False, True])
|
||||
@parameterize('model_name', ['gpt2', 'bert', 'albert'])
|
||||
@parameterize('scatter_after_inference', [False, True])
|
||||
def exam_gpt_inference(
|
||||
placement_policy,
|
||||
keep_gather,
|
||||
model_name: str,
|
||||
scatter_after_inference: bool = False,
|
||||
):
|
||||
init_device = get_current_device()
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
set_seed(42)
|
||||
with ColoInitContext(device=init_device):
|
||||
model = model_builder()
|
||||
|
||||
set_seed(42)
|
||||
torch_model = model_builder().cuda()
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
torch_p.data.copy_(p.data)
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = keep_gather
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True, scatter_after_inference=scatter_after_inference)
|
||||
|
||||
pg = ProcessGroup()
|
||||
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
|
||||
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
||||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
|
||||
|
||||
set_seed(pg.dp_local_rank())
|
||||
model.eval()
|
||||
torch_model.eval()
|
||||
for i, (input_ids, label) in enumerate(train_dataloader):
|
||||
# you can only test a single fwd + bwd.
|
||||
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
|
||||
if i > 0:
|
||||
break
|
||||
with torch.no_grad():
|
||||
input_ids, label = input_ids.cuda(), label.cuda()
|
||||
|
||||
torch_loss = run_fwd(torch_model, input_ids, label, criterion)
|
||||
loss = run_fwd(model, input_ids, label, criterion)
|
||||
|
||||
assert torch.equal(torch_loss, loss)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
exam_gpt_fwd_bwd()
|
||||
exam_gpt_inference()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
@@ -1,12 +1,11 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.zero import ColoInitContext, ZeroDDP
|
||||
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.zero import GeminiDDP
|
||||
from colossalai.zero.gemini.chunk import search_chunk_configuration
|
||||
from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer
|
||||
from tests.components_to_test import run_fwd_bwd
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
@@ -24,8 +23,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device='cpu'):
|
||||
model = model_builder(use_grad_checkpoint)
|
||||
model = model_builder(use_grad_checkpoint).cuda()
|
||||
|
||||
print(f'model_name {model_name}')
|
||||
runtime_mem_tracer = RuntimeMemTracer(model)
|
||||
@@ -59,12 +57,13 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = keep_gather
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
model = GeminiDDP(model,
|
||||
chunk_config_dict=config_dict,
|
||||
placement_policy=placement_policy,
|
||||
pin_memory=True,
|
||||
memstats=memstats)
|
||||
|
||||
pg = ProcessGroup()
|
||||
set_seed(pg.dp_local_rank())
|
||||
set_seed(dist.get_rank())
|
||||
for i, (input_ids, label) in enumerate(train_dataloader):
|
||||
# you can only test a single fwd + bwd.
|
||||
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
|
||||
@@ -76,7 +75,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
|
||||
set_seed(42)
|
||||
loss = run_fwd_bwd(model, input_ids, label, criterion, model)
|
||||
|
||||
gemini_non_model_data = gemini_manager._mem_stats_collector._memstats.non_model_data_list('cuda')
|
||||
gemini_non_model_data = model.gemini_manager._mem_stats_collector._memstats.non_model_data_list('cuda')
|
||||
|
||||
# print('gemini non model data:', gemini_non_model_data)
|
||||
|
||||
@@ -90,6 +89,7 @@ def run_dist(rank, world_size, port):
|
||||
run_gemini_use_rmt()
|
||||
|
||||
|
||||
@pytest.mark.skip("this is not used")
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
|
@@ -1,52 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoParameter
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext, GeminiDDP
|
||||
from colossalai.zero.gemini.utils import get_static_torch_model
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
@parameterize('model_name', ['hanging_param_model', 'resnet18', 'gpt2'])
|
||||
def run_convert_torch_module(model_name: str):
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, _, _, _, _ = get_components_func()
|
||||
|
||||
with ColoInitContext(device=torch.device("cpu")):
|
||||
model = model_builder(checkpoint=False)
|
||||
model = GeminiDDP(model, device=get_current_device(), placement_policy='auto', pin_memory=True)
|
||||
pytorch_model = get_static_torch_model(model, only_rank_0=False)
|
||||
|
||||
for n, p in pytorch_model.named_parameters():
|
||||
assert type(p) == torch.nn.Parameter, f"type error: {n} is a {type(p)}"
|
||||
|
||||
# get the static model should not change the original model
|
||||
for n, p in model.named_parameters():
|
||||
assert isinstance(p, ColoParameter)
|
||||
|
||||
for (pn, pm), (cn, cm) in zip(pytorch_model.named_modules(), model.named_modules()):
|
||||
assert pn == cn
|
||||
assert id(pm) != id(cm)
|
||||
for pp, cp in zip(pm.parameters(recurse=False), cm.parameters(recurse=False)):
|
||||
assert id(pp) != id(cp)
|
||||
assert pp.shape == cp.shape
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_convert_torch_module()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_convert_torch_module(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_convert_torch_module(2)
|
@@ -8,16 +8,38 @@ import colossalai
|
||||
from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer
|
||||
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
||||
from colossalai.zero.gemini.chunk import search_chunk_configuration
|
||||
from tests.components_to_test import run_fwd_bwd
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_tensor.common_utils import set_seed
|
||||
|
||||
PLACEMENT_CONFIGS = [
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0,
|
||||
'offload_optim_frac': 0.0,
|
||||
'offload_param_frac': 0.0
|
||||
}, # zero2
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0,
|
||||
'offload_optim_frac': 1.0,
|
||||
'offload_param_frac': 0.0
|
||||
}, # zero2-offload
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0,
|
||||
'offload_optim_frac': 0.5,
|
||||
'offload_param_frac': 0.0
|
||||
}, # zero2-offload-half
|
||||
{
|
||||
'placement_policy': 'auto'
|
||||
}
|
||||
]
|
||||
|
||||
def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
|
||||
|
||||
def check_param(model: GeminiDDP, torch_model: torch.nn.Module):
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
torch_dict = torch_model.state_dict()
|
||||
|
||||
@@ -30,9 +52,9 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
|
||||
assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3)
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
||||
@parameterize('placement_config', PLACEMENT_CONFIGS)
|
||||
@parameterize('model_name', ['gpt2'])
|
||||
def exam_grad_clipping(placement_policy, model_name: str):
|
||||
def exam_grad_clipping(placement_config, model_name: str):
|
||||
set_seed(1912)
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
@@ -43,9 +65,7 @@ def exam_grad_clipping(placement_policy, model_name: str):
|
||||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
|
||||
|
||||
init_dev = get_current_device()
|
||||
with ColoInitContext(device=init_dev):
|
||||
model = model_builder()
|
||||
model = model_builder()
|
||||
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
p.data.copy_(torch_p.data)
|
||||
@@ -54,16 +74,19 @@ def exam_grad_clipping(placement_policy, model_name: str):
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = False
|
||||
if placement_policy != 'cuda':
|
||||
if placement_config['placement_policy'] != 'cuda':
|
||||
init_device = torch.device('cpu')
|
||||
else:
|
||||
init_device = None
|
||||
chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
|
||||
model = GeminiDDP(model,
|
||||
chunk_config_dict=config_dict,
|
||||
chunk_init_device=init_device,
|
||||
pin_memory=True,
|
||||
**placement_config)
|
||||
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=32, clipping_norm=1.0)
|
||||
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=32, clipping_norm=1.0)
|
||||
|
||||
model.train()
|
||||
torch_model.train()
|
||||
|
@@ -11,15 +11,32 @@ from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer, post_process_colo_init_ctx, zero_model_wrapper
|
||||
from colossalai.zero.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration
|
||||
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
||||
from colossalai.zero.gemini.chunk import search_chunk_configuration
|
||||
from tests.components_to_test import run_fwd_bwd
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_tensor.common_utils import debug_print, set_seed
|
||||
from tests.test_tensor.common_utils import set_seed
|
||||
|
||||
PLACEMENT_CONFIGS = [
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0
|
||||
}, # zero2
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 1.0
|
||||
}, # zero3
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.5
|
||||
}, # zero3-half
|
||||
{
|
||||
'placement_policy': 'auto'
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
|
||||
def check_param(model: GeminiDDP, torch_model: torch.nn.Module):
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
torch_dict = torch_model.state_dict()
|
||||
|
||||
@@ -32,35 +49,24 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
|
||||
assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3)
|
||||
|
||||
|
||||
def multi_chunk_init(model: torch.nn.Module, placement_policy: str):
|
||||
def multi_chunk_init(model: torch.nn.Module, placement_config: dict):
|
||||
world_size = dist.get_world_size()
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = False
|
||||
if placement_policy != 'cuda':
|
||||
init_device = torch.device('cpu')
|
||||
else:
|
||||
init_device = None
|
||||
chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
model = GeminiDDP(model, config_dict, pin_memory=True, **placement_config)
|
||||
return model
|
||||
|
||||
|
||||
def single_chunk_init(model: torch.nn.Module, placement_policy: str):
|
||||
gemini_config = dict(
|
||||
device=get_current_device(),
|
||||
placement_policy=placement_policy,
|
||||
pin_memory=True,
|
||||
)
|
||||
model = zero_model_wrapper(model=model, zero_stage=3, gemini_config=gemini_config)
|
||||
def single_chunk_init(model: torch.nn.Module, placement_config: dict):
|
||||
model = GeminiDDP(model, chunk_init_device=get_current_device(), pin_memory=True, **placement_config)
|
||||
return model
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
||||
@parameterize('placement_config', PLACEMENT_CONFIGS)
|
||||
@parameterize('model_name', ['gpt2'])
|
||||
@parameterize('model_init_func', [single_chunk_init, multi_chunk_init])
|
||||
def exam_inference(placement_policy: str, model_name: str, model_init_func: Callable):
|
||||
def exam_inference(placement_config: dict, model_name: str, model_init_func: Callable):
|
||||
set_seed(19360226)
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
@@ -70,17 +76,15 @@ def exam_inference(placement_policy: str, model_name: str, model_init_func: Call
|
||||
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
||||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
|
||||
|
||||
init_dev = get_current_device()
|
||||
with ColoInitContext(device=init_dev):
|
||||
model = model_builder()
|
||||
model = model_builder().to(init_dev)
|
||||
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
p.data.copy_(torch_p.data)
|
||||
|
||||
model = model_init_func(model, placement_policy)
|
||||
model = model_init_func(model, placement_config)
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128)
|
||||
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=128)
|
||||
|
||||
model.eval()
|
||||
torch_model.eval()
|
||||
@@ -95,7 +99,7 @@ def exam_inference(placement_policy: str, model_name: str, model_init_func: Call
|
||||
torch_optim.zero_grad()
|
||||
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
|
||||
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
|
||||
assert_close(torch_loss, loss)
|
||||
assert_close(torch_loss, loss, rtol=1e-5, atol=1e-5)
|
||||
zero_optim.step()
|
||||
torch_optim.step()
|
||||
check_param(model, torch_model)
|
||||
|
@@ -9,12 +9,46 @@ from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer, post_process_colo_init_ctx
|
||||
from colossalai.zero.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration
|
||||
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
||||
from colossalai.zero.gemini.chunk import search_chunk_configuration
|
||||
from tests.components_to_test import run_fwd_bwd
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_tensor.common_utils import debug_print, set_seed
|
||||
from tests.test_tensor.common_utils import set_seed
|
||||
|
||||
PLACEMENT_CONFIGS = [
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0,
|
||||
'offload_optim_frac': 0.0
|
||||
}, # zero2
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0,
|
||||
'offload_optim_frac': 1.0
|
||||
}, # zero2-offload
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0,
|
||||
'offload_optim_frac': 0.5
|
||||
}, # zero2-offload-half
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 1.0
|
||||
}, # zero3
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.5
|
||||
}, # zero3-half
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 1.0,
|
||||
'offload_optim_frac': 1.0,
|
||||
'offload_param_frac': 1.0
|
||||
}, # zero3-offload-all
|
||||
{
|
||||
'placement_policy': 'auto'
|
||||
}
|
||||
]
|
||||
|
||||
# this model is large enough to slice to chunks
|
||||
TEST_MODELS = ['gpt2']
|
||||
@@ -29,7 +63,7 @@ BF16_IGNORED_KEYS = [
|
||||
]
|
||||
|
||||
|
||||
def check_param(model: ZeroDDP, torch_model: torch.nn.Module, dtype: torch.dtype):
|
||||
def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dtype):
|
||||
zero_dict = model.state_dict(only_rank_0=False, dtype=dtype)
|
||||
torch_dict = torch_model.state_dict()
|
||||
|
||||
@@ -51,10 +85,10 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module, dtype: torch.dtype
|
||||
msg=lambda s: s + f'\n{key}\n{temp_zero_value.dtype}')
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
||||
@parameterize('placement_config', PLACEMENT_CONFIGS)
|
||||
@parameterize('model_name', TEST_MODELS)
|
||||
@parameterize('mixed_precision', [torch.half, torch.bfloat16])
|
||||
def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dtype):
|
||||
def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype):
|
||||
set_seed(42)
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
@@ -65,9 +99,7 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt
|
||||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
|
||||
|
||||
init_dev = get_current_device()
|
||||
with ColoInitContext(device=init_dev):
|
||||
model = model_builder()
|
||||
model = model_builder().cuda()
|
||||
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
p.data.copy_(torch_p.data)
|
||||
@@ -76,16 +108,10 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = False
|
||||
if placement_policy != 'cuda':
|
||||
init_device = torch.device('cpu')
|
||||
else:
|
||||
init_device = None
|
||||
chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision)
|
||||
model = GeminiDDP(model, config_dict, **placement_config, mixed_precision=mixed_precision)
|
||||
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128)
|
||||
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=128)
|
||||
|
||||
model.eval()
|
||||
torch_model.eval()
|
||||
@@ -109,10 +135,10 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt
|
||||
check_param(model, torch_model, mixed_precision)
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
||||
@parameterize('placement_config', PLACEMENT_CONFIGS)
|
||||
@parameterize('model_name', EXAMPLE_MODELS)
|
||||
@parameterize('mixed_precision', [torch.half, torch.bfloat16])
|
||||
def exam_tiny_example(placement_policy, model_name: str, mixed_precision: torch.dtype):
|
||||
def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch.dtype):
|
||||
set_seed(2008)
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
@@ -123,18 +149,19 @@ def exam_tiny_example(placement_policy, model_name: str, mixed_precision: torch.
|
||||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
|
||||
|
||||
init_dev = get_current_device()
|
||||
with ColoInitContext(device=init_dev):
|
||||
model = model_builder()
|
||||
model = model_builder().cuda()
|
||||
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
p.data.copy_(torch_p.data)
|
||||
|
||||
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_m=1)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision)
|
||||
model = GeminiDDP(model,
|
||||
chunk_init_device=get_current_device(),
|
||||
search_range_m=1,
|
||||
pin_memory=True,
|
||||
mixed_precision=mixed_precision,
|
||||
**placement_config)
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2)
|
||||
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=2)
|
||||
|
||||
model.eval()
|
||||
torch_model.eval()
|
||||
|
@@ -1,15 +1,16 @@
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.testing import clear_cache_before_run
|
||||
from colossalai.zero import ColoInitContext
|
||||
from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer
|
||||
from tests.components_to_test import run_fwd_bwd
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
@pytest.mark.skip("this is not used")
|
||||
@clear_cache_before_run()
|
||||
def test_runtime_mem_tracer():
|
||||
test_models = ['gpt2', 'bert', 'simple_net', 'repeated_computed_layers', 'nested_model', 'albert']
|
||||
@@ -18,8 +19,7 @@ def test_runtime_mem_tracer():
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, _, _, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device='cpu'):
|
||||
model = model_builder(checkpoint=False)
|
||||
model = model_builder(checkpoint=False).cuda()
|
||||
|
||||
model_bk = deepcopy(model)
|
||||
runtime_mem_tracer = RuntimeMemTracer(model)
|
||||
|
@@ -2,33 +2,20 @@ import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from colossalai.zero.gemini.chunk import init_chunk_manager, search_chunk_configuration
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
def init_1d_row_spec(model, pg: ProcessGroup):
|
||||
tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
for n, p in model.named_parameters():
|
||||
if 'weight' in n and 'ln' not in n:
|
||||
p.set_process_group(pg)
|
||||
p.set_tensor_spec(*tensor_spec)
|
||||
|
||||
|
||||
def exam_search_chunk_size():
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg_tp = ProcessGroup(tp_degree=world_size)
|
||||
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
# make sure torch_model and model has the same parameter values
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
init_1d_row_spec(model, pg_tp)
|
||||
model = model_builder()
|
||||
config_dict, *_ = search_chunk_configuration(model,
|
||||
search_range_m=1,
|
||||
search_interval=16,
|
||||
@@ -37,57 +24,19 @@ def exam_search_chunk_size():
|
||||
|
||||
for key in config_dict:
|
||||
chunk_size = config_dict[key]['chunk_size']
|
||||
if world_size == 1:
|
||||
if world_size == 1 or True:
|
||||
assert chunk_size == 31616
|
||||
else:
|
||||
assert chunk_size == 1024
|
||||
|
||||
|
||||
def exam_search_strict_ddp():
|
||||
world_size = torch.distributed.get_world_size()
|
||||
default_shard_pg = ProcessGroup(tp_degree=world_size)
|
||||
default_shard_spec = ShardSpec([-1], [world_size])
|
||||
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
# get the chunk configuration over replicated models
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
ddp_model = model_builder()
|
||||
re_dict, re_total, re_wasted = search_chunk_configuration(ddp_model,
|
||||
search_range_m=1,
|
||||
search_interval=16,
|
||||
min_chunk_size_m=0,
|
||||
filter_exlarge_params=True,
|
||||
strict_ddp_flag=False)
|
||||
# get the chunk configuration over sharded ddp models
|
||||
with ColoInitContext(device=get_current_device(), default_pg=default_shard_pg,
|
||||
default_dist_spec=default_shard_spec):
|
||||
sharded_ddp_model = model_builder()
|
||||
sh_dict, sh_total, sh_wasted = search_chunk_configuration(sharded_ddp_model,
|
||||
search_range_m=1,
|
||||
search_interval=16,
|
||||
min_chunk_size_m=0,
|
||||
filter_exlarge_params=True,
|
||||
strict_ddp_flag=True)
|
||||
assert re_dict == sh_dict
|
||||
for key in re_dict:
|
||||
assert re_dict[key] == sh_dict[key]
|
||||
|
||||
assert re_total == sh_total
|
||||
assert re_wasted == sh_wasted
|
||||
|
||||
|
||||
def exam_chunk_manager():
|
||||
world_size = torch.distributed.get_world_size()
|
||||
default_shard_pg = ProcessGroup(tp_degree=world_size)
|
||||
default_shard_spec = ShardSpec([-1], [world_size])
|
||||
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device(), default_pg=default_shard_pg,
|
||||
default_dist_spec=default_shard_spec):
|
||||
sharded_ddp_model = model_builder()
|
||||
sharded_ddp_model = model_builder()
|
||||
chunk_manager = init_chunk_manager(sharded_ddp_model,
|
||||
get_current_device(),
|
||||
hidden_dim=16,
|
||||
@@ -103,7 +52,6 @@ def exam_chunk_manager():
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
exam_search_chunk_size()
|
||||
exam_search_strict_ddp()
|
||||
exam_chunk_manager()
|
||||
|
||||
|
||||
|
@@ -4,31 +4,46 @@ from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext, ZeroDDP
|
||||
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.zero import GeminiDDP
|
||||
from colossalai.zero.gemini.chunk import search_chunk_configuration
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_tensor.common_utils import debug_print, set_seed
|
||||
from tests.test_tensor.common_utils import set_seed
|
||||
|
||||
PLACEMENT_CONFIGS = [
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0
|
||||
}, # zero2
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 1.0
|
||||
}, # zero3
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.5
|
||||
}, # zero3-half
|
||||
{
|
||||
'placement_policy': 'auto'
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def ignore_the_first_parameter(model: torch.nn.Module):
|
||||
for name, param in model.named_parameters():
|
||||
print(f"parameter `{name}` is set ignored")
|
||||
ZeroDDP.set_params_to_ignore([param])
|
||||
GeminiDDP.set_params_to_ignore([param])
|
||||
return
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
|
||||
@parameterize('placement_config', PLACEMENT_CONFIGS)
|
||||
@parameterize('keep_gathered', [True, False])
|
||||
@parameterize('model_name', ['gpt2', 'bert'])
|
||||
def exam_state_dict(placement_policy, keep_gathered, model_name: str):
|
||||
def exam_state_dict(placement_config, keep_gathered, model_name: str):
|
||||
set_seed(431)
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
model = model_builder()
|
||||
|
||||
torch_model = model_builder()
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
@@ -38,9 +53,7 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = keep_gathered
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True)
|
||||
model.train()
|
||||
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
@@ -52,16 +65,15 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
|
||||
assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5)
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
|
||||
@parameterize('placement_config', PLACEMENT_CONFIGS)
|
||||
@parameterize('keep_gathered', [True, False])
|
||||
@parameterize('model_name', ['gpt2', 'bert'])
|
||||
def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
|
||||
def exam_load_state_dict(placement_config, keep_gathered, model_name: str):
|
||||
set_seed(431)
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
model = model_builder()
|
||||
|
||||
set_seed(451)
|
||||
torch_model = model_builder() # get a different model
|
||||
@@ -71,13 +83,7 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = keep_gathered
|
||||
|
||||
if placement_policy != 'cuda':
|
||||
init_device = torch.device('cpu')
|
||||
else:
|
||||
init_device = None
|
||||
chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True)
|
||||
|
||||
torch_dict = torch_model.state_dict()
|
||||
model.load_state_dict(torch_dict, strict=False)
|
||||
@@ -89,11 +95,37 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
|
||||
assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5)
|
||||
|
||||
|
||||
@parameterize('placement_config', PLACEMENT_CONFIGS)
|
||||
@parameterize('model_name', ['gpt2', 'bert'])
|
||||
def exam_state_dict_shard(placement_config, model_name: str):
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
model = model_builder()
|
||||
|
||||
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
|
||||
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
model = GeminiDDP(model, config_dict, **placement_config)
|
||||
model.train()
|
||||
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
accumulated_keys = set()
|
||||
# ensure number of shards > 1
|
||||
for shard, _ in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False):
|
||||
for key, value in shard.items():
|
||||
assert key not in accumulated_keys, f"key `{key}` is duplicated."
|
||||
accumulated_keys.add(key)
|
||||
assert key in zero_dict, f"{key} not in ZeRO dictionary."
|
||||
assert torch.equal(value, zero_dict[key]), f"{key} not equal."
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
exam_state_dict()
|
||||
exam_load_state_dict()
|
||||
exam_state_dict_shard()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
@@ -1,56 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext, ZeroDDP
|
||||
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||
@parameterize('model_name', ['gpt2', 'bert'])
|
||||
def exam_state_dict(placement_policy, model_name: str):
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
|
||||
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
|
||||
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager)
|
||||
model.train()
|
||||
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
accumulated_keys = set()
|
||||
# ensure number of shards > 1
|
||||
for shard, _ in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False):
|
||||
for key, value in shard.items():
|
||||
assert key not in accumulated_keys, f"key `{key}` is duplicated."
|
||||
accumulated_keys.add(key)
|
||||
assert key in zero_dict, f"{key} not in ZeRO dictionary."
|
||||
assert torch.equal(value, zero_dict[key]), f"{key} not equal."
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
exam_state_dict()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_ddp_state_dict_shard(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_ddp_state_dict_shard(1)
|
@@ -5,42 +5,53 @@ import torch.distributed as dist
|
||||
import colossalai
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer
|
||||
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
||||
from colossalai.zero.gemini.chunk import search_chunk_configuration
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_tensor.common_utils import debug_print, set_seed
|
||||
from tests.test_tensor.common_utils import set_seed
|
||||
|
||||
PLACEMENT_CONFIGS = [
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0,
|
||||
'offload_optim_frac': 0.0
|
||||
}, # zero2
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0,
|
||||
'offload_optim_frac': 1.0
|
||||
}, # zero2-offload
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0,
|
||||
'offload_optim_frac': 0.5
|
||||
}, # zero2-offload-half
|
||||
{
|
||||
'placement_policy': 'auto'
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
|
||||
@parameterize('placement_config', PLACEMENT_CONFIGS)
|
||||
@parameterize('keep_gathered', [True, False])
|
||||
def exam_zero_optim_state_dict(placement_policy, keep_gathered):
|
||||
def exam_zero_optim_state_dict(placement_config, keep_gathered):
|
||||
set_seed(431)
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
model = model_builder()
|
||||
|
||||
set_seed(451)
|
||||
torch_model = model_builder() # get a different model
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = keep_gathered
|
||||
|
||||
if placement_policy != 'cuda':
|
||||
init_device = torch.device('cpu')
|
||||
else:
|
||||
init_device = None
|
||||
chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True)
|
||||
|
||||
optimizer = HybridAdam(model.parameters())
|
||||
optim = ZeroOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32
|
||||
optim = GeminiOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32
|
||||
|
||||
set_seed(dist.get_rank() * 3 + 128)
|
||||
model.train()
|
||||
|
@@ -58,17 +58,8 @@ def exam_zero_1_2_grad_acc():
|
||||
assert torch.equal(zero1_output, zero2_output)
|
||||
|
||||
# zero-dp backward
|
||||
no_sync = number == 0
|
||||
with conditional_context(zero1_optimizer.no_sync(), no_sync):
|
||||
zero1_optimizer.backward(zero1_output.sum().float())
|
||||
with conditional_context(zero2_optimizer.no_sync(), no_sync):
|
||||
zero2_optimizer.backward(zero2_output.sum().float())
|
||||
|
||||
if check_flag:
|
||||
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
|
||||
if z2p.grad is not None:
|
||||
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
|
||||
assert torch.equal(z1p.grad, z2p.grad)
|
||||
zero1_optimizer.backward(zero1_output.sum().float())
|
||||
zero2_optimizer.backward(zero2_output.sum().float())
|
||||
|
||||
fwd_bwd_func(0, input_data1, True)
|
||||
fwd_bwd_func(1, input_data2, False)
|
||||
@@ -82,7 +73,7 @@ def exam_zero_1_2_grad_acc():
|
||||
assert torch.equal(z1p.data, z2p.data)
|
||||
|
||||
|
||||
def exam_zero_1_grad_acc():
|
||||
def exam_zero_1_grad_acc(sync):
|
||||
local_rank = torch.distributed.get_rank()
|
||||
seed_all(2008)
|
||||
|
||||
@@ -112,9 +103,8 @@ def exam_zero_1_grad_acc():
|
||||
input_data1 = torch.randn(32, 128).cuda()
|
||||
input_data2 = torch.randn(32, 128).cuda()
|
||||
|
||||
def fwd_bwd_func(number, cur_data, check_flag):
|
||||
def fwd_bwd_func(no_sync, cur_data, check_flag):
|
||||
|
||||
no_sync = number == 0
|
||||
# zero1 fwd and bwd
|
||||
with conditional_context(zero_optimizer.no_sync(), no_sync):
|
||||
zero_output = zero_model(cur_data)
|
||||
@@ -131,8 +121,8 @@ def exam_zero_1_grad_acc():
|
||||
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
||||
assert torch.equal(p.grad, z1p.grad)
|
||||
|
||||
fwd_bwd_func(0, input_data1, True)
|
||||
fwd_bwd_func(1, input_data2, False)
|
||||
fwd_bwd_func(sync, input_data1, sync)
|
||||
fwd_bwd_func(False, input_data2, False)
|
||||
|
||||
zero_optimizer.step()
|
||||
torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0)
|
||||
@@ -147,9 +137,9 @@ def exam_zero_1_grad_acc():
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
|
||||
exam_zero_1_grad_acc()
|
||||
# gradient accumulation is not compatible with ZeRO-2
|
||||
# exam_zero_1_2_grad_acc()
|
||||
exam_zero_1_grad_acc(sync=True)
|
||||
exam_zero_1_grad_acc(sync=False)
|
||||
exam_zero_1_2_grad_acc()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
@@ -37,7 +37,7 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32):
|
||||
atol = 4e-3
|
||||
|
||||
a = a.detach().to(dtype)
|
||||
b = b.detach().to(dtype)
|
||||
b = b.detach().to(dtype).to(a.device)
|
||||
|
||||
assert_close(a, b, rtol=rtol, atol=atol)
|
||||
|
||||
|
@@ -1,55 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.testing import spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext, LowLevelZeroOptimizer
|
||||
|
||||
|
||||
class MlpModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(MlpModel, self).__init__()
|
||||
self.linear1 = nn.Linear(128, 256)
|
||||
self.linear2 = nn.Linear(256, 512)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
|
||||
def exam_zero_init():
|
||||
dp_2_tp_2_pg = ProcessGroup(dp_degree=2, tp_degree=2)
|
||||
model1 = MlpModel().cuda()
|
||||
with ColoInitContext(device=get_current_device(), default_pg=dp_2_tp_2_pg):
|
||||
model2 = MlpModel()
|
||||
optimizer1 = LowLevelZeroOptimizer(torch.optim.Adam(model1.parameters(), lr=1))
|
||||
optimizer2 = LowLevelZeroOptimizer(torch.optim.Adam(model2.parameters(), lr=1))
|
||||
|
||||
assert optimizer1._local_rank == optimizer2._local_rank
|
||||
assert optimizer1._world_size == optimizer2._world_size
|
||||
|
||||
mp_group1 = optimizer1.tp_pg
|
||||
mp_group2 = optimizer2.tp_pg
|
||||
assert dist.get_world_size(mp_group1) == dist.get_world_size(mp_group2)
|
||||
assert dist.get_rank(mp_group1) == dist.get_rank(mp_group2)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config_dict = dict(parallel=dict(data=2, tensor=dict(size=2, mode='1d')))
|
||||
colossalai.launch(config=config_dict, rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
exam_zero_init()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_zero_init():
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_init()
|
@@ -85,6 +85,7 @@ def run_dist(rank, world_size, port):
|
||||
exam_zero_with_tp()
|
||||
|
||||
|
||||
@pytest.mark.skip('this will be rewritten by shardformer')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_with_tp():
|
||||
|
Reference in New Issue
Block a user