update some module with new api version

This commit is contained in:
FoolPlayer
2023-08-01 18:02:49 +08:00
committed by Hongxin Liu
parent 879301d0da
commit 726541afe2
7 changed files with 88 additions and 48 deletions

View File

@@ -537,10 +537,11 @@ class FusedLinear1D_Col(ParallelModule):
gather_output: bool = False,
skip_bias_add: bool = False,
n_fused: int = 3,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
@@ -554,36 +555,52 @@ class FusedLinear1D_Col(ParallelModule):
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
# sanity check
if weight is not None:
assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None'
else:
assert bias_ is None, 'bias_ must be None if weight is None'
# Parameters.
# Initialize weight.
factory_kwargs = {'device': device, 'dtype': dtype}
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
if weight is None:
# Initialize weight.
factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
def shard_fn(tensor):
return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False)
def gather_fn(tensor):
return gather_fused_qkv_in_gpt2_style(tensor, 3, self.process_group, False)
return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False)
with torch.no_grad():
sharded_weight = distribute_tensor_with_customization(weight, shard_fn, gather_fn)
self.weight = customized_distributed_tensor_to_param(sharded_weight)
if not is_customized_distributed_tensor(self.weight):
with torch.no_grad():
sharded_weight = distribute_tensor_with_customization(self.weight.data, shard_fn, gather_fn)
customized_distributed_tensor_to_existing_param(sharded_weight, self.weight)
if bias:
bias = torch.empty(self.out_features, **factory_kwargs)
with torch.no_grad():
sharded_bias = distribute_tensor_with_customization(bias, shard_fn, gather_fn)
self.bias = customized_distributed_tensor_to_param(sharded_bias)
if bias_ is None:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
bias_.data = bias_.data.to(device=device, dtype=dtype)
self.bias = bias_
if not is_customized_distributed_tensor(self.bias):
with torch.no_grad():
sharded_bias = distribute_tensor_with_customization(self.bias.data, shard_fn, gather_fn)
customized_distributed_tensor_to_existing_param(sharded_bias, self.bias)
else:
self.bias = None
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
# init weights
self.reset_parameters(weight_initializer, bias_initializer)
if weight is None:
# init weights
self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int,
@@ -613,24 +630,26 @@ class FusedLinear1D_Col(ParallelModule):
bias=bias,
device=device,
process_group=process_group,
weight=module.weight,
bias_=module.bias,
*args,
**kwargs)
# TODO: copy the sharded weights
with torch.no_grad():
sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data,
n_fused=n_fused,
process_group=process_group,
is_transposed=False)
linear_1d.weight.data.copy_(sharded_weight.data)
if bias:
sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data,
n_fused=n_fused,
process_group=process_group,
is_transposed=False)
linear_1d.bias.data.copy_(sharded_bias.data)
# # TODO: copy the sharded weights
# with torch.no_grad():
# sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data,
# n_fused=n_fused,
# process_group=process_group,
# is_transposed=False)
# linear_1d.weight.data.copy_(sharded_weight.data)
# if bias:
# sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data,
# n_fused=n_fused,
# process_group=process_group,
# is_transposed=False)
# linear_1d.bias.data.copy_(sharded_bias.data)
print(linear_1d.weight.shape)
return linear_1d
def reset_parameters(self, weight_initializer, bias_initializer) -> None: