mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[tensor] a shorter shard and replicate spec (#1245)
This commit is contained in:
@@ -16,7 +16,7 @@ from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup
|
||||
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ShardSpec, ProcessGroup
|
||||
from colossalai.nn.parallel.data_parallel import ColoDDP
|
||||
from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
@@ -81,7 +81,7 @@ class MLP(nn.Module):
|
||||
|
||||
|
||||
def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup):
|
||||
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'weight' in n:
|
||||
|
Reference in New Issue
Block a user