From 3263cdf57f6be9b61fc2f4e64b8f3e2e1dde8b3f Mon Sep 17 00:00:00 2001 From: Jiatong Han <59948448+JThh@users.noreply.github.com> Date: Thu, 8 Sep 2022 16:33:14 +0800 Subject: [PATCH] [NFC] polish colossalai/nn/parallel/data_parallel.py code style (#1570) Co-authored-by: JThh --- colossalai/nn/parallel/data_parallel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 7420da8f4..378f186a8 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -350,7 +350,7 @@ class ZeroDDP(ColoDDP): for tensor in chunk.get_tensors(): rec_p = torch.empty([0]) if record_flag: - rec_p = tensor.cpu() # move the whole tensor to CPU mem + rec_p = tensor.cpu() # move the whole tensor to CPU mem assert tensor not in param_to_save_data param_to_save_data[tensor] = rec_p # release the actual memory of the chunk @@ -406,7 +406,7 @@ class ZeroDDP(ColoDDP): state_dict = state_dict.copy() if metadata is not None: # mypy isn't aware that "_metadata" exists in state_dict - state_dict._metadata = metadata # type: ignore[attr-defined] + state_dict._metadata = metadata # type: ignore[attr-defined] prefix = '' local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})