mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[gemini] support gradient accumulation (#4869)
* add test * fix no_sync bug in low level zero plugin * fix test * add argument for grad accum * add grad accum in backward hook for gemini * finish implementation, rewrite tests * fix test * skip stuck model in low level zero test * update doc * optimize communication & fix gradient checkpoint * modify doc * cleaning codes * update cpu adam fp16 case
This commit is contained in:
@@ -434,6 +434,21 @@ class Chunk:
|
||||
if update_ptr:
|
||||
tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape)
|
||||
|
||||
def add_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
|
||||
"""
|
||||
Add data slice to the memory space indexed by the input tensor in the chunk.
|
||||
Only used when accumulating gradient chunks.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): the tensor used to retrieve meta information
|
||||
data_slice (torch.Tensor): the tensor to be added to the chunk
|
||||
"""
|
||||
# sanity check
|
||||
assert self.is_gathered
|
||||
|
||||
tensor_info = self.tensors_info[tensor]
|
||||
self.cuda_global_chunk[tensor_info.offset : tensor_info.end].add_(data_slice.data.flatten())
|
||||
|
||||
def get_valid_length(self) -> int:
|
||||
"""Get the valid length of the chunk's payload."""
|
||||
if self.keep_gathered:
|
||||
|
Reference in New Issue
Block a user