mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
[inference] overlap comm and compute in Linear1D_Row when stream_chunk_num > 1 (#1876)
This commit is contained in:
@@ -514,8 +514,9 @@ def check_linear_row_stream_inference():
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
||||
assert HIDDEN_SIZE % 2 == 0
|
||||
layer = Linear1D_Row(OUTPUT_SIZE, INPUT_SIZE, stream_chunk_num=2)
|
||||
stream_chunk_num = 4
|
||||
assert HIDDEN_SIZE % stream_chunk_num == 0
|
||||
layer = Linear1D_Row(OUTPUT_SIZE, INPUT_SIZE, stream_chunk_num=stream_chunk_num)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, OUTPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
@@ -537,6 +538,8 @@ def check_linear_row_stream_inference():
|
||||
layer.weight = Parameter(W)
|
||||
layer.bias = Parameter(B)
|
||||
layer.chunk_weight()
|
||||
layer.eval()
|
||||
|
||||
out = layer(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
|
Reference in New Issue
Block a user