mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 07:00:37 +00:00
Merge branch 'main' into sync/npu
This commit is contained in:
@@ -1,8 +1,11 @@
|
||||
import copy
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.testing import assert_close
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
@@ -10,10 +13,35 @@ from colossalai.booster.plugin import HybridParallelPlugin
|
||||
from colossalai.fx import is_compatible_with_meta
|
||||
from colossalai.lazy.lazy_init import LazyInitContext
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device, set_seed
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
class RandomDataset(Dataset):
|
||||
def __init__(self, num_samples: int = 100, max_length: int = 512, vocab_size: int = 32000):
|
||||
self.num_samples = num_samples
|
||||
self.max_length = max_length
|
||||
set_seed(42)
|
||||
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
|
||||
self.attention_mask = torch.ones_like(self.input_ids)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {
|
||||
"input_ids": self.input_ids[idx],
|
||||
"attention_mask": self.attention_mask[idx],
|
||||
"labels": self.input_ids[idx],
|
||||
}
|
||||
|
||||
|
||||
def move_to_cuda(batch):
|
||||
return {k: v.cuda() for k, v in batch.items()}
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
|
||||
try:
|
||||
if init_method == "lazy":
|
||||
@@ -69,7 +97,6 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True):
|
||||
"transformers_llama_for_casual_lm"
|
||||
).items():
|
||||
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if err is None:
|
||||
passed_models.append(name)
|
||||
@@ -85,10 +112,145 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True):
|
||||
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
|
||||
|
||||
|
||||
@parameterize(
|
||||
"test_args",
|
||||
[
|
||||
{
|
||||
"batch_size": 8,
|
||||
"num_steps": 4,
|
||||
"tp": 2,
|
||||
"pp": 2,
|
||||
"pp_style": "1f1b",
|
||||
"num_model_chunks": 1,
|
||||
"num_microbatches": 4,
|
||||
"zero": 0,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
"max_length": 512,
|
||||
"gradient_accumulation_step": 2,
|
||||
},
|
||||
{
|
||||
"batch_size": 8,
|
||||
"num_steps": 4,
|
||||
"tp": 1,
|
||||
"pp": 2,
|
||||
"pp_style": "1f1b",
|
||||
"num_model_chunks": 1,
|
||||
"num_microbatches": 4,
|
||||
"zero": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
"max_length": 512,
|
||||
"gradient_accumulation_step": 2,
|
||||
},
|
||||
{
|
||||
"batch_size": 1,
|
||||
"num_steps": 4,
|
||||
"tp": 2,
|
||||
"pp": 1,
|
||||
"pp_style": "1f1b",
|
||||
"num_model_chunks": 1,
|
||||
"num_microbatches": 1,
|
||||
"zero": 2,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
"max_length": 512,
|
||||
"gradient_accumulation_step": 2,
|
||||
},
|
||||
{
|
||||
"batch_size": 1,
|
||||
"num_steps": 4,
|
||||
"tp": 2,
|
||||
"pp": 1,
|
||||
"pp_style": "1f1b",
|
||||
"num_model_chunks": 1,
|
||||
"num_microbatches": 1,
|
||||
"zero": 0,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
"max_length": 512,
|
||||
"gradient_accumulation_step": 2,
|
||||
},
|
||||
],
|
||||
)
|
||||
def run_grad_acc_test(test_args):
|
||||
model_fn, *_ = next(iter(model_zoo.get_sub_registry("transformers_gpt_lm").values()))
|
||||
model = model_fn()
|
||||
optimizer = HybridAdam(model.parameters())
|
||||
origin_model = copy.deepcopy(model).cuda()
|
||||
origin_optimizer = HybridAdam(origin_model.parameters())
|
||||
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=test_args["tp"],
|
||||
pp_size=test_args["pp"],
|
||||
pp_style=test_args["pp_style"],
|
||||
zero_stage=test_args["zero"],
|
||||
num_model_chunks=test_args["num_model_chunks"],
|
||||
enable_fused_normalization=True,
|
||||
num_microbatches=test_args["num_microbatches"],
|
||||
precision=test_args["precision"],
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
dataset = RandomDataset(
|
||||
num_samples=test_args["batch_size"] * test_args["num_steps"] * plugin.dp_size,
|
||||
max_length=test_args["max_length"],
|
||||
vocab_size=model.config.vocab_size,
|
||||
)
|
||||
dataloader = plugin.prepare_dataloader(dataset, batch_size=test_args["batch_size"], shuffle=True, drop_last=True)
|
||||
|
||||
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
|
||||
|
||||
grad_accu_step = test_args["gradient_accumulation_step"]
|
||||
for step, batch in enumerate(dataloader):
|
||||
batch = move_to_cuda(batch)
|
||||
# train origin model
|
||||
origin_output = origin_model(**batch)
|
||||
origin_loss = origin_output[0] / grad_accu_step
|
||||
origin_loss.backward()
|
||||
|
||||
if (step + 1) % grad_accu_step != 0 and test_args["zero"] != 2:
|
||||
ctx = booster.no_sync(model, optimizer)
|
||||
else:
|
||||
ctx = nullcontext()
|
||||
|
||||
with ctx:
|
||||
if plugin.stage_manager is not None:
|
||||
batch = iter([batch])
|
||||
booster.execute_pipeline(
|
||||
batch,
|
||||
model,
|
||||
criterion=lambda outputs, inputs: outputs[0] / grad_accu_step,
|
||||
optimizer=optimizer,
|
||||
return_loss=False,
|
||||
)
|
||||
else:
|
||||
outputs = model(**batch)
|
||||
loss = outputs[0] / grad_accu_step
|
||||
booster.backward(loss, optimizer)
|
||||
|
||||
if (step + 1) % grad_accu_step == 0:
|
||||
# update origin model weight
|
||||
origin_optimizer.step()
|
||||
origin_optimizer.zero_grad()
|
||||
|
||||
# update sharded model
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# tricky code here, shard the origin model inorder to check the parameters in the same stage.
|
||||
origin_model, origin_optimizer, _, dataloader, _ = booster.boost(
|
||||
origin_model, origin_optimizer, dataloader=dataloader
|
||||
)
|
||||
for p1, p2 in zip(model.unwrap().parameters(), origin_model.unwrap().parameters()):
|
||||
assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, early_stop: bool = True):
|
||||
# init dist env
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_3d_plugin(early_stop=early_stop)
|
||||
run_grad_acc_test()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
import pytest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
@@ -11,13 +11,18 @@ from colossalai.booster.plugin import GeminiPlugin
|
||||
from colossalai.fx import is_compatible_with_meta
|
||||
from colossalai.lazy.lazy_init import LazyInitContext
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from colossalai.testing import (
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
skip_if_not_enough_gpus,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size) -> Optional[str]:
|
||||
try:
|
||||
if init_method == "lazy":
|
||||
@@ -26,7 +31,13 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, t
|
||||
ctx = nullcontext()
|
||||
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
|
||||
enable_all_optimization = True if tp_size > 1 else False
|
||||
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization)
|
||||
plugin = GeminiPlugin(
|
||||
max_norm=1.0,
|
||||
initial_scale=2**5,
|
||||
tp_size=tp_size,
|
||||
extra_dp_size=extra_dp_size,
|
||||
enable_all_optimization=enable_all_optimization,
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
with ctx:
|
||||
model = model_fn()
|
||||
@@ -62,11 +73,13 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, t
|
||||
# @parameterize('init_method', ['lazy', 'none', 'colo'])
|
||||
|
||||
|
||||
@parameterize("subset", ["torchvision", "transformers", "diffusers"])
|
||||
@parameterize("subset", [COMMON_MODELS] if IS_FAST_TEST else ["torchvision", "transformers", "diffusers"])
|
||||
@parameterize("init_method", ["none"])
|
||||
@parameterize("zero_size", [2])
|
||||
@parameterize("tp_size", [2])
|
||||
def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool = True, zero_size: int = 1, tp_size: int = 1):
|
||||
def check_gemini_plugin(
|
||||
subset: str, init_method: str = "none", early_stop: bool = True, zero_size: int = 1, tp_size: int = 1
|
||||
):
|
||||
"""check gemini plugin over model zoo
|
||||
|
||||
Args:
|
||||
@@ -105,6 +118,14 @@ def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool
|
||||
"transformers_sam",
|
||||
"transformers_vit",
|
||||
"transformers_gpt_double_heads", # TODO check why does the model fail to run using Gemini
|
||||
"transformers_falcon", # TODO check why falcon fails to run Gemini
|
||||
"transformers_falcon_for_causal_lm",
|
||||
"transformers_falcon_for_sequence_classification",
|
||||
"transformers_falcon_for_token_classification",
|
||||
"transformers_falcon_for_question_answering",
|
||||
"transformers_gptj_lm", # lead to OOM when running in ci
|
||||
"transformers_gptj_for_question_answering",
|
||||
"transformers_gptj_for_sequence_classification",
|
||||
]:
|
||||
continue
|
||||
|
||||
@@ -131,7 +152,6 @@ def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool
|
||||
tp_size = 1
|
||||
|
||||
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size)
|
||||
torch.cuda.empty_cache()
|
||||
if err is None:
|
||||
passed_models.append(name)
|
||||
else:
|
||||
@@ -156,11 +176,13 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
|
||||
def test_gemini_plugin(early_stop: bool = True):
|
||||
spawn(run_dist, 4, early_stop=early_stop)
|
||||
|
||||
|
||||
@pytest.mark.largedist
|
||||
@skip_if_not_enough_gpus(8)
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gemini_plugin_3d(early_stop: bool = True):
|
||||
spawn(run_dist, 8, early_stop=early_stop)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_gemini_plugin(early_stop=False)
|
||||
test_gemini_plugin(early_stop=False)
|
||||
|
@@ -10,8 +10,8 @@ from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||
|
||||
# from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo
|
||||
|
||||
# These models are not compatible with AMP
|
||||
_AMP_ERR_MODELS = ["timm_convit", "deepfm_interactionarch"]
|
||||
@@ -21,6 +21,7 @@ _LOW_LEVEL_ZERO_ERR_MODELS = ["dlrm_interactionarch"]
|
||||
_STUCK_MODELS = ["transformers_albert_for_multiple_choice"]
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
|
||||
device = get_accelerator().get_current_device()
|
||||
try:
|
||||
@@ -62,7 +63,12 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
|
||||
ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS
|
||||
skipped_models = []
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
|
||||
if IS_FAST_TEST:
|
||||
registry = model_zoo.get_sub_registry(COMMON_MODELS)
|
||||
else:
|
||||
registry = model_zoo
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in registry.items():
|
||||
# FIXME(ver217): fix these models
|
||||
if name in ignore_models:
|
||||
skipped_models.append(name)
|
||||
|
@@ -10,10 +10,11 @@ import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import TorchDDPPlugin
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
def run_fn(model_fn, data_gen_fn, output_transform_fn):
|
||||
plugin = TorchDDPPlugin()
|
||||
booster = Booster(plugin=plugin)
|
||||
@@ -40,7 +41,12 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn):
|
||||
|
||||
|
||||
def check_torch_ddp_plugin():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
|
||||
if IS_FAST_TEST:
|
||||
registry = model_zoo.get_sub_registry(COMMON_MODELS)
|
||||
else:
|
||||
registry = model_zoo
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in registry.items():
|
||||
if name == "dlrm_interactionarch":
|
||||
continue
|
||||
run_fn(model_fn, data_gen_fn, output_transform_fn)
|
||||
|
@@ -11,11 +11,12 @@ if version.parse(torch.__version__) >= version.parse("1.12.0"):
|
||||
from colossalai.booster.plugin import TorchFSDPPlugin
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo
|
||||
|
||||
|
||||
# test basic fsdp function
|
||||
@clear_cache_before_run()
|
||||
def run_fn(model_fn, data_gen_fn, output_transform_fn):
|
||||
plugin = TorchFSDPPlugin()
|
||||
booster = Booster(plugin=plugin)
|
||||
@@ -40,9 +41,20 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn):
|
||||
optimizer.clip_grad_by_norm(1.0)
|
||||
optimizer.step()
|
||||
|
||||
del model
|
||||
del optimizer
|
||||
del criterion
|
||||
del booster
|
||||
del plugin
|
||||
|
||||
|
||||
def check_torch_fsdp_plugin():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
|
||||
if IS_FAST_TEST:
|
||||
registry = model_zoo.get_sub_registry(COMMON_MODELS)
|
||||
else:
|
||||
registry = model_zoo.get_sub_registry("transformers_gptj")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in registry.items():
|
||||
if any(
|
||||
element in name
|
||||
for element in [
|
||||
@@ -54,6 +66,7 @@ def check_torch_fsdp_plugin():
|
||||
]
|
||||
):
|
||||
continue
|
||||
print(name)
|
||||
run_fn(model_fn, data_gen_fn, output_transform_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -68,3 +81,7 @@ def run_dist(rank, world_size, port):
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_torch_fsdp_plugin():
|
||||
spawn(run_dist, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_torch_fsdp_plugin()
|
||||
|
Reference in New Issue
Block a user