impl shard optim v2 and add unit test

This commit is contained in:
ver217
2022-03-04 11:49:02 +08:00
committed by Frank Lee
parent 74f77e314b
commit 001ca624dd
4 changed files with 97 additions and 6 deletions

View File

@@ -1,12 +1,16 @@
import torch
from . import BaseOpHook
import torch.distributed as dist
from colossalai.registry import OPHOOKS
from . import BaseOpHook
@OPHOOKS.register_module
class ShardParamHook(BaseOpHook):
"""
A hook to process sharded param before and afther FWD and BWD operator executing.
"""
def __init__(self):
super().__init__()
@@ -17,25 +21,32 @@ class ShardParamHook(BaseOpHook):
for param in module.parameters():
assert hasattr(param, 'ca_attr')
param.ca_attr.gather()
if dist.get_rank() == 0:
print(f'{param._name} pre fwd shape {param.ca_attr.payload("cpu").shape}')
def post_fwd_exec(self, module: torch.nn.Module, *args):
for param in module.parameters():
assert hasattr(param, 'ca_attr')
param.ca_attr.shard()
if dist.get_rank() == 0:
print(f'{param._name} post fwd shape {param.ca_attr.payload("cpu").shape}')
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
for param in module.parameters():
assert hasattr(param, 'ca_attr')
param.ca_attr.gather()
if dist.get_rank() == 0:
print(f'{param._name} pre bwd shape {param.ca_attr.payload("cpu").shape}')
def post_bwd_exec(self, module: torch.nn.Module, input):
for param in module.parameters():
assert hasattr(param, 'ca_attr')
param.ca_attr.shard()
if dist.get_rank() == 0:
print(f'{param._name} post bwd shape {param.ca_attr.payload("cpu").shape}')
def pre_iter(self):
pass
def post_iter(self):
pass