mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 02:20:49 +00:00
impl shard optim v2 and add unit test
This commit is contained in:
@@ -1,12 +1,16 @@
|
||||
import torch
|
||||
from . import BaseOpHook
|
||||
import torch.distributed as dist
|
||||
from colossalai.registry import OPHOOKS
|
||||
|
||||
from . import BaseOpHook
|
||||
|
||||
|
||||
@OPHOOKS.register_module
|
||||
class ShardParamHook(BaseOpHook):
|
||||
"""
|
||||
A hook to process sharded param before and afther FWD and BWD operator executing.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@@ -17,25 +21,32 @@ class ShardParamHook(BaseOpHook):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
param.ca_attr.gather()
|
||||
if dist.get_rank() == 0:
|
||||
print(f'{param._name} pre fwd shape {param.ca_attr.payload("cpu").shape}')
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
param.ca_attr.shard()
|
||||
if dist.get_rank() == 0:
|
||||
print(f'{param._name} post fwd shape {param.ca_attr.payload("cpu").shape}')
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
param.ca_attr.gather()
|
||||
if dist.get_rank() == 0:
|
||||
print(f'{param._name} pre bwd shape {param.ca_attr.payload("cpu").shape}')
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
param.ca_attr.shard()
|
||||
if dist.get_rank() == 0:
|
||||
print(f'{param._name} post bwd shape {param.ca_attr.payload("cpu").shape}')
|
||||
|
||||
def pre_iter(self):
|
||||
pass
|
||||
|
||||
def post_iter(self):
|
||||
pass
|
||||
|
||||
|
Reference in New Issue
Block a user