mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
Layer integration (#83)
* integrated parallel layers for ease of building models * integrated 2.5d layers * cleaned codes and unit tests * added log metric by step hook; updated imagenet benchmark; fixed some bugs * reworked initialization; cleaned codes Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
This commit is contained in:
@@ -1,33 +1,140 @@
|
||||
import math
|
||||
import warnings
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import init as init
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def init_weight_(tensor: Tensor, fan_in: int, fan_out: int = None, init_method: str = 'torch'):
|
||||
if init_method == 'torch':
|
||||
a = math.sqrt(5)
|
||||
nonlinearity = 'leaky_relu'
|
||||
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
|
||||
bound = math.sqrt(3.0) * std
|
||||
init.uniform_(tensor, -bound, bound)
|
||||
elif init_method == 'jax':
|
||||
std = math.sqrt(2.0 / float(fan_in + fan_out))
|
||||
a = math.sqrt(3.0) * std
|
||||
init.uniform_(tensor, -a, a)
|
||||
elif init_method == 'jax_embed':
|
||||
def zeros_():
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
return nn.init.zeros_(tensor)
|
||||
|
||||
return initializer
|
||||
|
||||
|
||||
def ones_():
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
return nn.init.ones_(tensor)
|
||||
|
||||
return initializer
|
||||
|
||||
|
||||
def uniform_(a: float = 0., b: float = 1.):
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
return nn.init.uniform_(tensor, a, b)
|
||||
|
||||
return initializer
|
||||
|
||||
|
||||
def normal_(mean: float = 0., std: float = 1.):
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
return nn.init.normal_(tensor, mean, std)
|
||||
|
||||
return initializer
|
||||
|
||||
|
||||
def trunc_normal_(mean: float = 0., std: float = 1., a: float = -2., b: float = 2.):
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
return nn.init.trunc_normal_(tensor, mean, std, a, b)
|
||||
|
||||
return initializer
|
||||
|
||||
|
||||
def kaiming_uniform_(a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
# adapted from torch.nn.init
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
if 0 in tensor.shape:
|
||||
warnings.warn("Initializing zero-element tensors is a no-op")
|
||||
return tensor
|
||||
|
||||
if mode == 'fan_in':
|
||||
assert fan_in is not None, 'Fan_in is not provided.'
|
||||
fan = fan_in
|
||||
elif mode == 'fan_out':
|
||||
assert fan_out is not None, 'Fan_out is not provided.'
|
||||
fan = fan_out
|
||||
else:
|
||||
raise ValueError(f'Invalid initialization mode \'{mode}\'')
|
||||
|
||||
std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan)
|
||||
bound = math.sqrt(3.) * std
|
||||
return nn.init.uniform_(tensor, -bound, bound)
|
||||
|
||||
return initializer
|
||||
|
||||
|
||||
def kaiming_normal_(a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
# adapted from torch.nn.init
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
if 0 in tensor.shape:
|
||||
warnings.warn("Initializing zero-element tensors is a no-op")
|
||||
return tensor
|
||||
|
||||
if mode == 'fan_in':
|
||||
assert fan_in is not None, 'Fan_in is not provided.'
|
||||
fan = fan_in
|
||||
elif mode == 'fan_out':
|
||||
assert fan_out is not None, 'Fan_out is not provided.'
|
||||
fan = fan_out
|
||||
else:
|
||||
raise ValueError(f'Invalid initialization mode \'{mode}\'')
|
||||
|
||||
std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan)
|
||||
return nn.init.normal_(tensor, 0, std)
|
||||
|
||||
return initializer
|
||||
|
||||
|
||||
def xavier_uniform_(a: float = math.sqrt(3.), scale: float = 2., gain: float = 1.):
|
||||
# adapted from torch.nn.init
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
assert fan_in is not None, 'Fan_in is not provided.'
|
||||
|
||||
fan = fan_in
|
||||
if fan_out is not None:
|
||||
fan += fan_out
|
||||
|
||||
std = gain * math.sqrt(scale / float(fan))
|
||||
bound = a * std
|
||||
return nn.init.uniform_(tensor, -bound, bound)
|
||||
|
||||
return initializer
|
||||
|
||||
|
||||
def xavier_normal_(scale: float = 2., gain: float = 1.):
|
||||
# adapted from torch.nn.init
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
assert fan_in is not None, 'Fan_in is not provided.'
|
||||
|
||||
fan = fan_in
|
||||
if fan_out is not None:
|
||||
fan += fan_out
|
||||
|
||||
std = gain * math.sqrt(scale / float(fan))
|
||||
|
||||
return nn.init.normal_(tensor, 0., std)
|
||||
|
||||
return initializer
|
||||
|
||||
|
||||
def lecun_uniform_():
|
||||
# adapted from jax.nn.initializers
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
assert fan_in is not None, 'Fan_in is not provided.'
|
||||
|
||||
var = 1.0 / fan_in
|
||||
bound = math.sqrt(3 * var)
|
||||
return nn.init.uniform_(tensor, -bound, bound)
|
||||
|
||||
return initializer
|
||||
|
||||
|
||||
def lecun_normal_():
|
||||
# adapted from jax.nn.initializers
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
assert fan_in is not None, 'Fan_in is not provided.'
|
||||
|
||||
std = math.sqrt(1.0 / fan_in)
|
||||
init.trunc_normal_(tensor, std=std / .87962566103423978)
|
||||
elif init_method == 'zero':
|
||||
init.zeros_(tensor)
|
||||
return nn.init.trunc_normal_(tensor, std=std / .87962566103423978)
|
||||
|
||||
def init_bias_(tensor: Tensor, fan_in: int, init_method: str = 'torch'):
|
||||
if init_method == 'torch':
|
||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||||
init.uniform_(tensor, -bound, bound)
|
||||
elif init_method == 'jax':
|
||||
init.normal_(tensor, std=1e-6)
|
||||
elif init_method == 'jax_embed':
|
||||
init.trunc_normal_(tensor, std=.02)
|
||||
elif init_method == 'zero':
|
||||
init.zeros_(tensor)
|
||||
return initializer
|
||||
|
Reference in New Issue
Block a user