mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-19 09:51:18 +00:00
remove useless platform args and comment
This commit is contained in:
parent
a9b5ec8664
commit
3942218618
@ -12,37 +12,6 @@ class CPUAdamLoader(BaseKernelLoader):
|
||||
Usage:
|
||||
# init
|
||||
cpu_adam = CPUAdamLoader().load()
|
||||
cpu_adam_op = cpu_adam.CPUAdamOptimizer(
|
||||
alpha, beta1, beta2, epsilon, weight_decay, adamw_mode,
|
||||
)
|
||||
...
|
||||
# optim step
|
||||
cpu_adam_op.step(
|
||||
step, lr, beta1, beta2, epsilon, weight_decay, bias_correction,
|
||||
params, grads, exp_avg, exp_avg_sq, loss_scale,
|
||||
)
|
||||
|
||||
Args:
|
||||
func CPUAdamOptimizer:
|
||||
alpha (float): learning rate. Default to 1e-3.
|
||||
beta1 (float): coefficients used for computing running averages of gradient. Default to 0.9.
|
||||
beta2 (float): coefficients used for computing running averages of its square. Default to 0.99.
|
||||
epsilon (float): term added to the denominator to improve numerical stability. Default to 1e-8.
|
||||
weight_decay (float): weight decay (L2 penalty). Default to 0.
|
||||
adamw_mode (bool): whether to use the adamw. Default to True.
|
||||
func step:
|
||||
step (int): current step.
|
||||
lr (float): learning rate.
|
||||
beta1 (float): coefficients used for computing running averages of gradient.
|
||||
beta2 (float): coefficients used for computing running averages of its square.
|
||||
epsilon (float): term added to the denominator to improve numerical stability.
|
||||
weight_decay (float): weight decay (L2 penalty).
|
||||
bias_correction (bool): whether to use bias correction.
|
||||
params (torch.Tensor): parameter.
|
||||
grads (torch.Tensor): gradient.
|
||||
exp_avg (torch.Tensor): exp average.
|
||||
exp_avg_sq (torch.Tensor): exp average square.
|
||||
loss_scale (float): loss scale value.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@ -57,7 +26,7 @@ class CPUAdamLoader(BaseKernelLoader):
|
||||
def fetch_kernel(self):
|
||||
if platform.machine() == "x86_64":
|
||||
kernel = self._extension_map["x86"]().fetch()
|
||||
elif platform.machine() in ["aarch64", "aarch64_be", "armv8b", "armv8l"]:
|
||||
elif platform.machine() == "aarch64":
|
||||
kernel = self._extension_map["arm"]().fetch()
|
||||
else:
|
||||
raise Exception("not supported")
|
||||
|
Loading…
Reference in New Issue
Block a user