mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[gemini] fixes for benchmarking (#5847)
* [gemini] fix missing return * [gemini] fix missing arg pass * [gemini] use gather tensor instead of list * [test] enable flash attention for benchmark by default * [test] enable flash attention for benchmark by default --------- Co-authored-by: genghaozhe <939857490@qq.com>
This commit is contained in:
@@ -403,9 +403,9 @@ class Chunk:
|
||||
self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device()
|
||||
)
|
||||
|
||||
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
|
||||
self.grad_reduce_work = dist.reduce_scatter(
|
||||
self.cuda_shard, input_list, group=self.torch_pg, async_op=async_op
|
||||
assert self.cuda_global_chunk.is_contiguous()
|
||||
self.grad_reduce_work = dist.reduce_scatter_tensor(
|
||||
self.cuda_shard, self.cuda_global_chunk, group=self.torch_pg, async_op=async_op
|
||||
)
|
||||
|
||||
if self.extra_dp_group is not None:
|
||||
@@ -520,8 +520,10 @@ class Chunk:
|
||||
assert self.cuda_shard is not None
|
||||
|
||||
alloc_storage(self.cuda_global_chunk)
|
||||
gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0))
|
||||
work = dist.all_gather(gather_list, self.cuda_shard, self.torch_pg, async_op=async_op)
|
||||
assert self.cuda_global_chunk.is_contiguous()
|
||||
work = dist.all_gather_into_tensor(
|
||||
self.cuda_global_chunk, self.cuda_shard, self.torch_pg, async_op=async_op
|
||||
)
|
||||
|
||||
self.cuda_shard = None
|
||||
self.is_gathered = True
|
||||
|
Reference in New Issue
Block a user