mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-24 09:17:58 +00:00
[test] polish zero related unitest (#351)
This commit is contained in:
@@ -3,8 +3,10 @@ from functools import partial
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import checkpoint
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
|
||||
LOGGER = get_dist_logger()
|
||||
|
||||
@@ -20,6 +22,21 @@ CONFIG = dict(fp16=dict(mode=None,),
|
||||
parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)))
|
||||
|
||||
|
||||
def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
|
||||
model.train()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
if criterion:
|
||||
y = model(data)
|
||||
loss = criterion(y, label)
|
||||
else:
|
||||
loss = model(data, label)
|
||||
loss = loss.float()
|
||||
if isinstance(model, ShardedModelV2):
|
||||
model.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
|
||||
def checkpoint_wrapper(module, enable=True):
|
||||
if enable:
|
||||
module.forward = partial(checkpoint, module.forward)
|
||||
|
||||
Reference in New Issue
Block a user