[shardformer] fix master param sync for hybrid plugin/rewrite unwrapping logic (#4758)

* fix master param sync for hybrid plugin

* rewrite unwrap for ddp/fsdp

* rewrite unwrap for zero/gemini

* rewrite unwrap for hybrid plugin

* fix geemini unwrap

* fix bugs
This commit is contained in:
Baizhou Zhang
2023-09-20 18:29:37 +08:00
committed by GitHub
parent 7b9b86441f
commit c0a033700c
14 changed files with 141 additions and 171 deletions

View File

@@ -3,7 +3,7 @@ import logging
import os
from pathlib import Path
from shutil import rmtree
from typing import Dict, Iterator, Optional, OrderedDict, Tuple, Union
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
import torch
import torch.distributed as dist
@@ -13,7 +13,7 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.cluster import DistCoordinator
from colossalai.interface import OptimizerWrapper
from colossalai.interface import ModelWrapper, OptimizerWrapper
from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile
@@ -71,8 +71,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
self.tp_size = dist.get_world_size(tp_group)
self.use_zero = zero_stage > 0
self.verbose = verbose
self.working_to_master_map = None
self.master_to_working_map = None
self.coordinator = DistCoordinator()
@staticmethod
@@ -159,7 +157,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
def save_sharded_model(
self,
model: nn.Module,
model: ModelWrapper,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
@@ -184,6 +182,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model = model.unwrap()
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
@@ -279,7 +280,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
f"index located at {final_index_file_path}."
)
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False):
"""
Load sharded model with the given path to index file of checkpoint folder.
@@ -289,6 +290,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
This argument should be manually set to False since params on same device might be stored in different files.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
model_before_wrapping = model # backup for model before wrapping
model = model.unwrap()
# Check whether the checkpoint uses safetensors.
use_safetensors = False
@@ -347,23 +351,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
_load(extra_state_key)
# Update master params if mixed-precision training is enabled.
with torch.no_grad():
if self.working_to_master_map is not None:
for param in model.parameters():
if (param is None) or (id(param) not in self.working_to_master_map):
continue
master_param = self.working_to_master_map[id(param)]
if self.use_zero:
# master_param is sharded under Zero setting
padding_size = (self.dp_size - param.numel() % self.dp_size) % self.dp_size
if padding_size > 0:
padded_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
else:
padded_param = param.data.view(-1)
sharded_param = padded_param.split(padded_param.numel() // self.dp_size)[self.dp_rank]
master_param.data.copy_(sharded_param.data)
else:
master_param.data.copy_(param.data)
model_before_wrapping.update_master_params()
if self.verbose:
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
@@ -392,6 +380,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file shard that store state tensors
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
@@ -410,7 +399,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
use_zero=self.use_zero,
dp_group=self.dp_group,
tp_group=self.tp_group,
master_to_working_map=self.master_to_working_map,
master_to_working_map=optimizer.get_master_to_working_map(),
size_per_shard=size_per_shard,
)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
@@ -511,6 +500,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
checkpoint_index_file (str): Path to the index file of checkpointing folder.
prefix (str): Not used.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
def _get_param_id_from_optimizer_param(
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
@@ -525,9 +515,10 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# When Zero is used, the mapped parameter objects should be fp32 master parameters.
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
id_map = {}
master_to_working_map = optimizer.get_master_to_working_map()
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
id_map[param_id] = param
# Read checkpoint index file.
@@ -560,7 +551,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
for param in pg["params"]:
if param is None:
continue
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
if param_id not in weight_map:
continue
filename = weight_map[param_id]
@@ -577,8 +568,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
device = param.device
if self.master_to_working_map is not None:
working_param = self.master_to_working_map[id(param)]
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
@@ -614,42 +605,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
def link_master_and_working_param(
self,
working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor],
master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor],
):
"""
Create mappings between working params (for forward/backward) and master params (for optimizer update) with passed in mappings.
This mapping can only be created when mixied precision is used.
The created mappings should be mappings from integer parameter addresses to parameter objects.
Args:
working_to_master_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from working parameters objects/addresses to master parameter objects.
master_to_working_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from master parameters objects/addresses to working parameter objects.
"""
self.working_to_master_map = dict()
for k, v in working_to_master_map.items():
if isinstance(k, torch.Tensor):
self.working_to_master_map[id(k)] = v
elif isinstance(k, int):
self.working_to_master_map[k] = v
else:
raise ValueError(
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!"
)
self.master_to_working_map = dict()
for k, v in master_to_working_map.items():
if isinstance(k, torch.Tensor):
self.master_to_working_map[id(k)] = v
elif isinstance(k, int):
self.master_to_working_map[k] = v
else:
raise ValueError(
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!"
)
@staticmethod
def gather_from_sharded_optimizer_state(
state: OrderedDict,