[inference] overlap comm and compute in Linear1D_Row when stream_chunk_num > 1 (#1876)

This commit is contained in:
Jiarui Fang
2022-11-10 17:36:42 +08:00
committed by GitHub
parent 1b494ad73c
commit 986f8cbaa7
2 changed files with 20 additions and 8 deletions

View File

@@ -514,8 +514,9 @@ def check_linear_row_stream_inference():
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
assert HIDDEN_SIZE % 2 == 0
layer = Linear1D_Row(OUTPUT_SIZE, INPUT_SIZE, stream_chunk_num=2)
stream_chunk_num = 4
assert HIDDEN_SIZE % stream_chunk_num == 0
layer = Linear1D_Row(OUTPUT_SIZE, INPUT_SIZE, stream_chunk_num=stream_chunk_num)
A_shape = (BATCH_SIZE, SEQ_LENGTH, OUTPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
@@ -537,6 +538,8 @@ def check_linear_row_stream_inference():
layer.weight = Parameter(W)
layer.bias = Parameter(B)
layer.chunk_weight()
layer.eval()
out = layer(A)
A_master = A_master.clone()