mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-21 02:40:48 +00:00
remove useless platform args and comment
This commit is contained in:
parent
a9b5ec8664
commit
3942218618
@ -12,37 +12,6 @@ class CPUAdamLoader(BaseKernelLoader):
|
|||||||
Usage:
|
Usage:
|
||||||
# init
|
# init
|
||||||
cpu_adam = CPUAdamLoader().load()
|
cpu_adam = CPUAdamLoader().load()
|
||||||
cpu_adam_op = cpu_adam.CPUAdamOptimizer(
|
|
||||||
alpha, beta1, beta2, epsilon, weight_decay, adamw_mode,
|
|
||||||
)
|
|
||||||
...
|
|
||||||
# optim step
|
|
||||||
cpu_adam_op.step(
|
|
||||||
step, lr, beta1, beta2, epsilon, weight_decay, bias_correction,
|
|
||||||
params, grads, exp_avg, exp_avg_sq, loss_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
func CPUAdamOptimizer:
|
|
||||||
alpha (float): learning rate. Default to 1e-3.
|
|
||||||
beta1 (float): coefficients used for computing running averages of gradient. Default to 0.9.
|
|
||||||
beta2 (float): coefficients used for computing running averages of its square. Default to 0.99.
|
|
||||||
epsilon (float): term added to the denominator to improve numerical stability. Default to 1e-8.
|
|
||||||
weight_decay (float): weight decay (L2 penalty). Default to 0.
|
|
||||||
adamw_mode (bool): whether to use the adamw. Default to True.
|
|
||||||
func step:
|
|
||||||
step (int): current step.
|
|
||||||
lr (float): learning rate.
|
|
||||||
beta1 (float): coefficients used for computing running averages of gradient.
|
|
||||||
beta2 (float): coefficients used for computing running averages of its square.
|
|
||||||
epsilon (float): term added to the denominator to improve numerical stability.
|
|
||||||
weight_decay (float): weight decay (L2 penalty).
|
|
||||||
bias_correction (bool): whether to use bias correction.
|
|
||||||
params (torch.Tensor): parameter.
|
|
||||||
grads (torch.Tensor): gradient.
|
|
||||||
exp_avg (torch.Tensor): exp average.
|
|
||||||
exp_avg_sq (torch.Tensor): exp average square.
|
|
||||||
loss_scale (float): loss scale value.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -57,7 +26,7 @@ class CPUAdamLoader(BaseKernelLoader):
|
|||||||
def fetch_kernel(self):
|
def fetch_kernel(self):
|
||||||
if platform.machine() == "x86_64":
|
if platform.machine() == "x86_64":
|
||||||
kernel = self._extension_map["x86"]().fetch()
|
kernel = self._extension_map["x86"]().fetch()
|
||||||
elif platform.machine() in ["aarch64", "aarch64_be", "armv8b", "armv8l"]:
|
elif platform.machine() == "aarch64":
|
||||||
kernel = self._extension_map["arm"]().fetch()
|
kernel = self._extension_map["arm"]().fetch()
|
||||||
else:
|
else:
|
||||||
raise Exception("not supported")
|
raise Exception("not supported")
|
||||||
|
Loading…
Reference in New Issue
Block a user