mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-03 22:18:23 +00:00
[checkpointio] fix zero optimizer async save memory (#6151)
* [checkpointio] fix zero optimizer async save memory * [checkpointio] fit new tensornvme api * [checkpointio] fit new tensornvme api
This commit is contained in:
parent
8ecff0cb7f
commit
ab856fd308
@ -128,22 +128,20 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||||||
# the `state_dict` in LowLevelZeroOptimizer has communication
|
# the `state_dict` in LowLevelZeroOptimizer has communication
|
||||||
# if only the master rank collect state_dict and save,
|
# if only the master rank collect state_dict and save,
|
||||||
# the communication on each rank would not match
|
# the communication on each rank would not match
|
||||||
if use_async:
|
if use_async and self.coordinator.is_master():
|
||||||
if id(optimizer) not in self.pinned_state_dicts:
|
if id(optimizer) not in self.pinned_state_dicts:
|
||||||
self.pinned_state_dicts[id(optimizer)] = {}
|
self.pinned_state_dicts[id(optimizer)] = {}
|
||||||
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
|
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
|
||||||
else:
|
else:
|
||||||
pinned_state_dicts = None
|
pinned_state_dicts = None
|
||||||
state_dict = optimizer.state_dict(pinned_state_dicts)
|
state_dict = optimizer.state_dict(pinned_state_dicts, only_on_master=True)
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
if use_async:
|
if use_async:
|
||||||
from tensornvme.async_file_io import AsyncFileWriter
|
from tensornvme.async_file_io import AsyncFileWriter
|
||||||
|
|
||||||
from colossalai.utils.safetensors import save_nested
|
from colossalai.utils.safetensors import save_nested
|
||||||
|
|
||||||
f_writer = AsyncFileWriter(
|
f_writer = AsyncFileWriter(checkpoint, n_entries=self.N_WRITE_ENTRIES, backend="pthread")
|
||||||
fp=open(checkpoint, "wb", buffering=0), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
|
|
||||||
)
|
|
||||||
save_nested(f_writer, state_dict)
|
save_nested(f_writer, state_dict)
|
||||||
self.async_writers.append(f_writer)
|
self.async_writers.append(f_writer)
|
||||||
else:
|
else:
|
||||||
@ -192,13 +190,15 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||||||
# state_dict only provide only 'param_groups'
|
# state_dict only provide only 'param_groups'
|
||||||
state_dict = optimizer.optim.state_dict()
|
state_dict = optimizer.optim.state_dict()
|
||||||
# state shard would be handled by the low-level zero optimizer
|
# state shard would be handled by the low-level zero optimizer
|
||||||
if use_async:
|
if use_async and self.coordinator.is_master():
|
||||||
if id(optimizer) not in self.pinned_state_dicts:
|
if id(optimizer) not in self.pinned_state_dicts:
|
||||||
self.pinned_state_dicts[id(optimizer)] = {}
|
self.pinned_state_dicts[id(optimizer)] = {}
|
||||||
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
|
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
|
||||||
else:
|
else:
|
||||||
pinned_state_dicts = None
|
pinned_state_dicts = None
|
||||||
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts)
|
sharded_state = optimizer.state_dict_shard(
|
||||||
|
max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts, only_on_master=True
|
||||||
|
)
|
||||||
|
|
||||||
# Preparing file paths and index file.
|
# Preparing file paths and index file.
|
||||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
|
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
|
||||||
@ -227,7 +227,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||||||
from colossalai.utils.safetensors import save_nested
|
from colossalai.utils.safetensors import save_nested
|
||||||
|
|
||||||
f_writer = AsyncFileWriter(
|
f_writer = AsyncFileWriter(
|
||||||
fp=open(checkpoint_file_path, "wb", buffering=0),
|
checkpoint_file_path,
|
||||||
n_entries=self.N_WRITE_ENTRIES,
|
n_entries=self.N_WRITE_ENTRIES,
|
||||||
backend="pthread",
|
backend="pthread",
|
||||||
)
|
)
|
||||||
|
@ -72,7 +72,6 @@ class CheckpointIO(ABC):
|
|||||||
def _sync_io(self):
|
def _sync_io(self):
|
||||||
for writer in self.async_writers:
|
for writer in self.async_writers:
|
||||||
writer.synchronize()
|
writer.synchronize()
|
||||||
writer.fp.close()
|
|
||||||
self.async_writers.clear()
|
self.async_writers.clear()
|
||||||
|
|
||||||
def _sync_d2h(self):
|
def _sync_d2h(self):
|
||||||
|
@ -56,7 +56,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||||||
if use_async:
|
if use_async:
|
||||||
from tensornvme.async_file_io import AsyncFileWriter
|
from tensornvme.async_file_io import AsyncFileWriter
|
||||||
|
|
||||||
writer = AsyncFileWriter(open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread")
|
writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread")
|
||||||
if id(model) not in self.pinned_state_dicts:
|
if id(model) not in self.pinned_state_dicts:
|
||||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||||
self.async_writers.append(writer)
|
self.async_writers.append(writer)
|
||||||
|
@ -690,9 +690,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
|
|
||||||
from colossalai.utils.safetensors import move_and_save
|
from colossalai.utils.safetensors import move_and_save
|
||||||
|
|
||||||
writer = AsyncFileWriter(
|
writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread")
|
||||||
open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread"
|
|
||||||
)
|
|
||||||
if id(model) not in self.pinned_state_dicts:
|
if id(model) not in self.pinned_state_dicts:
|
||||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||||
self.async_writers.append(writer)
|
self.async_writers.append(writer)
|
||||||
|
@ -311,7 +311,7 @@ def async_save_state_dict_shards(
|
|||||||
index_file.append_weight_map(key, shard_file)
|
index_file.append_weight_map(key, shard_file)
|
||||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||||
|
|
||||||
writer = AsyncFileWriter(open(checkpoint_file_path, "wb", buffering=0), n_write_entries, backend="pthread")
|
writer = AsyncFileWriter(checkpoint_file_path, n_write_entries, backend="pthread")
|
||||||
writers.append(writer)
|
writers.append(writer)
|
||||||
|
|
||||||
if pinned_state_dict is not None:
|
if pinned_state_dict is not None:
|
||||||
|
@ -776,7 +776,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
|
|
||||||
return {"state": packed_state, "param_groups": param_groups}
|
return {"state": packed_state, "param_groups": param_groups}
|
||||||
|
|
||||||
def state_dict(self, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None) -> Dict:
|
def state_dict(
|
||||||
|
self, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None, only_on_master: bool = False
|
||||||
|
) -> Dict:
|
||||||
"""Return a state_dict same with DDP
|
"""Return a state_dict same with DDP
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -785,16 +787,22 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
zero_state = dict()
|
zero_state = dict()
|
||||||
device = get_accelerator().get_current_device()
|
device = get_accelerator().get_current_device()
|
||||||
for param, state in self.optim.state.items():
|
for param, state in self.optim.state.items():
|
||||||
if pinned_state_dicts is not None and param not in pinned_state_dicts:
|
|
||||||
pinned_state_dicts[param] = {}
|
|
||||||
zero_state[param] = copy.deepcopy(state)
|
|
||||||
for k, v in state.items():
|
|
||||||
if isinstance(v, torch.Tensor) and k != "step":
|
|
||||||
working_param = self.master_to_working_param[id(param)]
|
working_param = self.master_to_working_param[id(param)]
|
||||||
pg = self.param_to_pg[working_param]
|
pg = self.param_to_pg[working_param]
|
||||||
|
if not only_on_master or get_nd_rank(pg) == 0:
|
||||||
|
zero_state[param] = copy.deepcopy(state)
|
||||||
|
else:
|
||||||
|
zero_state[param] = {}
|
||||||
|
|
||||||
|
if pinned_state_dicts is not None and param not in pinned_state_dicts:
|
||||||
|
pinned_state_dicts[param] = {}
|
||||||
|
|
||||||
|
for k, v in state.items():
|
||||||
|
if isinstance(v, torch.Tensor) and k != "step":
|
||||||
gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
|
gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
|
||||||
all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg)
|
all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg)
|
||||||
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param)
|
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param)
|
||||||
|
if not only_on_master or get_nd_rank(pg) == 0:
|
||||||
if pinned_state_dicts is not None and k not in pinned_state_dicts[param]:
|
if pinned_state_dicts is not None and k not in pinned_state_dicts[param]:
|
||||||
pinned_state_dicts[param][k] = torch.empty_like(param_state, pin_memory=True, device="cpu")
|
pinned_state_dicts[param][k] = torch.empty_like(param_state, pin_memory=True, device="cpu")
|
||||||
if pinned_state_dicts is not None:
|
if pinned_state_dicts is not None:
|
||||||
@ -837,7 +845,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
self.optim.load_state_dict(zero_state_dict)
|
self.optim.load_state_dict(zero_state_dict)
|
||||||
|
|
||||||
def state_dict_shard(
|
def state_dict_shard(
|
||||||
self, max_shard_size: int = 1024, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None
|
self,
|
||||||
|
max_shard_size: int = 1024,
|
||||||
|
pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None,
|
||||||
|
only_on_master: bool = False,
|
||||||
) -> Iterator[Tuple[Dict, int]]:
|
) -> Iterator[Tuple[Dict, int]]:
|
||||||
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
|
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
|
||||||
Only include the 'state' in state_dict.
|
Only include the 'state' in state_dict.
|
||||||
@ -862,20 +873,26 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
cnt += 1
|
cnt += 1
|
||||||
for param_idx, states in local_states.items():
|
for param_idx, states in local_states.items():
|
||||||
current_block_size = 0
|
current_block_size = 0
|
||||||
current_block = copy.deepcopy(states)
|
|
||||||
if pinned_state_dicts is not None and param_idx not in pinned_state_dicts:
|
if pinned_state_dicts is not None and param_idx not in pinned_state_dicts:
|
||||||
pinned_state_dicts[param_idx] = {}
|
pinned_state_dicts[param_idx] = {}
|
||||||
master_param = idx2master[param_idx]
|
master_param = idx2master[param_idx]
|
||||||
working_param = self.master_to_working_param[id(master_param)]
|
working_param = self.master_to_working_param[id(master_param)]
|
||||||
pg = self.param_to_pg[working_param]
|
pg = self.param_to_pg[working_param]
|
||||||
|
if not only_on_master or get_nd_rank(pg) == 0:
|
||||||
|
current_block = copy.deepcopy(states)
|
||||||
|
else:
|
||||||
|
current_block = {}
|
||||||
|
|
||||||
for k, v in states.items():
|
for k, v in states.items():
|
||||||
if isinstance(v, torch.Tensor) and k != "step":
|
if isinstance(v, torch.Tensor) and k != "step":
|
||||||
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
|
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
|
||||||
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
|
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
|
||||||
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param)
|
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param)
|
||||||
|
if not only_on_master or get_nd_rank(pg) == 0:
|
||||||
if pinned_state_dicts is not None and k not in pinned_state_dicts[param_idx]:
|
if pinned_state_dicts is not None and k not in pinned_state_dicts[param_idx]:
|
||||||
pinned_state_dicts[param_idx][k] = torch.empty_like(state_tensor, pin_memory=True, device="cpu")
|
pinned_state_dicts[param_idx][k] = torch.empty_like(
|
||||||
|
state_tensor, pin_memory=True, device="cpu"
|
||||||
|
)
|
||||||
if pinned_state_dicts is not None:
|
if pinned_state_dicts is not None:
|
||||||
pinned_state_dicts[param_idx][k].copy_(state_tensor)
|
pinned_state_dicts[param_idx][k].copy_(state_tensor)
|
||||||
current_block[k] = pinned_state_dicts[param_idx][k]
|
current_block[k] = pinned_state_dicts[param_idx][k]
|
||||||
|
@ -10,6 +10,7 @@ try:
|
|||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
|
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
|
||||||
|
|
||||||
|
|
||||||
from colossalai.testing import check_state_dict_equal
|
from colossalai.testing import check_state_dict_equal
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
@ -110,20 +111,20 @@ def test_save_load():
|
|||||||
}
|
}
|
||||||
|
|
||||||
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
|
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
|
||||||
f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread")
|
f_writer = AsyncFileWriter(optimizer_saved_path, n_entries=191, backend="pthread")
|
||||||
save_nested(f_writer, optimizer_state_dict)
|
save_nested(f_writer, optimizer_state_dict)
|
||||||
f_writer.sync_before_step()
|
f_writer.sync_before_step()
|
||||||
f_writer.synchronize()
|
f_writer.synchronize()
|
||||||
f_writer.fp.close()
|
del f_writer
|
||||||
load_state_dict = load_flat(optimizer_saved_path)
|
load_state_dict = load_flat(optimizer_saved_path)
|
||||||
check_state_dict_equal(load_state_dict, optimizer_state_dict)
|
check_state_dict_equal(load_state_dict, optimizer_state_dict)
|
||||||
|
|
||||||
optimizer_shard_saved_path = f"{tempdir}/save_optimizer_shard.safetensors"
|
optimizer_shard_saved_path = f"{tempdir}/save_optimizer_shard.safetensors"
|
||||||
f_writer = AsyncFileWriter(fp=open(optimizer_shard_saved_path, "wb"), n_entries=191, backend="pthread")
|
f_writer = AsyncFileWriter(optimizer_shard_saved_path, n_entries=191, backend="pthread")
|
||||||
save_nested(f_writer, optimizer_state_dict["state"])
|
save_nested(f_writer, optimizer_state_dict["state"])
|
||||||
f_writer.sync_before_step()
|
f_writer.sync_before_step()
|
||||||
f_writer.synchronize()
|
f_writer.synchronize()
|
||||||
f_writer.fp.close()
|
del f_writer
|
||||||
load_state_dict_shard = load_flat(optimizer_shard_saved_path)
|
load_state_dict_shard = load_flat(optimizer_shard_saved_path)
|
||||||
check_state_dict_equal(load_state_dict_shard, optimizer_state_dict["state"])
|
check_state_dict_equal(load_state_dict_shard, optimizer_state_dict["state"])
|
||||||
|
|
||||||
@ -133,21 +134,21 @@ def test_save_load():
|
|||||||
"module.weight2": torch.rand((1024, 1024)),
|
"module.weight2": torch.rand((1024, 1024)),
|
||||||
}
|
}
|
||||||
model_saved_path = f"{tempdir}/save_model.safetensors"
|
model_saved_path = f"{tempdir}/save_model.safetensors"
|
||||||
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
|
f_writer = AsyncFileWriter(model_saved_path, n_entries=191, backend="pthread")
|
||||||
save(f_writer, model_state_dict)
|
save(f_writer, model_state_dict)
|
||||||
f_writer.sync_before_step()
|
f_writer.sync_before_step()
|
||||||
f_writer.synchronize()
|
f_writer.synchronize()
|
||||||
f_writer.fp.close()
|
del f_writer
|
||||||
load_state_dict = load_file(model_saved_path)
|
load_state_dict = load_file(model_saved_path)
|
||||||
check_state_dict_equal(model_state_dict, load_state_dict)
|
check_state_dict_equal(model_state_dict, load_state_dict)
|
||||||
|
|
||||||
model_state_dict_cuda = {k: v.to(get_current_device()) for k, v in model_state_dict.items()}
|
model_state_dict_cuda = {k: v.to(get_current_device()) for k, v in model_state_dict.items()}
|
||||||
model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()}
|
model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()}
|
||||||
model_saved_path = f"{tempdir}/save_model_cuda.safetensors"
|
model_saved_path = f"{tempdir}/save_model_cuda.safetensors"
|
||||||
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
|
f_writer = AsyncFileWriter(model_saved_path, n_entries=191, backend="pthread")
|
||||||
move_and_save(f_writer, model_state_dict_cuda, model_state_pinned)
|
move_and_save(f_writer, model_state_dict_cuda, model_state_pinned)
|
||||||
f_writer.sync_before_step()
|
f_writer.sync_before_step()
|
||||||
f_writer.synchronize()
|
f_writer.synchronize()
|
||||||
f_writer.fp.close()
|
del f_writer
|
||||||
load_state_dict = load_file(model_saved_path)
|
load_state_dict = load_file(model_saved_path)
|
||||||
check_state_dict_equal(model_state_dict, load_state_dict)
|
check_state_dict_equal(model_state_dict, load_state_dict)
|
||||||
|
Loading…
Reference in New Issue
Block a user