From 2dd4d556fb296059a9bed19061fd9e0027bfc3ea Mon Sep 17 00:00:00 2001 From: Ofey Chan Date: Wed, 13 Jul 2022 10:51:55 +0800 Subject: [PATCH] [NFC] polish colossalai/nn/init.py code style (#1292) --- colossalai/nn/init.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/colossalai/nn/init.py b/colossalai/nn/init.py index 1a1164290..559b7038f 100644 --- a/colossalai/nn/init.py +++ b/colossalai/nn/init.py @@ -7,6 +7,7 @@ import torch.nn as nn def zeros_(): """Return the initializer filling the input Tensor with the scalar zeros""" + def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): return nn.init.zeros_(tensor) @@ -15,6 +16,7 @@ def zeros_(): def ones_(): """Return the initializer filling the input Tensor with the scalar ones""" + def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): return nn.init.ones_(tensor) @@ -46,6 +48,7 @@ def normal_(mean: float = 0., std: float = 1.): mean (float): the mean of the normal distribution. Defaults 0.0. std (float): the standard deviation of the normal distribution. Defaults 1.0. """ + def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): return nn.init.normal_(tensor, mean, std) @@ -66,6 +69,7 @@ def trunc_normal_(mean: float = 0., std: float = 1., a: float = -2., b: float = a (float): the minimum cutoff value. Defaults -2.0. b (float): the maximum cutoff value. Defaults 2.0. """ + def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): return nn.init.trunc_normal_(tensor, mean, std, a, b) @@ -93,6 +97,7 @@ def kaiming_uniform_(a=0, mode='fan_in', nonlinearity='leaky_relu'): nonlinearity (str, optional): the non-linear function (`nn.functional` name), recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). """ + # adapted from torch.nn.init def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): if 0 in tensor.shape: @@ -136,6 +141,7 @@ def kaiming_normal_(a=0, mode='fan_in', nonlinearity='leaky_relu'): nonlinearity (str, optional): the non-linear function (`nn.functional` name), recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). """ + # adapted from torch.nn.init def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): if 0 in tensor.shape: @@ -175,6 +181,7 @@ def xavier_uniform_(a: float = math.sqrt(3.), scale: float = 2., gain: float = 1 scale (float, optional): an optional scaling factor used to calculate standard deviation. Defaults 2.0. gain (float, optional): an optional scaling factor. Defaults 1.0. """ + # adapted from torch.nn.init def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): assert fan_in is not None, 'Fan_in is not provided.' @@ -206,6 +213,7 @@ def xavier_normal_(scale: float = 2., gain: float = 1.): scale (float, optional): an optional scaling factor used to calculate standard deviation. Defaults 2.0. gain (float, optional): an optional scaling factor. Defaults 1.0. """ + # adapted from torch.nn.init def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): assert fan_in is not None, 'Fan_in is not provided.' @@ -241,4 +249,4 @@ def lecun_normal_(): std = math.sqrt(1.0 / fan_in) return nn.init.trunc_normal_(tensor, std=std / .87962566103423978) - return initializer \ No newline at end of file + return initializer