diff --git a/colossalai/kernel/extensions/cpu_adam/arm_extension.py b/colossalai/kernel/extensions/cpu_adam/arm_extension.py index 9868059bf..0b552f436 100644 --- a/colossalai/kernel/extensions/cpu_adam/arm_extension.py +++ b/colossalai/kernel/extensions/cpu_adam/arm_extension.py @@ -6,7 +6,7 @@ class ArmCPUAdamExtension(BaseExtension): def __init__(self) -> None: super().__init__() self.kernel_builder = ArmCPUAdamBuilder() - self._requires_build = False + self._requires_build = True @property def requires_build(self) -> bool: @@ -14,7 +14,7 @@ class ArmCPUAdamExtension(BaseExtension): def build(self): self.kernel_builder.build() - self._requires_build = True + self._requires_build = False def load(self): return self.kernel_builder.load() diff --git a/colossalai/kernel/extensions/cpu_adam/x86_extension.py b/colossalai/kernel/extensions/cpu_adam/x86_extension.py index 687c91f35..a5b64bed4 100644 --- a/colossalai/kernel/extensions/cpu_adam/x86_extension.py +++ b/colossalai/kernel/extensions/cpu_adam/x86_extension.py @@ -7,7 +7,7 @@ class X86CPUAdamExtension(BaseExtension): def __init__(self) -> None: super().__init__() self.kernel_builder = X86CPUAdamBuilder() - self._requires_build = False + self._requires_build = True @property def requires_build(self) -> bool: @@ -15,7 +15,7 @@ class X86CPUAdamExtension(BaseExtension): def build(self): self.kernel_builder.build() - self._requires_build = True + self._requires_build = False def load(self): return self.kernel_builder.load()