[test] polish zero related unitest (#351)

This commit is contained in:
Jiarui Fang
2022-03-10 09:57:26 +08:00
committed by Frank Lee
parent 534e0bb118
commit cb34cd384d
5 changed files with 75 additions and 123 deletions

View File

@@ -3,8 +3,10 @@ from functools import partial
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.logging import get_dist_logger
from colossalai.utils import checkpoint
from colossalai.zero.sharded_model import ShardedModelV2
LOGGER = get_dist_logger()
@@ -20,6 +22,21 @@ CONFIG = dict(fp16=dict(mode=None,),
parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)))
def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
if isinstance(model, ShardedModelV2):
model.backward(loss)
else:
loss.backward()
def checkpoint_wrapper(module, enable=True):
if enable:
module.forward = partial(checkpoint, module.forward)