mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-26 12:14:02 +00:00
refactored grad scaler (#338)
This commit is contained in:
16
colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py
Normal file
16
colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py
Normal file
@@ -0,0 +1,16 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
from .base_grad_scaler import BaseGradScaler
|
||||
|
||||
__all__ = ['ConstantGradScaler']
|
||||
|
||||
|
||||
class ConstantGradScaler(BaseGradScaler):
|
||||
|
||||
def __init__(self, initial_scale: int, verbose: bool):
|
||||
super().__init__(initial_scale, verbose)
|
||||
self.log(f"Constant Gradient Scaler is initialized with scale {self.scale}", ranks=[0])
|
||||
|
||||
def update(self, overflow: bool) -> None:
|
||||
# do nothing to maintain the current scale value
|
||||
pass
|
Reference in New Issue
Block a user