diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index da05df1cb..31a9e5627 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -314,14 +314,18 @@ class ZeroDDP(ColoDDP): module """ chunks = self.chunk_manager.get_chunks(self.fp32_params) + chunks_orig_device_type = [] for chunk in chunks: + chunks_orig_device_type.append(chunk.device_type) self.chunk_manager.access_chunk(chunk) for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params): if p is not None: rec_p = fp32_p.clone() if fp32_p.device.type == 'cpu' else fp32_p.cpu() destination[prefix + name] = rec_p if keep_vars else rec_p.detach() - for chunk in chunks: + for orig_dvice_type, chunk in zip(chunks_orig_device_type, chunks): self.chunk_manager.release_chunk(chunk) + if not chunk.is_empty and orig_dvice_type == 'cpu': + self.chunk_manager.move_chunk(chunk, torch.device('cpu')) for name, buf in self.named_buffers(): if buf is not None and name not in self._non_persistent_buffers_set: destination[prefix + name] = buf if keep_vars else buf.detach()