[gemini] fix tensor storage cleaning in state dict collection (#4396)

This commit is contained in:
Baizhou Zhang
2023-08-10 15:36:46 +08:00
committed by GitHub
parent 458ae331ad
commit 6ccecc0c69

View File

@@ -1,6 +1,5 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import copy
import gc
import math
import warnings
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple
@@ -468,11 +467,6 @@ class ZeroOptimizer(ColossalaiOptimizer):
self.load_from_compacted_states(compacted_states, collected_states, state_names, shard_offset,
shard_size)
# Clean gathered states
for state_shard in gathered_state_shards:
del state_shard[0]
gc.collect()
# Reshape tensors
if is_collector:
for state_name, state_tensor in collected_states.items():