From 63469c0f91941dcf4eec21c58247910e02067290 Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 14 Mar 2022 15:48:55 +0800 Subject: [PATCH] polish code --- colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py b/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py index a2b9b0097..d5ba72a2e 100644 --- a/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py +++ b/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py @@ -23,6 +23,9 @@ class BucketTensorShardStrategy(TensorShardStrategy): for i in range(self.world_size): if i == self.local_rank: buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device())) + # Release payload here, to decrease peak memory usage + for t in tensor_list: + t.reset_payload(None) else: buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device())) dist.all_gather(buffer_list, buffer_list[self.local_rank], group=self.process_group)