From 5d5031e9468856ed7ae3aa058c6c800d40a81c2a Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 28 Jul 2022 09:31:42 +0800 Subject: [PATCH] fix zero ddp state dict (#1378) --- colossalai/nn/parallel/data_parallel.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index da05df1cb..31a9e5627 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -314,14 +314,18 @@ class ZeroDDP(ColoDDP): module """ chunks = self.chunk_manager.get_chunks(self.fp32_params) + chunks_orig_device_type = [] for chunk in chunks: + chunks_orig_device_type.append(chunk.device_type) self.chunk_manager.access_chunk(chunk) for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params): if p is not None: rec_p = fp32_p.clone() if fp32_p.device.type == 'cpu' else fp32_p.cpu() destination[prefix + name] = rec_p if keep_vars else rec_p.detach() - for chunk in chunks: + for orig_dvice_type, chunk in zip(chunks_orig_device_type, chunks): self.chunk_manager.release_chunk(chunk) + if not chunk.is_empty and orig_dvice_type == 'cpu': + self.chunk_manager.move_chunk(chunk, torch.device('cpu')) for name, buf in self.named_buffers(): if buf is not None and name not in self._non_persistent_buffers_set: destination[prefix + name] = buf if keep_vars else buf.detach()