mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[booster] add low level zero plugin (#3594)
* [booster] add low level zero plugin * [booster] fix gemini plugin test * [booster] fix precision * [booster] add low level zero plugin test * [test] fix booster plugin test oom * [test] fix booster plugin test oom * [test] fix googlenet and inception output trans * [test] fix diffuser clip vision model * [test] fix torchaudio_wav2vec2_base * [test] fix low level zero plugin test
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -10,11 +11,53 @@ from colossalai.fx import is_compatible_with_meta
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.model.experimental import LazyInitContext
|
||||
from colossalai.zero import ColoInitContext
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
@parameterize('init_method', ['lazy', 'none', 'colo'])
|
||||
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
|
||||
try:
|
||||
if init_method == 'colo':
|
||||
ctx = ColoInitContext()
|
||||
elif init_method == 'lazy':
|
||||
ctx = LazyInitContext()
|
||||
else:
|
||||
ctx = nullcontext()
|
||||
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, max_norm=1.0, initial_scale=2**5)
|
||||
booster = Booster(plugin=plugin)
|
||||
with ctx:
|
||||
model = model_fn()
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
criterion = lambda x: x.mean()
|
||||
data = data_gen_fn()
|
||||
|
||||
data = {
|
||||
k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()
|
||||
}
|
||||
|
||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||
|
||||
for n, p in model.named_parameters():
|
||||
assert isinstance(p, ColoParameter), f'{n} is not a ColoParameter'
|
||||
|
||||
output = model(**data)
|
||||
output = output_transform_fn(output)
|
||||
output_key = list(output.keys())[0]
|
||||
loss = criterion(output[output_key])
|
||||
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
|
||||
except Exception as e:
|
||||
return repr(e)
|
||||
|
||||
|
||||
# TODO(ver217): CI does not support lazy now
|
||||
# @parameterize('init_method', ['lazy', 'none', 'colo'])
|
||||
|
||||
|
||||
@parameterize('init_method', ['none'])
|
||||
def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
|
||||
"""check gemini plugin over model zoo
|
||||
|
||||
@@ -25,7 +68,6 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
|
||||
if not is_support_meta and init_method == 'lazy':
|
||||
return
|
||||
|
||||
from colossalai.utils.model.experimental import LazyInitContext
|
||||
passed_models = []
|
||||
failed_info = {} # (model_name, error) pair
|
||||
|
||||
@@ -58,48 +100,16 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
|
||||
]:
|
||||
continue
|
||||
|
||||
try:
|
||||
if init_method == 'colo':
|
||||
ctx = ColoInitContext()
|
||||
elif init_method == 'lazy':
|
||||
ctx = LazyInitContext()
|
||||
else:
|
||||
ctx = nullcontext()
|
||||
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, max_norm=1.0, initial_scale=2**5)
|
||||
booster = Booster(plugin=plugin)
|
||||
with ctx:
|
||||
model = model_fn()
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
criterion = lambda x: x.mean()
|
||||
data = data_gen_fn()
|
||||
|
||||
data = {
|
||||
k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
|
||||
for k, v in data.items()
|
||||
}
|
||||
|
||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||
|
||||
for n, p in model.named_parameters():
|
||||
assert isinstance(p, ColoParameter), f'{n} is not a ColoParameter'
|
||||
|
||||
output = model(**data)
|
||||
output = output_transform_fn(output)
|
||||
output_key = list(output.keys())[0]
|
||||
loss = criterion(output[output_key])
|
||||
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
passed_models.append(name)
|
||||
|
||||
del booster, plugin, model, optimizer, criterion, data, output, loss
|
||||
except Exception as e:
|
||||
failed_info[name] = e
|
||||
if early_stop:
|
||||
raise e
|
||||
|
||||
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if err is None:
|
||||
passed_models.append(name)
|
||||
else:
|
||||
failed_info[name] = err
|
||||
if early_stop:
|
||||
break
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
print(f'Init method: {init_method}')
|
||||
print(f'Passed models({len(passed_models)}): {passed_models}\n\n')
|
||||
@@ -140,7 +150,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gemini_plugin(early_stop: bool = True):
|
||||
spawn(run_dist, 2, early_stop=early_stop)
|
||||
spawn(run_dist, 4, early_stop=early_stop)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Reference in New Issue
Block a user