mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-04 02:57:20 +00:00
[gemini] fix tensor storage cleaning in state dict collection (#4396)
This commit is contained in:
parent
458ae331ad
commit
6ccecc0c69
@ -1,6 +1,5 @@
|
|||||||
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
|
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
|
||||||
import copy
|
import copy
|
||||||
import gc
|
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple
|
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple
|
||||||
@ -468,11 +467,6 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||||||
self.load_from_compacted_states(compacted_states, collected_states, state_names, shard_offset,
|
self.load_from_compacted_states(compacted_states, collected_states, state_names, shard_offset,
|
||||||
shard_size)
|
shard_size)
|
||||||
|
|
||||||
# Clean gathered states
|
|
||||||
for state_shard in gathered_state_shards:
|
|
||||||
del state_shard[0]
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
# Reshape tensors
|
# Reshape tensors
|
||||||
if is_collector:
|
if is_collector:
|
||||||
for state_name, state_tensor in collected_states.items():
|
for state_name, state_tensor in collected_states.items():
|
||||||
|
Loading…
Reference in New Issue
Block a user