diff --git a/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py b/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py index a2b9b0097..d5ba72a2e 100644 --- a/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py +++ b/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py @@ -23,6 +23,9 @@ class BucketTensorShardStrategy(TensorShardStrategy): for i in range(self.world_size): if i == self.local_rank: buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device())) + # Release payload here, to decrease peak memory usage + for t in tensor_list: + t.reset_payload(None) else: buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device())) dist.all_gather(buffer_list, buffer_list[self.local_rank], group=self.process_group)