mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-11-03 07:33:54 +00:00
fix sharded param hook and unit test
This commit is contained in:
@@ -3,37 +3,21 @@ from functools import partial
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import checkpoint
|
||||
|
||||
LOGGER = get_dist_logger()
|
||||
|
||||
CONFIG = dict(
|
||||
fp16=dict(
|
||||
mode=None,
|
||||
),
|
||||
zero=dict(
|
||||
level=3,
|
||||
verbose=False,
|
||||
offload_optimizer_config=dict(
|
||||
device='cpu',
|
||||
pin_memory=True,
|
||||
buffer_count=5,
|
||||
fast_init=False
|
||||
),
|
||||
offload_param_config=dict(
|
||||
device='cpu',
|
||||
pin_memory=True,
|
||||
buffer_count=5,
|
||||
buffer_size=1e8,
|
||||
max_in_cpu=1e9
|
||||
)
|
||||
),
|
||||
parallel=dict(
|
||||
pipeline=dict(size=1),
|
||||
tensor=dict(size=1, mode=None)
|
||||
)
|
||||
)
|
||||
CONFIG = dict(fp16=dict(mode=None,),
|
||||
zero=dict(level=3,
|
||||
verbose=False,
|
||||
offload_optimizer_config=dict(device='cpu', pin_memory=True, buffer_count=5, fast_init=False),
|
||||
offload_param_config=dict(device='cpu',
|
||||
pin_memory=True,
|
||||
buffer_count=5,
|
||||
buffer_size=1e8,
|
||||
max_in_cpu=1e9)),
|
||||
parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)))
|
||||
|
||||
|
||||
def checkpoint_wrapper(module, enable=True):
|
||||
@@ -43,6 +27,7 @@ def checkpoint_wrapper(module, enable=True):
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
|
||||
def __init__(self, checkpoint=False) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(5, 5)
|
||||
@@ -50,13 +35,7 @@ class Net(nn.Module):
|
||||
self.fc3 = nn.Linear(5, 1)
|
||||
if checkpoint:
|
||||
self.fc1 = checkpoint_wrapper(self.fc1)
|
||||
self.layers = [
|
||||
self.fc1,
|
||||
self.fc2,
|
||||
self.fc1,
|
||||
self.fc2,
|
||||
self.fc3
|
||||
]
|
||||
self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3]
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
@@ -111,3 +90,17 @@ def check_params_padding(model, zero_model, loose=False):
|
||||
zero_p = zero_p[:p.size(0)]
|
||||
assert p.dtype == zero_p.dtype
|
||||
assert allclose(p, zero_p, loose=loose)
|
||||
|
||||
|
||||
def check_sharded_params_padding(model, zero_model, loose=False):
|
||||
rank = dist.get_rank()
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_p = zero_p.ca_attr.payload(p.device)
|
||||
chunks = torch.flatten(p).chunk(dist.get_world_size())
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
p = chunks[rank]
|
||||
if zero_p.size(0) > p.size(0):
|
||||
zero_p = zero_p[:p.size(0)]
|
||||
assert p.dtype == zero_p.dtype
|
||||
assert allclose(p, zero_p, loose=loose)
|
||||
|
||||
Reference in New Issue
Block a user