support linear accumulation fusion (#5199)

support linear accumulation fusion

support linear accumulation fusion

fix
This commit is contained in:
flybird11111
2023-12-29 18:22:42 +08:00
committed by GitHub
parent 64519eb830
commit 02d2328a04
2 changed files with 48 additions and 5 deletions

View File

@@ -408,7 +408,7 @@ class Linear1D_Row(ParallelModule):
handle.wait()
output = torch.cat(output_parallel_list, dim=-1)
else:
output_parallel = F.linear(input_, self.weight)
output_parallel = linear_with_async_comm(input_, self.weight, None, None, False)
if self.seq_parallel:
output = linear_reducescatter_forward_gather_backward(
output_parallel, self.process_group, self.seq_parallel_dim