mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[CI] Cleanup Dist Optim tests with shared helper funcs (#6125)
* Refractor and cleanup using common helper funcs. Tests passed * Update comments * Fix relative import * Fix param fetching bug
This commit is contained in:
@@ -384,7 +384,7 @@ class Linear1D_Row(ParallelModule):
|
||||
out_features (int): size of each output sample.
|
||||
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
||||
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
|
||||
parallel_input (bool): If set to ``True``, it's assumed that the input is already split/copied across each rank, defaults to False.
|
||||
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
||||
seq_parallel_mode (`str`): The type of sp mode, it will use sequence parallel when `seq_parallel_mode` is not None. Defaults to None.
|
||||
seq_parallel_dim (`int`): Which dim will sequence parallelism split and gather the sequence.
|
||||
@@ -544,14 +544,14 @@ class Linear1D_Row(ParallelModule):
|
||||
if self.parallel_input:
|
||||
assert (
|
||||
input_.shape[-1] == self.weight.shape[-1]
|
||||
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected feature dim of input {}.".format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1]
|
||||
)
|
||||
input_ = input_
|
||||
else:
|
||||
assert (
|
||||
divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1]
|
||||
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected feature dim of input {}.".format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions
|
||||
)
|
||||
input_ = split_forward_gather_backward(
|
||||
|
Reference in New Issue
Block a user