mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
update some module with new api version
This commit is contained in:
@@ -537,10 +537,11 @@ class FusedLinear1D_Col(ParallelModule):
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
n_fused: int = 3,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
super().__init__()
|
||||
|
||||
# Keep input parameters
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
@@ -554,36 +555,52 @@ class FusedLinear1D_Col(ParallelModule):
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||
|
||||
# sanity check
|
||||
if weight is not None:
|
||||
assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None'
|
||||
else:
|
||||
assert bias_ is None, 'bias_ must be None if weight is None'
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
|
||||
if weight is None:
|
||||
# Initialize weight.
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
self.weight = weight
|
||||
|
||||
def shard_fn(tensor):
|
||||
return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False)
|
||||
|
||||
def gather_fn(tensor):
|
||||
return gather_fused_qkv_in_gpt2_style(tensor, 3, self.process_group, False)
|
||||
return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False)
|
||||
|
||||
with torch.no_grad():
|
||||
sharded_weight = distribute_tensor_with_customization(weight, shard_fn, gather_fn)
|
||||
self.weight = customized_distributed_tensor_to_param(sharded_weight)
|
||||
if not is_customized_distributed_tensor(self.weight):
|
||||
with torch.no_grad():
|
||||
sharded_weight = distribute_tensor_with_customization(self.weight.data, shard_fn, gather_fn)
|
||||
customized_distributed_tensor_to_existing_param(sharded_weight, self.weight)
|
||||
|
||||
if bias:
|
||||
bias = torch.empty(self.out_features, **factory_kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
sharded_bias = distribute_tensor_with_customization(bias, shard_fn, gather_fn)
|
||||
self.bias = customized_distributed_tensor_to_param(sharded_bias)
|
||||
if bias_ is None:
|
||||
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
|
||||
else:
|
||||
bias_.data = bias_.data.to(device=device, dtype=dtype)
|
||||
self.bias = bias_
|
||||
if not is_customized_distributed_tensor(self.bias):
|
||||
with torch.no_grad():
|
||||
sharded_bias = distribute_tensor_with_customization(self.bias.data, shard_fn, gather_fn)
|
||||
customized_distributed_tensor_to_existing_param(sharded_bias, self.bias)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||
|
||||
# init weights
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
if weight is None:
|
||||
# init weights
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int,
|
||||
@@ -613,24 +630,26 @@ class FusedLinear1D_Col(ParallelModule):
|
||||
bias=bias,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
*args,
|
||||
**kwargs)
|
||||
|
||||
# TODO: copy the sharded weights
|
||||
with torch.no_grad():
|
||||
sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data,
|
||||
n_fused=n_fused,
|
||||
process_group=process_group,
|
||||
is_transposed=False)
|
||||
linear_1d.weight.data.copy_(sharded_weight.data)
|
||||
|
||||
if bias:
|
||||
sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data,
|
||||
n_fused=n_fused,
|
||||
process_group=process_group,
|
||||
is_transposed=False)
|
||||
linear_1d.bias.data.copy_(sharded_bias.data)
|
||||
# # TODO: copy the sharded weights
|
||||
# with torch.no_grad():
|
||||
# sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data,
|
||||
# n_fused=n_fused,
|
||||
# process_group=process_group,
|
||||
# is_transposed=False)
|
||||
# linear_1d.weight.data.copy_(sharded_weight.data)
|
||||
|
||||
# if bias:
|
||||
# sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data,
|
||||
# n_fused=n_fused,
|
||||
# process_group=process_group,
|
||||
# is_transposed=False)
|
||||
# linear_1d.bias.data.copy_(sharded_bias.data)
|
||||
print(linear_1d.weight.shape)
|
||||
return linear_1d
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
|
Reference in New Issue
Block a user