[zero] fix error for BEiT models (#2169)

* [zero] fix error for BEiT models

* [ColoParameter] add unpack operation for tuple arguments

* fix bugs

* fix chunkv2 unit testing

* add assertion for gradient state
This commit is contained in:
HELSON
2022-12-26 15:03:54 +08:00
committed by GitHub
parent 4363ff3e41
commit 2458659919
7 changed files with 82 additions and 32 deletions

View File

@@ -90,6 +90,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory):
for param in param_list:
my_chunk.tensor_trans_state(param, TensorState.COMPUTE)
my_chunk.tensor_trans_state(param, TensorState.HOLD_AFTER_BWD)
my_chunk.tensor_trans_state(param, TensorState.READY_FOR_REDUCE)
assert my_chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4