mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 19:17:30 +00:00
[inference] overlap comm and compute in Linear1D_Row when stream_chunk_num > 1 (#1876)
This commit is contained in:
@@ -706,13 +706,22 @@ class Linear1D_Row(ParallelLayer):
|
||||
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
|
||||
if self.stream_chunk_num > 1:
|
||||
output_parallel_list = [None for i in range(self.stream_chunk_num)]
|
||||
for i in range(self.stream_chunk_num):
|
||||
output_parallel_list[i] = F.linear(input_, self.weight_list[i])
|
||||
output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D)
|
||||
output = torch.cat(output_parallel_list, dim=-1)
|
||||
if self.training:
|
||||
raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!")
|
||||
with torch.no_grad():
|
||||
output_parallel_list = [None for i in range(self.stream_chunk_num)]
|
||||
handle_list = []
|
||||
for i in range(self.stream_chunk_num):
|
||||
output_parallel_list[i] = F.linear(input_, self.weight_list[i])
|
||||
handle = torch.distributed.all_reduce(output_parallel_list[i],
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D),
|
||||
async_op=True)
|
||||
handle_list.append(handle)
|
||||
# output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D)
|
||||
for handle in handle_list:
|
||||
handle.wait()
|
||||
output = torch.cat(output_parallel_list, dim=-1)
|
||||
else:
|
||||
print(input_.shape, self.weight.shape)
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
|
||||
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
|
||||
|
Reference in New Issue
Block a user