mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 18:09:06 +00:00
[utils] refactor parallel layers checkpoint and bcast model on loading checkpoint (#1548)
* refactor parallel layer * broadcast rank0 model after load ckpt
This commit is contained in:
@@ -5,9 +5,11 @@ import torch.nn as nn
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
class ParallelLayer(nn.Module):
|
||||
global_state_dict: bool = True
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -26,10 +28,35 @@ class ParallelLayer(nn.Module):
|
||||
self.pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
|
||||
ParallelMode.PIPELINE)
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
||||
error_msgs):
|
||||
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
||||
error_msgs)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
return super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
||||
error_msgs):
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
||||
error_msgs)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) != 0:
|
||||
missing_keys.clear()
|
||||
unexpected_keys.clear()
|
||||
if self.global_state_dict:
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) != 0:
|
||||
missing_keys.clear()
|
||||
unexpected_keys.clear()
|
||||
return self._load_from_global_state_dict(state_dict, prefix, local_metadata, strict, missing_keys,
|
||||
unexpected_keys, error_msgs)
|
||||
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
||||
error_msgs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
if self.global_state_dict:
|
||||
return self._save_to_global_state_dict(destination, prefix, keep_vars)
|
||||
return super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def use_local_state_dict(cls):
|
||||
try:
|
||||
cls.global_state_dict = False
|
||||
yield
|
||||
finally:
|
||||
cls.global_state_dict = True
|
||||
|
@@ -189,7 +189,7 @@ class Classifier1D(ParallelLayer):
|
||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
@@ -215,9 +215,9 @@ class Classifier1D(ParallelLayer):
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
})
|
||||
super()._load_from_state_dict(local_state, prefix, *args)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict()
|
||||
@@ -242,12 +242,12 @@ class Classifier1D(ParallelLayer):
|
||||
# Set up backprop all-reduce.
|
||||
if self.parallel_input:
|
||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||
'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||
input_ = input_
|
||||
else:
|
||||
assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \
|
||||
'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size)
|
||||
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
|
||||
@@ -326,7 +326,7 @@ class VocabParallelClassifier1D(ParallelLayer):
|
||||
if self.bias is not None:
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
@@ -352,9 +352,9 @@ class VocabParallelClassifier1D(ParallelLayer):
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
})
|
||||
super()._load_from_state_dict(local_state, prefix, *args)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict()
|
||||
@@ -461,7 +461,7 @@ class Linear1D_Col(ParallelLayer):
|
||||
if self.bias is not None:
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
@@ -486,9 +486,9 @@ class Linear1D_Col(ParallelLayer):
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
})
|
||||
super()._load_from_state_dict(local_state, prefix, *args)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
@@ -598,7 +598,7 @@ class Linear1D_Row(ParallelLayer):
|
||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
@@ -623,9 +623,9 @@ class Linear1D_Row(ParallelLayer):
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
})
|
||||
super()._load_from_state_dict(local_state, prefix, *args)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
@@ -648,12 +648,12 @@ class Linear1D_Row(ParallelLayer):
|
||||
# Set up backprop all-reduce.
|
||||
if self.parallel_input:
|
||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||
input_ = input_
|
||||
else:
|
||||
assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \
|
||||
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size)
|
||||
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
|
||||
@@ -738,7 +738,7 @@ class Embedding1D(ParallelLayer):
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
@@ -751,9 +751,9 @@ class Embedding1D(ParallelLayer):
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={weight_key: -1},
|
||||
partition_states={weight_key: True})
|
||||
super()._load_from_state_dict(local_state, prefix, *args)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
local_state = gather_tensor_parallel_state_dict(local_state,
|
||||
@@ -773,7 +773,7 @@ class Embedding1D(ParallelLayer):
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class VocabParallelEmbedding1D(torch.nn.Module):
|
||||
class VocabParallelEmbedding1D(ParallelLayer):
|
||||
r"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
Args:
|
||||
@@ -847,7 +847,7 @@ class VocabParallelEmbedding1D(torch.nn.Module):
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
@@ -860,9 +860,9 @@ class VocabParallelEmbedding1D(torch.nn.Module):
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={weight_key: 0},
|
||||
partition_states={weight_key: True})
|
||||
super()._load_from_state_dict(local_state, prefix, *args)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
local_state = gather_tensor_parallel_state_dict(local_state,
|
||||
|
@@ -94,7 +94,7 @@ class Linear2D(ParallelLayer):
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
@@ -137,9 +137,9 @@ class Linear2D(ParallelLayer):
|
||||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
@@ -252,7 +252,7 @@ class LayerNorm2D(ParallelLayer):
|
||||
if self.bias is not None:
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
@@ -294,9 +294,9 @@ class LayerNorm2D(ParallelLayer):
|
||||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
@@ -443,7 +443,7 @@ class PatchEmbedding2D(ParallelLayer):
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
position_embed_initializer(self.pos_embed)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
@@ -503,9 +503,9 @@ class PatchEmbedding2D(ParallelLayer):
|
||||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
cls_token_key = prefix + 'cls_token'
|
||||
@@ -651,7 +651,7 @@ class Embedding2D(ParallelLayer):
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
@@ -676,9 +676,9 @@ class Embedding2D(ParallelLayer):
|
||||
partition_states={weight_key: True},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
||||
@@ -712,7 +712,7 @@ class Embedding2D(ParallelLayer):
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class VocabParallelEmbedding2D(torch.nn.Module):
|
||||
class VocabParallelEmbedding2D(ParallelLayer):
|
||||
r"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
Args:
|
||||
@@ -789,7 +789,7 @@ class VocabParallelEmbedding2D(torch.nn.Module):
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
@@ -814,9 +814,9 @@ class VocabParallelEmbedding2D(torch.nn.Module):
|
||||
partition_states={weight_key: True},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
||||
@@ -924,7 +924,7 @@ class Classifier2D(ParallelLayer):
|
||||
broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2D_COL)
|
||||
broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2D_ROW)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
@@ -968,9 +968,9 @@ class Classifier2D(ParallelLayer):
|
||||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict()
|
||||
@@ -1095,7 +1095,7 @@ class VocabParallelClassifier2D(ParallelLayer):
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
@@ -1139,9 +1139,9 @@ class VocabParallelClassifier2D(ParallelLayer):
|
||||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict()
|
||||
|
@@ -96,7 +96,7 @@ class Linear2p5D(ParallelLayer):
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
@@ -143,9 +143,9 @@ class Linear2p5D(ParallelLayer):
|
||||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) == 0:
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
@@ -272,7 +272,7 @@ class LayerNorm2p5D(ParallelLayer):
|
||||
if self.bias is not None:
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
@@ -314,9 +314,9 @@ class LayerNorm2p5D(ParallelLayer):
|
||||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
@@ -463,7 +463,7 @@ class PatchEmbedding2p5D(ParallelLayer):
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
position_embed_initializer(self.pos_embed)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
@@ -523,9 +523,9 @@ class PatchEmbedding2p5D(ParallelLayer):
|
||||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
cls_token_key = prefix + 'cls_token'
|
||||
@@ -671,7 +671,7 @@ class Embedding2p5D(ParallelLayer):
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
@@ -696,9 +696,9 @@ class Embedding2p5D(ParallelLayer):
|
||||
partition_states={weight_key: True},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
||||
@@ -733,7 +733,7 @@ class Embedding2p5D(ParallelLayer):
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class VocabParallelEmbedding2p5D(torch.nn.Module):
|
||||
class VocabParallelEmbedding2p5D(ParallelLayer):
|
||||
"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
Args:
|
||||
@@ -810,7 +810,7 @@ class VocabParallelEmbedding2p5D(torch.nn.Module):
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
@@ -835,9 +835,9 @@ class VocabParallelEmbedding2p5D(torch.nn.Module):
|
||||
partition_states={weight_key: True},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
||||
@@ -950,7 +950,7 @@ class Classifier2p5D(ParallelLayer):
|
||||
broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2P5D_COL)
|
||||
broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2P5D_ROW)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
@@ -994,9 +994,9 @@ class Classifier2p5D(ParallelLayer):
|
||||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict()
|
||||
@@ -1123,7 +1123,7 @@ class VocabParallelClassifier2p5D(ParallelLayer):
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
@@ -1167,7 +1167,7 @@ class VocabParallelClassifier2p5D(ParallelLayer):
|
||||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# input: [m/dq, n/q, k/q]
|
||||
|
@@ -70,7 +70,7 @@ class LayerNorm3D(ParallelLayer):
|
||||
if self.bias is not None:
|
||||
init.zeros_()(self.bias)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
@@ -105,9 +105,9 @@ class LayerNorm3D(ParallelLayer):
|
||||
# broadcast in weight groups
|
||||
local_state = broadcast_state_dict(local_state, self.weight_parallel_mode)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
@@ -207,7 +207,7 @@ class Linear3D(ParallelLayer):
|
||||
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
|
||||
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
@@ -265,9 +265,9 @@ class Linear3D(ParallelLayer):
|
||||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
@@ -400,7 +400,7 @@ class Classifier3D(ParallelLayer):
|
||||
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
|
||||
broadcast(self.bias, input_src_rank, self.input_parallel_mode)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
@@ -437,9 +437,9 @@ class Classifier3D(ParallelLayer):
|
||||
# broadcast in weight groups
|
||||
local_state = broadcast_state_dict(local_state, self.weight_parallel_mode)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict()
|
||||
@@ -551,7 +551,7 @@ class VocabParallelClassifier3D(ParallelLayer):
|
||||
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
|
||||
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
@@ -610,9 +610,9 @@ class VocabParallelClassifier3D(ParallelLayer):
|
||||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
@@ -763,7 +763,7 @@ class PatchEmbedding3D(ParallelLayer):
|
||||
self.cls_token.register_hook(self._sync_grad_hook)
|
||||
self.pos_embed.register_hook(self._sync_grad_hook)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
@@ -812,9 +812,9 @@ class PatchEmbedding3D(ParallelLayer):
|
||||
# broadcast in weight groups
|
||||
local_state = broadcast_state_dict(local_state, self.weight_parallel_mode)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
cls_token_key = prefix + 'cls_token'
|
||||
@@ -937,7 +937,7 @@ class Embedding3D(ParallelLayer):
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
@@ -961,9 +961,9 @@ class Embedding3D(ParallelLayer):
|
||||
# broadcast in weight groups
|
||||
local_state = broadcast_state_dict(local_state, self.weight_parallel_mode)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
||||
@@ -991,7 +991,7 @@ class Embedding3D(ParallelLayer):
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class VocabParallelEmbedding3D(torch.nn.Module):
|
||||
class VocabParallelEmbedding3D(ParallelLayer):
|
||||
r"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
Args:
|
||||
@@ -1070,7 +1070,7 @@ class VocabParallelEmbedding3D(torch.nn.Module):
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
@@ -1104,9 +1104,9 @@ class VocabParallelEmbedding3D(torch.nn.Module):
|
||||
partition_states={weight_key: True},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
||||
|
@@ -3,9 +3,9 @@ from itertools import chain
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.communication.collective import scatter_object_list
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.constants import IS_TENSOR_PARALLEL
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
|
||||
except ImportError:
|
||||
@@ -190,6 +190,15 @@ def save_checkpoint(file,
|
||||
torch.save(checkpoint, file, **kwargs)
|
||||
|
||||
|
||||
def broadcast_model(model: torch.nn.Module):
|
||||
src_rank = gpc.get_ranks_in_group(ParallelMode.TENSOR)[0]
|
||||
for p in model.parameters():
|
||||
if not getattr(p, IS_TENSOR_PARALLEL, False) and p.storage().size() > 0:
|
||||
group = gpc.get_group(ParallelMode.TENSOR) if p.device.type == 'cuda' else gpc.get_cpu_group(
|
||||
ParallelMode.TENSOR)
|
||||
dist.broadcast(p, src_rank, group=group)
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
file,
|
||||
model: torch.nn.Module,
|
||||
@@ -225,6 +234,7 @@ def load_checkpoint(
|
||||
model_state = partition_pipeline_parallel_state_dict(model, model_state)
|
||||
try:
|
||||
model.load_state_dict(model_state, strict=strict)
|
||||
broadcast_model(model)
|
||||
except RuntimeError as e:
|
||||
error_msgs = str(e)
|
||||
if error_msgs.startswith("Error(s) in loading state_dict for "):
|
||||
|
Reference in New Issue
Block a user