[checkpointio] support debug log (#6153)

* [checkpointio] support debug log

* [checkpointio] refactor async writer api

* fix test

* fix test
This commit is contained in:
Hongxin Liu
2024-12-02 11:29:19 +08:00
committed by GitHub
parent ab856fd308
commit 6280cb18b8
9 changed files with 33 additions and 54 deletions

View File

@@ -83,7 +83,11 @@ class TensorBucket:
unflat_buffers = list(map(list, zip(*unflat_buffers)))
for unflat_shards, tensor in zip(unflat_buffers, self._bucket):
write_back_tensor = self._write_back_pairs[tensor]
write_back_tensor.data.copy_(
_flatten_dense_tensors(unflat_shards)[: write_back_tensor.numel()].reshape_as(write_back_tensor)
)
rec_tensor = _flatten_dense_tensors(unflat_shards)[: write_back_tensor.numel()]
if write_back_tensor.is_contiguous():
rec_tensor = rec_tensor.view_as(write_back_tensor)
else:
rec_tensor = rec_tensor.reshape_as(write_back_tensor)
write_back_tensor.data.copy_(rec_tensor)
self.empty()