mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 10:30:03 +00:00
[zero] fix error for BEiT models (#2169)
* [zero] fix error for BEiT models * [ColoParameter] add unpack operation for tuple arguments * fix bugs * fix chunkv2 unit testing * add assertion for gradient state
This commit is contained in:
@@ -283,7 +283,9 @@ class ZeroDDP(ColoDDP):
|
||||
p.grad = None
|
||||
|
||||
def _post_backward(self):
|
||||
assert self.chunk_manager.accessed_mem == 0
|
||||
if self.chunk_manager.accessed_mem != 0:
|
||||
raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.",
|
||||
"The most possible reason is that the model is not compatible with ZeroDDP.")
|
||||
self._setup_grads_ptr()
|
||||
self._logger.debug(
|
||||
f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}'
|
||||
@@ -304,8 +306,9 @@ class ZeroDDP(ColoDDP):
|
||||
empty_grad = torch.empty_like(grad)
|
||||
free_storage(empty_grad)
|
||||
with torch._C.DisableTorchFunction():
|
||||
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
|
||||
chunk = self.chunk_manager.get_chunk(p)
|
||||
assert chunk.tensors_info[p].state == TensorState.HOLD_AFTER_BWD
|
||||
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
|
||||
chunk.copy_tensor_to_chunk_slice(p, grad)
|
||||
reduced = self.chunk_manager.reduce_chunk(chunk)
|
||||
if reduced:
|
||||
|
Reference in New Issue
Block a user