mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 13:30:19 +00:00
[shardformer] fix master param sync for hybrid plugin/rewrite unwrapping logic (#4758)
* fix master param sync for hybrid plugin * rewrite unwrap for ddp/fsdp * rewrite unwrap for zero/gemini * rewrite unwrap for hybrid plugin * fix geemini unwrap * fix bugs
This commit is contained in:
@@ -3,7 +3,7 @@ import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Dict, Iterator, Optional, OrderedDict, Tuple, Union
|
||||
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -13,7 +13,7 @@ from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
|
||||
from .general_checkpoint_io import GeneralCheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
@@ -71,8 +71,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
self.tp_size = dist.get_world_size(tp_group)
|
||||
self.use_zero = zero_stage > 0
|
||||
self.verbose = verbose
|
||||
self.working_to_master_map = None
|
||||
self.master_to_working_map = None
|
||||
self.coordinator = DistCoordinator()
|
||||
|
||||
@staticmethod
|
||||
@@ -159,7 +157,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
def save_sharded_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
model: ModelWrapper,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
@@ -184,6 +182,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
|
||||
"""
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
model = model.unwrap()
|
||||
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
@@ -279,7 +280,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
f"index located at {final_index_file_path}."
|
||||
)
|
||||
|
||||
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
|
||||
def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False):
|
||||
"""
|
||||
Load sharded model with the given path to index file of checkpoint folder.
|
||||
|
||||
@@ -289,6 +290,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
|
||||
This argument should be manually set to False since params on same device might be stored in different files.
|
||||
"""
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
||||
model_before_wrapping = model # backup for model before wrapping
|
||||
model = model.unwrap()
|
||||
|
||||
# Check whether the checkpoint uses safetensors.
|
||||
use_safetensors = False
|
||||
@@ -347,23 +351,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
_load(extra_state_key)
|
||||
|
||||
# Update master params if mixed-precision training is enabled.
|
||||
with torch.no_grad():
|
||||
if self.working_to_master_map is not None:
|
||||
for param in model.parameters():
|
||||
if (param is None) or (id(param) not in self.working_to_master_map):
|
||||
continue
|
||||
master_param = self.working_to_master_map[id(param)]
|
||||
if self.use_zero:
|
||||
# master_param is sharded under Zero setting
|
||||
padding_size = (self.dp_size - param.numel() % self.dp_size) % self.dp_size
|
||||
if padding_size > 0:
|
||||
padded_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
|
||||
else:
|
||||
padded_param = param.data.view(-1)
|
||||
sharded_param = padded_param.split(padded_param.numel() // self.dp_size)[self.dp_rank]
|
||||
master_param.data.copy_(sharded_param.data)
|
||||
else:
|
||||
master_param.data.copy_(param.data)
|
||||
model_before_wrapping.update_master_params()
|
||||
|
||||
if self.verbose:
|
||||
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||
@@ -392,6 +380,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
prefix (str): Perfix of file to save
|
||||
size_per_shard (int): Max file size of each file shard that store state tensors
|
||||
"""
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
@@ -410,7 +399,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
use_zero=self.use_zero,
|
||||
dp_group=self.dp_group,
|
||||
tp_group=self.tp_group,
|
||||
master_to_working_map=self.master_to_working_map,
|
||||
master_to_working_map=optimizer.get_master_to_working_map(),
|
||||
size_per_shard=size_per_shard,
|
||||
)
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||
@@ -511,6 +500,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
checkpoint_index_file (str): Path to the index file of checkpointing folder.
|
||||
prefix (str): Not used.
|
||||
"""
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
|
||||
def _get_param_id_from_optimizer_param(
|
||||
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
|
||||
@@ -525,9 +515,10 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
# When Zero is used, the mapped parameter objects should be fp32 master parameters.
|
||||
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
|
||||
id_map = {}
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
for pg in optimizer.optim.param_groups:
|
||||
for param in pg["params"]:
|
||||
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
||||
id_map[param_id] = param
|
||||
|
||||
# Read checkpoint index file.
|
||||
@@ -560,7 +551,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
for param in pg["params"]:
|
||||
if param is None:
|
||||
continue
|
||||
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
||||
if param_id not in weight_map:
|
||||
continue
|
||||
filename = weight_map[param_id]
|
||||
@@ -577,8 +568,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
# Then shard the loaded optimizer states if using tp/zero.
|
||||
for param, state in optimizer.optim.state.items():
|
||||
device = param.device
|
||||
if self.master_to_working_map is not None:
|
||||
working_param = self.master_to_working_map[id(param)]
|
||||
if master_to_working_map is not None:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||
@@ -614,42 +605,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
if self.coordinator.is_master():
|
||||
super().save_lr_scheduler(lr_scheduler, checkpoint)
|
||||
|
||||
def link_master_and_working_param(
|
||||
self,
|
||||
working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor],
|
||||
master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor],
|
||||
):
|
||||
"""
|
||||
Create mappings between working params (for forward/backward) and master params (for optimizer update) with passed in mappings.
|
||||
This mapping can only be created when mixied precision is used.
|
||||
The created mappings should be mappings from integer parameter addresses to parameter objects.
|
||||
|
||||
Args:
|
||||
working_to_master_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from working parameters objects/addresses to master parameter objects.
|
||||
master_to_working_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from master parameters objects/addresses to working parameter objects.
|
||||
"""
|
||||
self.working_to_master_map = dict()
|
||||
for k, v in working_to_master_map.items():
|
||||
if isinstance(k, torch.Tensor):
|
||||
self.working_to_master_map[id(k)] = v
|
||||
elif isinstance(k, int):
|
||||
self.working_to_master_map[k] = v
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!"
|
||||
)
|
||||
|
||||
self.master_to_working_map = dict()
|
||||
for k, v in master_to_working_map.items():
|
||||
if isinstance(k, torch.Tensor):
|
||||
self.master_to_working_map[id(k)] = v
|
||||
elif isinstance(k, int):
|
||||
self.master_to_working_map[k] = v
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def gather_from_sharded_optimizer_state(
|
||||
state: OrderedDict,
|
||||
|
Reference in New Issue
Block a user