diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index 56d8a0935..8047d90f7 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -1,4 +1,3 @@ -import warnings from contextlib import contextmanager from typing import Any, Callable, Dict, Iterator, List, Optional, Union @@ -8,6 +7,8 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader +from colossalai.logging import get_dist_logger + SUPPORT_PEFT = False try: import peft @@ -81,12 +82,15 @@ class Booster: plugin, Plugin ), f"Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}." self.plugin = plugin + self.logger = get_dist_logger() # set accelerator if self.plugin and self.plugin.control_device(): self.accelerator = None if device is not None: - warnings.warn("The plugin will control the accelerator, so the device argument will be ignored.") + self.logger.warning( + "The plugin will control the accelerator," "so the device argument will be ignored.", ranks=[0] + ) else: device = device or "cuda" self.accelerator = Accelerator(device) @@ -94,7 +98,10 @@ class Booster: # set precision if self.plugin and self.plugin.control_precision(): if mixed_precision is not None: - warnings.warn("The plugin will control the precision, so the mixed_precision argument will be ignored.") + self.logger.warning( + "The plugin will control the precision," "so the mixed_precision argument will be ignored.", + ranks=[0], + ) self.mixed_precision = None elif mixed_precision is None: self.mixed_precision = None @@ -267,8 +274,9 @@ class Booster: ), "Please provide pretrained directory path if not passing in lora configuration." if quantize is True: if bnb_quantization_config is not None: - warnings.warn( - "User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk." + self.logger.warning( + "User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk.", + ranks=[0], ) else: bnb_quantization_config = BnbQuantizationConfig( diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index d0d3275cf..6a5d0c161 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -1,5 +1,4 @@ import gc -import logging import os import random from pathlib import Path @@ -27,6 +26,7 @@ from colossalai.checkpoint_io.utils import ( ) from colossalai.cluster import DistCoordinator, ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.logging import get_dist_logger from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.memory_tracer import MemStats @@ -63,6 +63,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): def __init__(self) -> None: super().__init__() self.coordinator = DistCoordinator() + self.logger = get_dist_logger() def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): """ @@ -118,7 +119,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): """ assert isinstance(model, GeminiDDP), "Please boost the model before saving!" if os.path.isfile(checkpoint_path): - logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") + self.logger.error(f"Provided path ({checkpoint_path}) should be a directory, not a file", ranks=[0]) return Path(checkpoint_path).mkdir(parents=True, exist_ok=True) @@ -143,10 +144,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO): index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) save_config_file(model.unwrap(), checkpoint_path) - logging.info( + self.logger.info( f"The model is split into checkpoint shards. " f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." + f"index located at {save_index_file}.", + ranks=[0], ) def load_sharded_model( @@ -168,7 +170,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!" if os.path.isfile(checkpoint): - logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0]) return Path(checkpoint).mkdir(parents=True, exist_ok=True) @@ -201,10 +203,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO): if self.coordinator.is_master(): index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) - logging.info( + self.logger.info( f"The optimizer is going to be split to checkpoint shards. " f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." + f"index located at {save_index_file}.", + ranks=[0], ) def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str): @@ -214,7 +217,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): """ assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!" if not os.path.isfile(checkpoint_index_file): - logging.error(f"Provided path ({checkpoint_index_file}) should be a file") + self.logger.error(f"Provided path ({checkpoint_index_file}) should be a file", ranks=[0]) assert isinstance(optimizer, GeminiOptimizer) @@ -371,9 +374,12 @@ class GeminiPlugin(DPPluginBase): assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" if get_accelerator().name == "npu": assert placement_policy == "static", "NPU only supports static placement policy" + + self.logger = get_dist_logger() if enable_async_reduce and not pin_memory: - logging.warning( - f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set." + self.logger.warning( + f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set.", + ranks=[0], ) pin_memory = True self.gemini_config = dict( diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index a92371485..ae1fbc771 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,6 +1,5 @@ import ctypes import random -import warnings from collections import defaultdict from contextlib import contextmanager, nullcontext from copy import deepcopy @@ -27,6 +26,7 @@ from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.cluster import ProcessGroupMesh from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface.optimizer import DistributedOptim +from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager @@ -1036,6 +1036,7 @@ class HybridParallelPlugin(PipelinePluginBase): inner_ring_size: int = None, ) -> None: super().__init__() + self.logger = get_dist_logger() assert ( dist.get_world_size() % (tp_size * pp_size) == 0 @@ -1053,8 +1054,9 @@ class HybridParallelPlugin(PipelinePluginBase): tp_size > 1 ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" if sp_size != 1: - warnings.warn( - f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size." + self.logger.warning( + f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size.", + ranks=[0], ) self.sp_size = 1 self.dp_size = dist.get_world_size() // (tp_size * pp_size) @@ -1143,7 +1145,12 @@ class HybridParallelPlugin(PipelinePluginBase): else: raise NotImplementedError() if sequence_parallelism_mode == "ring_attn": - assert parallel_output, "Ring Attention doesn't support gathering output yet." + if not parallel_output: + self.logger.warning( + "parallel_output must be True for Zigzag Ring Attention, as we've not supported Zigzag all-gather yet.", + ranks=[0], + ) + parallel_output = True self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis) @@ -1249,7 +1256,10 @@ class HybridParallelPlugin(PipelinePluginBase): optimizer = cast_to_distributed(optimizer) if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0: - warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.") + self.logger.warning( + "Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.", + ranks=[0], + ) zero_config["partition_grad"] = False zero_stage = 0 @@ -1306,9 +1316,10 @@ class HybridParallelPlugin(PipelinePluginBase): else: is_zero = self.dp_size > 1 if self.dp_size == 1: - warnings.warn( + self.logger.warning( "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " - "If you do not intend to use cpu_offload, please consider set zero_stage=0." + "If you do not intend to use cpu_offload, please consider set zero_stage=0.", + ranks=[0], ) assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." @@ -1351,7 +1362,7 @@ class HybridParallelPlugin(PipelinePluginBase): assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled" if return_outputs: - warnings.warn("return_outputs may lead to significant extra memory consumption.") + self.logger.warning("return_outputs may lead to significant extra memory consumption.", ranks=[0]) # Create a context for gradient synchronization based on the optimizer type. # If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync(). @@ -1365,10 +1376,8 @@ class HybridParallelPlugin(PipelinePluginBase): ) # run with gradients accumulation - if ( - model.require_grad_sync == False - or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False) - or not torch.is_grad_enabled() + if model.require_grad_sync == False or ( + isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False ): return outputs @@ -1468,7 +1477,7 @@ class HybridParallelPlugin(PipelinePluginBase): assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model." assert self.pp_size == 1 and self.tp_size == 1 self.lora_enabled = True - warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr") + self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0]) if bnb_quantization_config is not None: model = quantize_model(model, bnb_quantization_config) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 088fa1daa..4188491c2 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -1,7 +1,5 @@ import enum -import logging import os -import warnings from contextlib import nullcontext from functools import partial from pathlib import Path @@ -33,6 +31,7 @@ from colossalai.checkpoint_io.utils import ( ) from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface.optimizer import DistributedOptim +from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.quantization.fp8_hook import FP8Hook @@ -64,12 +63,7 @@ class OptimizerParamCheckState(enum.Enum): class LowLevelZeroModel(ModelWrapper, AMPModelMixin): def __init__( - self, - module: nn.Module, - precision: str, - overlap_allgather: bool = False, - cast_inputs: bool = True, - use_fp8: bool = False, + self, module: nn.Module, precision: str, overlap_allgather: bool = False, use_fp8: bool = False ) -> None: super().__init__(module) self.dtype = None @@ -87,8 +81,6 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin): self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) self.overlap_allgather = overlap_allgather self.op_hooks = [] - if self.dtype is not None and cast_inputs: - self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) if overlap_allgather: self.op_hooks.append(ZeroOpHook()) if use_fp8: @@ -153,7 +145,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): """ assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!" if os.path.isfile(checkpoint): - logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0]) return Path(checkpoint).mkdir(parents=True, exist_ok=True) @@ -190,10 +182,11 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): index_file.append_meta_data("total_size", total_size) if self.coordinator.is_master(): index_file.write_index_file(save_index_file) - logging.info( + self.logger.info( f"The optimizer is going to be split to checkpoint shards. " f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." + f"index located at {save_index_file}.", + ranks=[0], ) def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str): @@ -280,7 +273,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): def save_lora_as_pretrained(self, model, checkpoint, use_safetensors): if os.path.isfile(checkpoint): - logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0]) return from peft import PeftModel @@ -349,7 +342,6 @@ class LowLevelZeroPlugin(DPPluginBase): cpu_offload: bool = False, master_weights: bool = True, verbose: bool = False, - cast_inputs: bool = True, fp8_communication: bool = False, use_fp8: bool = False, ) -> None: @@ -379,9 +371,8 @@ class LowLevelZeroPlugin(DPPluginBase): ) self.lora_enabled = False self.verbose = verbose + self.logger = get_dist_logger() self.use_fp8 = use_fp8 - self.cast_inputs = cast_inputs - # set class name with stage, for better error message setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}") @@ -417,7 +408,7 @@ class LowLevelZeroPlugin(DPPluginBase): assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model." self.lora_enabled = True - warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr") + self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0]) if bnb_quantization_config is not None: model = quantize_model(model, bnb_quantization_config) @@ -466,8 +457,9 @@ class LowLevelZeroPlugin(DPPluginBase): origin_param = name2param[origin_key] group_id, check_state = self.get_param_group_id(optimizer, origin_param, param) if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND: - warnings.warn( - f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups." + self.logger.warning( + f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups.", + ranks=[0], ) elif ( check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED @@ -498,7 +490,6 @@ class LowLevelZeroPlugin(DPPluginBase): model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"], - cast_inputs=self.cast_inputs, use_fp8=self.use_fp8, ) @@ -511,7 +502,10 @@ class LowLevelZeroPlugin(DPPluginBase): optimizer = cast_to_distributed(optimizer) if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0: - warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.") + self.logger.warning( + "Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.", + ranks=[0], + ) zero_optim_kwargs["partition_grad"] = False zero_stage = 0 diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 724a39fd2..74d35f5c5 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -1,4 +1,3 @@ -import warnings from collections import defaultdict from types import MethodType from typing import Callable, List, Optional, OrderedDict, Tuple @@ -26,6 +25,7 @@ from colossalai.checkpoint_io import MoECheckpointIO from colossalai.cluster.process_group_mesh import ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface.optimizer import DistributedOptim +from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import cast_to_distributed from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule @@ -217,12 +217,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): fp8_communication: bool = False, use_fp8: bool = False, ) -> None: + self.logger = get_dist_logger() if overlap_communication or zero_stage == 2: overlap_communication = False zero_stage = 1 - warnings.warn( + self.logger.warning( f"overlap_communication and zero_stage are set to False and 1 because " - f"ZeRO-2 or comm overlap cause program hang when some experts are not routed. " + f"ZeRO-2 or comm overlap cause program hang when some experts are not routed.", + ranks=[0], ) assert ( @@ -240,8 +242,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): tp_size > 1 ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" if sp_size != 1: - warnings.warn( - f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size." + self.logger.warning( + f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}," + "will ignore the given sequence parallelism size.", + ranks=[0], ) self.sp_size = 1 self.dp_size = dist.get_world_size() // (tp_size * pp_size) @@ -404,8 +408,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): and self.sequence_parallelism_mode == "all_to_all" ) if use_ddp: - warnings.warn( - f"Will have to check all params are used in pytorch DDP since not all experts are always activated" + self.logger.warning( + f"Will have to check all params are used in pytorch DDP since not all experts are always activated", + ranks=[0], ) self.ddp_config["find_unused_parameters"] = True @@ -462,9 +467,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): ) else: if self.dp_size <= 1: - warnings.warn( + self.logger.warning( "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " - "If you do not intend to use cpu_offload, please consider set zero_stage=0." + "If you do not intend to use cpu_offload, please consider set zero_stage=0.", + ranks=[0], ) assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." optimizer = MoeHybridParallelZeroOptimizer( diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 34caa2f68..61d785a4c 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -9,6 +9,7 @@ from torch.utils.data import DataLoader from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.logging import get_dist_logger from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.utils import get_current_device @@ -21,6 +22,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): def __init__(self) -> None: super().__init__() self.coordinator = DistCoordinator() + self.logger = get_dist_logger() def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True): """ diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index e3f81928d..23a35bbcb 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -1,6 +1,4 @@ -import logging import os -import warnings from pathlib import Path from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple @@ -30,6 +28,7 @@ from torch.utils.data import DataLoader from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.logging import get_dist_logger from .dp_plugin_base import DPPluginBase @@ -40,6 +39,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): def __init__(self) -> None: super().__init__() self.coordinator = DistCoordinator() + self.logger = get_dist_logger() def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool): assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!" @@ -88,7 +88,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): """ assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!" if os.path.isfile(checkpoint_path): - logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") + self.logger.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") return Path(checkpoint_path).mkdir(parents=True, exist_ok=True) @@ -117,7 +117,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) utils.save_config_file(model.unwrap(), checkpoint_path) - logging.info( + self.logger.info( f"The model is split into checkpoint shards. " f"You can find where each parameters has been saved in the " f"index located at {save_index_file}." @@ -162,7 +162,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!" if os.path.isfile(checkpoint): - logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file") return Path(checkpoint).mkdir(parents=True, exist_ok=True) @@ -200,7 +200,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) - logging.info( + self.logger.info( f"The optimizer is going to be split to checkpoint shards. " f"You can find where each parameters has been saved in the " f"index located at {save_index_file}." @@ -313,6 +313,7 @@ class TorchFSDPPlugin(DPPluginBase): sync_module_states=sync_module_states, ) self.fp8_communication = fp8_communication + self.logger = get_dist_logger() else: raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") @@ -364,7 +365,7 @@ class TorchFSDPPlugin(DPPluginBase): if optimizer is not None: if len(optimizer.param_groups) > 1: - warnings.warn( + self.logger.warning( "TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used." ) optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 505729219..b102387a0 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -694,6 +694,13 @@ class RingAttention(torch.autograd.Function): ) return out, softmax_lse, rng_state + def _kv_comm(i): + # Avoid overwriting attn input when it shares mem with buffer + if not RingAttention.ATTN_DONE.query(): + kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2]) + if i < local_sp_size - 1: + local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) + def _local_ring_forward(): # (Hopefully) overlap output correction with next flash attn for i in range(local_sp_size): @@ -702,12 +709,8 @@ class RingAttention(torch.autograd.Function): # NOTE: waiting outside the current stream will NOT correctly synchronize. if i > 0: local_kv_comms[(i + 1) % 2].wait() - - # Avoid overwriting attn input when it shares mem with buffer - if not RingAttention.ATTN_DONE.query(): - kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2]) - if i < local_sp_size - 1: - local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) + if i == 0: + _kv_comm(i) if i == 0: # Compute with local KV; no mask @@ -738,6 +741,9 @@ class RingAttention(torch.autograd.Function): rng_states[i], ) = _forward(q_block, kv_block[0], kv_block[1], causal=False) RingAttention.ATTN_DONE.record() + # Pipeline the next KV comm with output correction instead of the next flash attn + # to minimize idle time when comm takes longer than attn. + _kv_comm(i + 1) block_softmax_lse[i % 2] = ( block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() @@ -765,15 +771,13 @@ class RingAttention(torch.autograd.Function): # all new KVs from the previous inner ring for i in range(local_sp_size): with torch.cuda.stream(sp_streams[i % 2]): - if not RingAttention.ATTN_DONE.query(): - kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2]) - if i < local_sp_size - 1: - local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) - # Send & recv KV if i > 0: local_kv_comms[(i + 1) % 2].wait() + if i == 0: + _kv_comm(i) + if ring_num_idx > inter_ring_rank: kv_block = kv_buffers[i % 2] ( @@ -782,6 +786,8 @@ class RingAttention(torch.autograd.Function): rng_states[i + local_sp_size * ring_num_idx], ) = _forward(q1, kv_block[0], kv_block[1], causal=False) RingAttention.ATTN_DONE.record() + + _kv_comm(i + 1) block_softmax_lse[i % 2] = ( block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() ) @@ -796,6 +802,8 @@ class RingAttention(torch.autograd.Function): rng_states[i + local_sp_size * ring_num_idx], ) = _forward(q, kv_block[0], kv_block[1], causal=False) RingAttention.ATTN_DONE.record() + + _kv_comm(i + 1) block_softmax_lse[i % 2] = ( block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() ) diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 1d755c417..fdf2a4976 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -1,7 +1,6 @@ # this code is inspired by the DeepSpeed library and implemented with our own design from scratch import copy import math -import warnings from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union import torch @@ -136,7 +135,7 @@ class GeminiOptimizer(OptimizerWrapper): self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0 self.verbose = verbose self.param_groups_backup = list() - + self.logger = get_dist_logger() # Mapping from integer id to real/fake param tensor, used for checkpointing. self.id_to_real_params: Dict[int, Parameter] = dict() self.id_to_fake_params: Dict[int, Parameter] = dict() @@ -148,9 +147,10 @@ class GeminiOptimizer(OptimizerWrapper): for name, param in module.named_parameters(): if is_ddp_ignored(param): if param.requires_grad: - warnings.warn( + self.logger.warning( f"Parameter `{name}` is ignored by DDP but requires gradient! " - "You should handle its optimizer update by yourself!" + "You should handle its optimizer update by yourself!", + ranks=[0], ) else: ddp_param_list.append(param) @@ -842,7 +842,9 @@ class GeminiOptimizer(OptimizerWrapper): *args, **kwargs, ) -> torch.Tensor: - warnings.warn(f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm") + self.logger.warning( + f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm", ranks=[0] + ) class GeminiAdamOptimizer(GeminiOptimizer):