diff --git a/colossalai/nn/optimizer/__init__.py b/colossalai/nn/optimizer/__init__.py index ddb03f947..d872dbbaf 100644 --- a/colossalai/nn/optimizer/__init__.py +++ b/colossalai/nn/optimizer/__init__.py @@ -43,11 +43,11 @@ optim2DistOptim = { CAME: DistributedCAME, Adafactor: DistributedAdaFactor, } -_logger = get_dist_logger() def cast_to_distributed(optim): if optim.__class__ in optim2DistOptim: + _logger = get_dist_logger() _logger.info(f"Converting optimizer {optim.__class__.__name__} to its distributed version.", ranks=[0]) if isinstance(optim, GaLoreAdamW8bit):