[CI] Cleanup Dist Optim tests with shared helper funcs (#6125)

* Refractor and cleanup using common helper funcs. Tests passed

* Update comments

* Fix relative import

* Fix param fetching bug
This commit is contained in:
Wenxuan Tan
2025-02-11 23:42:34 -06:00
committed by GitHub
parent 5c09d726a6
commit ec73f1b5e2
8 changed files with 142 additions and 298 deletions

View File

@@ -384,7 +384,7 @@ class Linear1D_Row(ParallelModule):
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
parallel_input (bool): If set to ``True``, it's assumed that the input is already split/copied across each rank, defaults to False.
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
seq_parallel_mode (`str`): The type of sp mode, it will use sequence parallel when `seq_parallel_mode` is not None. Defaults to None.
seq_parallel_dim (`int`): Which dim will sequence parallelism split and gather the sequence.
@@ -544,14 +544,14 @@ class Linear1D_Row(ParallelModule):
if self.parallel_input:
assert (
input_.shape[-1] == self.weight.shape[-1]
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected feature dim of input {}.".format(
input_.shape, self.weight.shape, self.weight.shape[-1]
)
input_ = input_
else:
assert (
divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1]
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected feature dim of input {}.".format(
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions
)
input_ = split_forward_gather_backward(