[fix] fix linear (no tp) ops func name;

This commit is contained in:
duanjunwen
2024-10-31 08:18:28 +00:00
parent d2e05a99b3
commit 5f0924361d
6 changed files with 19 additions and 41 deletions

View File

@@ -9,7 +9,7 @@ from torch.testing import assert_close
import colossalai
from colossalai.lazy import LazyInitContext
from colossalai.pipeline.weight_grad_store import WeightGradStore
from colossalai.shardformer.layer import Linear1D, Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, LinearWithGradAccum
from colossalai.tensor.d_tensor import is_distributed_tensor
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@@ -124,7 +124,7 @@ def check_linear_without_weight_grad_store(lazy_init: bool, seq_parallel_mode: b
linear = nn.Linear(32, 128).cuda()
with ctx:
linear_copy = nn.Linear(32, 128).cuda()
linear_base = Linear1D.from_native_module(
linear_base = LinearWithGradAccum.from_native_module(
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=False
)
assert linear_base.weight.shape == torch.Size([128, 32])
@@ -164,7 +164,7 @@ def check_linear_with_weight_grad_store(lazy_init: bool, seq_parallel_mode: bool
linear = nn.Linear(32, 128).cuda()
with ctx:
linear_copy = nn.Linear(32, 128).cuda()
linear_base = Linear1D.from_native_module(
linear_base = LinearWithGradAccum.from_native_module(
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=True
)
assert linear_base.weight.shape == torch.Size([128, 32])