mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[Tensor] Add function to spec and update linear 1Drow and unit tests (#869)
This commit is contained in:
@@ -12,6 +12,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
|
||||
|
||||
from _utils import check_equal, replace_parameter_add_grad, broadcast_tensor_chunk
|
||||
|
||||
@@ -45,7 +46,11 @@ def run_linear_tp1d_row_test():
|
||||
|
||||
# replace the torch nn.Parameters with ColoTensor
|
||||
sharded_weight = ColoTensor.init_from_torch_tensor(W)
|
||||
sharded_weight.set_spec(spec="1Drow") # reshard
|
||||
parallel_action_list = [
|
||||
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)
|
||||
]
|
||||
spec = TensorSpec(parallel_action_list)
|
||||
sharded_weight.set_spec(spec=spec) # reshard
|
||||
sharded_bias = ColoTensor.init_from_torch_tensor(B)
|
||||
replace_parameter_add_grad(layer, sharded_weight, sharded_bias)
|
||||
out = layer(A)
|
||||
|
Reference in New Issue
Block a user