mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-31 16:40:41 +00:00
[fix] fix linear (no tp) ops func name;
This commit is contained in:
@@ -9,7 +9,7 @@ from torch.testing import assert_close
|
||||
import colossalai
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.pipeline.weight_grad_store import WeightGradStore
|
||||
from colossalai.shardformer.layer import Linear1D, Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, LinearWithGradAccum
|
||||
from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
@@ -124,7 +124,7 @@ def check_linear_without_weight_grad_store(lazy_init: bool, seq_parallel_mode: b
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
with ctx:
|
||||
linear_copy = nn.Linear(32, 128).cuda()
|
||||
linear_base = Linear1D.from_native_module(
|
||||
linear_base = LinearWithGradAccum.from_native_module(
|
||||
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=False
|
||||
)
|
||||
assert linear_base.weight.shape == torch.Size([128, 32])
|
||||
@@ -164,7 +164,7 @@ def check_linear_with_weight_grad_store(lazy_init: bool, seq_parallel_mode: bool
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
with ctx:
|
||||
linear_copy = nn.Linear(32, 128).cuda()
|
||||
linear_base = Linear1D.from_native_module(
|
||||
linear_base = LinearWithGradAccum.from_native_module(
|
||||
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=True
|
||||
)
|
||||
assert linear_base.weight.shape == torch.Size([128, 32])
|
||||
|
Reference in New Issue
Block a user