[tensor] a shorter shard and replicate spec (#1245)

This commit is contained in:
Jiarui Fang
2022-07-11 15:51:48 +08:00
committed by GitHub
parent 2699dfbbfd
commit 9bcd2fd4af
25 changed files with 91 additions and 98 deletions

View File

@@ -16,7 +16,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ShardSpec, ProcessGroup
from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
@@ -81,7 +81,7 @@ class MLP(nn.Module):
def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'weight' in n: