From 394221861826b8032b1bea0052f06e792467674d Mon Sep 17 00:00:00 2001 From: wangbinluo <2538539015@qq.com> Date: Thu, 11 Jan 2024 08:21:53 +0000 Subject: [PATCH] remove useless platform args and comment --- colossalai/kernel/cpu_adam_loader.py | 33 +--------------------------- 1 file changed, 1 insertion(+), 32 deletions(-) diff --git a/colossalai/kernel/cpu_adam_loader.py b/colossalai/kernel/cpu_adam_loader.py index 0df6bd49b..4763f40ab 100644 --- a/colossalai/kernel/cpu_adam_loader.py +++ b/colossalai/kernel/cpu_adam_loader.py @@ -12,37 +12,6 @@ class CPUAdamLoader(BaseKernelLoader): Usage: # init cpu_adam = CPUAdamLoader().load() - cpu_adam_op = cpu_adam.CPUAdamOptimizer( - alpha, beta1, beta2, epsilon, weight_decay, adamw_mode, - ) - ... - # optim step - cpu_adam_op.step( - step, lr, beta1, beta2, epsilon, weight_decay, bias_correction, - params, grads, exp_avg, exp_avg_sq, loss_scale, - ) - - Args: - func CPUAdamOptimizer: - alpha (float): learning rate. Default to 1e-3. - beta1 (float): coefficients used for computing running averages of gradient. Default to 0.9. - beta2 (float): coefficients used for computing running averages of its square. Default to 0.99. - epsilon (float): term added to the denominator to improve numerical stability. Default to 1e-8. - weight_decay (float): weight decay (L2 penalty). Default to 0. - adamw_mode (bool): whether to use the adamw. Default to True. - func step: - step (int): current step. - lr (float): learning rate. - beta1 (float): coefficients used for computing running averages of gradient. - beta2 (float): coefficients used for computing running averages of its square. - epsilon (float): term added to the denominator to improve numerical stability. - weight_decay (float): weight decay (L2 penalty). - bias_correction (bool): whether to use bias correction. - params (torch.Tensor): parameter. - grads (torch.Tensor): gradient. - exp_avg (torch.Tensor): exp average. - exp_avg_sq (torch.Tensor): exp average square. - loss_scale (float): loss scale value. """ def __init__(self): @@ -57,7 +26,7 @@ class CPUAdamLoader(BaseKernelLoader): def fetch_kernel(self): if platform.machine() == "x86_64": kernel = self._extension_map["x86"]().fetch() - elif platform.machine() in ["aarch64", "aarch64_be", "armv8b", "armv8l"]: + elif platform.machine() == "aarch64": kernel = self._extension_map["arm"]().fetch() else: raise Exception("not supported")