From dcc44aab8d3ab640bfd9d740a147ad39cb21f18c Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Tue, 20 Aug 2024 10:32:41 +0800 Subject: [PATCH] [misc] Use dist logger in plugins (#6011) * use dist logger in plugins * remove trash * print on rank 0 --------- Co-authored-by: Edenzzzz --- colossalai/booster/booster.py | 18 ++++++--- colossalai/booster/plugin/gemini_plugin.py | 26 ++++++++----- .../booster/plugin/hybrid_parallel_plugin.py | 35 ++++++++++------- .../booster/plugin/low_level_zero_plugin.py | 39 +++++++++---------- .../plugin/moe_hybrid_parallel_plugin.py | 24 +++++++----- colossalai/booster/plugin/torch_ddp_plugin.py | 2 + .../booster/plugin/torch_fsdp_plugin.py | 15 +++---- colossalai/zero/gemini/gemini_optimizer.py | 12 +++--- 8 files changed, 101 insertions(+), 70 deletions(-) diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index 56d8a0935..8047d90f7 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -1,4 +1,3 @@ -import warnings from contextlib import contextmanager from typing import Any, Callable, Dict, Iterator, List, Optional, Union @@ -8,6 +7,8 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader +from colossalai.logging import get_dist_logger + SUPPORT_PEFT = False try: import peft @@ -81,12 +82,15 @@ class Booster: plugin, Plugin ), f"Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}." self.plugin = plugin + self.logger = get_dist_logger() # set accelerator if self.plugin and self.plugin.control_device(): self.accelerator = None if device is not None: - warnings.warn("The plugin will control the accelerator, so the device argument will be ignored.") + self.logger.warning( + "The plugin will control the accelerator," "so the device argument will be ignored.", ranks=[0] + ) else: device = device or "cuda" self.accelerator = Accelerator(device) @@ -94,7 +98,10 @@ class Booster: # set precision if self.plugin and self.plugin.control_precision(): if mixed_precision is not None: - warnings.warn("The plugin will control the precision, so the mixed_precision argument will be ignored.") + self.logger.warning( + "The plugin will control the precision," "so the mixed_precision argument will be ignored.", + ranks=[0], + ) self.mixed_precision = None elif mixed_precision is None: self.mixed_precision = None @@ -267,8 +274,9 @@ class Booster: ), "Please provide pretrained directory path if not passing in lora configuration." if quantize is True: if bnb_quantization_config is not None: - warnings.warn( - "User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk." + self.logger.warning( + "User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk.", + ranks=[0], ) else: bnb_quantization_config = BnbQuantizationConfig( diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index ad131fbe7..3754cfe60 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -1,5 +1,4 @@ import gc -import logging import os import random from pathlib import Path @@ -27,6 +26,7 @@ from colossalai.checkpoint_io.utils import ( ) from colossalai.cluster import DistCoordinator, ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.logging import get_dist_logger from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.memory_tracer import MemStats @@ -63,6 +63,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): def __init__(self) -> None: super().__init__() self.coordinator = DistCoordinator() + self.logger = get_dist_logger() def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): """ @@ -118,7 +119,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): """ assert isinstance(model, GeminiDDP), "Please boost the model before saving!" if os.path.isfile(checkpoint_path): - logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") + self.logger.error(f"Provided path ({checkpoint_path}) should be a directory, not a file", ranks=[0]) return Path(checkpoint_path).mkdir(parents=True, exist_ok=True) @@ -143,10 +144,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO): index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) save_config_file(model.unwrap(), checkpoint_path) - logging.info( + self.logger.info( f"The model is split into checkpoint shards. " f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." + f"index located at {save_index_file}.", + ranks=[0], ) def load_sharded_model( @@ -168,7 +170,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!" if os.path.isfile(checkpoint): - logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0]) return Path(checkpoint).mkdir(parents=True, exist_ok=True) @@ -201,10 +203,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO): if self.coordinator.is_master(): index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) - logging.info( + self.logger.info( f"The optimizer is going to be split to checkpoint shards. " f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." + f"index located at {save_index_file}.", + ranks=[0], ) def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str): @@ -214,7 +217,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): """ assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!" if not os.path.isfile(checkpoint_index_file): - logging.error(f"Provided path ({checkpoint_index_file}) should be a file") + self.logger.error(f"Provided path ({checkpoint_index_file}) should be a file", ranks=[0]) assert isinstance(optimizer, GeminiOptimizer) @@ -369,9 +372,12 @@ class GeminiPlugin(DPPluginBase): assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" if get_accelerator().name == "npu": assert placement_policy == "static", "NPU only supports static placement policy" + + self.logger = get_dist_logger() if enable_async_reduce and not pin_memory: - logging.warning( - f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set." + self.logger.warning( + f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set.", + ranks=[0], ) pin_memory = True self.gemini_config = dict( diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 63427192f..b4b40020f 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,6 +1,5 @@ import ctypes import random -import warnings from collections import defaultdict from contextlib import contextmanager, nullcontext from copy import deepcopy @@ -27,6 +26,7 @@ from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.cluster import ProcessGroupMesh from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface.optimizer import DistributedOptim +from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager @@ -1023,6 +1023,7 @@ class HybridParallelPlugin(PipelinePluginBase): inner_ring_size: int = None, ) -> None: super().__init__() + self.logger = get_dist_logger() assert ( dist.get_world_size() % (tp_size * pp_size) == 0 @@ -1040,8 +1041,9 @@ class HybridParallelPlugin(PipelinePluginBase): tp_size > 1 ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" if sp_size != 1: - warnings.warn( - f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size." + self.logger.warning( + f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size.", + ranks=[0], ) self.sp_size = 1 self.dp_size = dist.get_world_size() // (tp_size * pp_size) @@ -1126,7 +1128,12 @@ class HybridParallelPlugin(PipelinePluginBase): else: raise NotImplementedError() if sequence_parallelism_mode == "ring_attn": - assert parallel_output, "Ring Attention doesn't support gathering output yet." + if not parallel_output: + self.logger.warning( + "parallel_output must be True for Zigzag Ring Attention, as we've not supported Zigzag all-gather yet.", + ranks=[0], + ) + parallel_output = True self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis) @@ -1231,7 +1238,10 @@ class HybridParallelPlugin(PipelinePluginBase): optimizer = cast_to_distributed(optimizer) if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0: - warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.") + self.logger.warning( + "Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.", + ranks=[0], + ) zero_config["partition_grad"] = False zero_stage = 0 @@ -1287,9 +1297,10 @@ class HybridParallelPlugin(PipelinePluginBase): else: is_zero = self.dp_size > 1 if self.dp_size == 1: - warnings.warn( + self.logger.warning( "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " - "If you do not intend to use cpu_offload, please consider set zero_stage=0." + "If you do not intend to use cpu_offload, please consider set zero_stage=0.", + ranks=[0], ) assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." @@ -1332,7 +1343,7 @@ class HybridParallelPlugin(PipelinePluginBase): assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled" if return_outputs: - warnings.warn("return_outputs may lead to significant extra memory consumption.") + self.logger.warning("return_outputs may lead to significant extra memory consumption.", ranks=[0]) # Create a context for gradient synchronization based on the optimizer type. # If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync(). @@ -1346,10 +1357,8 @@ class HybridParallelPlugin(PipelinePluginBase): ) # run with gradients accumulation - if ( - model.require_grad_sync == False - or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False) - or not torch.is_grad_enabled() + if model.require_grad_sync == False or ( + isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False ): return outputs @@ -1449,7 +1458,7 @@ class HybridParallelPlugin(PipelinePluginBase): assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model." assert self.pp_size == 1 and self.tp_size == 1 self.lora_enabled = True - warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr") + self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0]) if bnb_quantization_config is not None: model = quantize_model(model, bnb_quantization_config) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index e4c386a22..cc3755e6f 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -1,7 +1,5 @@ import enum -import logging import os -import warnings from contextlib import nullcontext from functools import partial from pathlib import Path @@ -33,6 +31,7 @@ from colossalai.checkpoint_io.utils import ( ) from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface.optimizer import DistributedOptim +from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.tensor.colo_parameter import ColoParameter @@ -62,9 +61,7 @@ class OptimizerParamCheckState(enum.Enum): class LowLevelZeroModel(ModelWrapper, AMPModelMixin): - def __init__( - self, module: nn.Module, precision: str, overlap_allgather: bool = False, cast_inputs: bool = True - ) -> None: + def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None: super().__init__(module) self.dtype = None if precision == "fp16": @@ -76,7 +73,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin): module = module.to(get_accelerator().get_current_device()) self.module = module self.convert_fn = None - if self.dtype is not None and cast_inputs: + if self.dtype is not None: self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) self.overlap_allgather = overlap_allgather if overlap_allgather: @@ -140,7 +137,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): """ assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!" if os.path.isfile(checkpoint): - logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0]) return Path(checkpoint).mkdir(parents=True, exist_ok=True) @@ -177,10 +174,11 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): index_file.append_meta_data("total_size", total_size) if self.coordinator.is_master(): index_file.write_index_file(save_index_file) - logging.info( + self.logger.info( f"The optimizer is going to be split to checkpoint shards. " f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." + f"index located at {save_index_file}.", + ranks=[0], ) def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str): @@ -267,7 +265,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): def save_lora_as_pretrained(self, model, checkpoint, use_safetensors): if os.path.isfile(checkpoint): - logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0]) return from peft import PeftModel @@ -336,7 +334,6 @@ class LowLevelZeroPlugin(DPPluginBase): cpu_offload: bool = False, master_weights: bool = True, verbose: bool = False, - cast_inputs: bool = True, ) -> None: super().__init__() assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training" @@ -363,8 +360,7 @@ class LowLevelZeroPlugin(DPPluginBase): ) self.lora_enabled = False self.verbose = verbose - self.cast_inputs = cast_inputs - + self.logger = get_dist_logger() # set class name with stage, for better error message setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}") @@ -400,7 +396,7 @@ class LowLevelZeroPlugin(DPPluginBase): assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model." self.lora_enabled = True - warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr") + self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0]) if bnb_quantization_config is not None: model = quantize_model(model, bnb_quantization_config) @@ -449,8 +445,9 @@ class LowLevelZeroPlugin(DPPluginBase): origin_param = name2param[origin_key] group_id, check_state = self.get_param_group_id(optimizer, origin_param, param) if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND: - warnings.warn( - f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups." + self.logger.warning( + f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups.", + ranks=[0], ) elif ( check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED @@ -478,10 +475,7 @@ class LowLevelZeroPlugin(DPPluginBase): if not isinstance(model, ModelWrapper): model = LowLevelZeroModel( - model, - self.precision, - overlap_allgather=self.zero_optim_kwargs["overlap_allgather"], - cast_inputs=self.cast_inputs, + model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"] ) # TODO: Support Galore + ZeRO @@ -493,7 +487,10 @@ class LowLevelZeroPlugin(DPPluginBase): optimizer = cast_to_distributed(optimizer) if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0: - warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.") + self.logger.warning( + "Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.", + ranks=[0], + ) zero_optim_kwargs["partition_grad"] = False zero_stage = 0 diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index b3415af0e..36973b240 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -1,4 +1,3 @@ -import warnings from collections import defaultdict from types import MethodType from typing import Callable, List, Optional, OrderedDict, Tuple @@ -26,6 +25,7 @@ from colossalai.checkpoint_io import MoECheckpointIO from colossalai.cluster.process_group_mesh import ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface.optimizer import DistributedOptim +from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import cast_to_distributed from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule @@ -215,12 +215,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): overlap_p2p: bool = True, overlap_allgather: bool = False, ) -> None: + self.logger = get_dist_logger() if overlap_communication or zero_stage == 2: overlap_communication = False zero_stage = 1 - warnings.warn( + self.logger.warning( f"overlap_communication and zero_stage are set to False and 1 because " - f"ZeRO-2 or comm overlap cause program hang when some experts are not routed. " + f"ZeRO-2 or comm overlap cause program hang when some experts are not routed.", + ranks=[0], ) assert ( @@ -238,8 +240,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): tp_size > 1 ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" if sp_size != 1: - warnings.warn( - f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size." + self.logger.warning( + f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}," + "will ignore the given sequence parallelism size.", + ranks=[0], ) self.sp_size = 1 self.dp_size = dist.get_world_size() // (tp_size * pp_size) @@ -400,8 +404,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): and self.sequence_parallelism_mode == "all_to_all" ) if use_ddp: - warnings.warn( - f"Will have to check all params are used in pytorch DDP since not all experts are always activated" + self.logger.warning( + f"Will have to check all params are used in pytorch DDP since not all experts are always activated", + ranks=[0], ) self.ddp_config["find_unused_parameters"] = True @@ -457,9 +462,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): ) else: if self.dp_size <= 1: - warnings.warn( + self.logger.warning( "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " - "If you do not intend to use cpu_offload, please consider set zero_stage=0." + "If you do not intend to use cpu_offload, please consider set zero_stage=0.", + ranks=[0], ) assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." optimizer = MoeHybridParallelZeroOptimizer( diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 5116446a4..8a807970c 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -9,6 +9,7 @@ from torch.utils.data import DataLoader from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.logging import get_dist_logger from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.utils import get_current_device @@ -21,6 +22,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): def __init__(self) -> None: super().__init__() self.coordinator = DistCoordinator() + self.logger = get_dist_logger() def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True): """ diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index cd2f9e840..7b67da032 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -1,6 +1,4 @@ -import logging import os -import warnings from pathlib import Path from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple @@ -30,6 +28,7 @@ from torch.utils.data import DataLoader from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.logging import get_dist_logger from .dp_plugin_base import DPPluginBase @@ -40,6 +39,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): def __init__(self) -> None: super().__init__() self.coordinator = DistCoordinator() + self.logger = get_dist_logger() def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool): assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!" @@ -88,7 +88,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): """ assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!" if os.path.isfile(checkpoint_path): - logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") + self.logger.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") return Path(checkpoint_path).mkdir(parents=True, exist_ok=True) @@ -117,7 +117,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) utils.save_config_file(model.unwrap(), checkpoint_path) - logging.info( + self.logger.info( f"The model is split into checkpoint shards. " f"You can find where each parameters has been saved in the " f"index located at {save_index_file}." @@ -162,7 +162,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!" if os.path.isfile(checkpoint): - logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file") return Path(checkpoint).mkdir(parents=True, exist_ok=True) @@ -200,7 +200,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) - logging.info( + self.logger.info( f"The optimizer is going to be split to checkpoint shards. " f"You can find where each parameters has been saved in the " f"index located at {save_index_file}." @@ -311,6 +311,7 @@ class TorchFSDPPlugin(DPPluginBase): param_init_fn=param_init_fn, sync_module_states=sync_module_states, ) + self.logger = get_dist_logger() else: raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") @@ -349,7 +350,7 @@ class TorchFSDPPlugin(DPPluginBase): if optimizer is not None: if len(optimizer.param_groups) > 1: - warnings.warn( + self.logger.warning( "TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used." ) optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults) diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 1d755c417..fdf2a4976 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -1,7 +1,6 @@ # this code is inspired by the DeepSpeed library and implemented with our own design from scratch import copy import math -import warnings from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union import torch @@ -136,7 +135,7 @@ class GeminiOptimizer(OptimizerWrapper): self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0 self.verbose = verbose self.param_groups_backup = list() - + self.logger = get_dist_logger() # Mapping from integer id to real/fake param tensor, used for checkpointing. self.id_to_real_params: Dict[int, Parameter] = dict() self.id_to_fake_params: Dict[int, Parameter] = dict() @@ -148,9 +147,10 @@ class GeminiOptimizer(OptimizerWrapper): for name, param in module.named_parameters(): if is_ddp_ignored(param): if param.requires_grad: - warnings.warn( + self.logger.warning( f"Parameter `{name}` is ignored by DDP but requires gradient! " - "You should handle its optimizer update by yourself!" + "You should handle its optimizer update by yourself!", + ranks=[0], ) else: ddp_param_list.append(param) @@ -842,7 +842,9 @@ class GeminiOptimizer(OptimizerWrapper): *args, **kwargs, ) -> torch.Tensor: - warnings.warn(f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm") + self.logger.warning( + f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm", ranks=[0] + ) class GeminiAdamOptimizer(GeminiOptimizer):