mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 21:22:49 +00:00
[gemini] fixes for benchmarking (#5847)
* [gemini] fix missing return * [gemini] fix missing arg pass * [gemini] use gather tensor instead of list * [test] enable flash attention for benchmark by default * [test] enable flash attention for benchmark by default --------- Co-authored-by: genghaozhe <939857490@qq.com>
This commit is contained in:
@@ -133,12 +133,12 @@ class ChunkManager:
|
||||
self.__sub_accessed_chunk(chunk)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None:
|
||||
def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False, async_move=False) -> None:
|
||||
"""Move the shard of the chunk to the target device."""
|
||||
if not chunk.can_move or chunk.device_type == device.type:
|
||||
return
|
||||
self.__sub_memory_usage(chunk.memory_usage)
|
||||
chunk.shard_move(device, force_copy)
|
||||
chunk.shard_move(device, force_copy, non_blocking=async_move)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
|
||||
|
Reference in New Issue
Block a user