[bug] fix: somehow logger hangs the program

This commit is contained in:
botbw 2024-07-23 06:17:51 +00:00
parent e31d2ebcf7
commit 91f84f6a5f
2 changed files with 0 additions and 17 deletions

View File

@ -27,7 +27,6 @@ from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
@ -1020,8 +1019,6 @@ class HybridParallelPlugin(PipelinePluginBase):
) -> None: ) -> None:
super().__init__() super().__init__()
self.logger = get_dist_logger(type(self).__name__)
assert ( assert (
dist.get_world_size() % (tp_size * pp_size) == 0 dist.get_world_size() % (tp_size * pp_size) == 0
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
@ -1070,10 +1067,6 @@ class HybridParallelPlugin(PipelinePluginBase):
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
self.logger.info(
f"{type(self).__name__}: {self.pp_size=} {self.dp_size=} {self.tp_size=} {self.sp_size=}", ranks=[0]
)
self.stage_manager = None self.stage_manager = None
self.schedule = None self.schedule = None
self.custom_policy = custom_policy self.custom_policy = custom_policy
@ -1123,10 +1116,6 @@ class HybridParallelPlugin(PipelinePluginBase):
else: else:
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis) self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
self.logger.info(
f"{type(self).__name__}: dp_group {dist.get_process_group_ranks(self.dp_group)} pp_group {dist.get_process_group_ranks(self.pp_group)} tp_group {dist.get_process_group_ranks(self.tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}",
ranks=[0],
)
self.shard_config = ShardConfig( self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group, tensor_parallel_process_group=self.tp_group,
sequence_parallel_process_group=self.sp_group, sequence_parallel_process_group=self.sp_group,

View File

@ -226,12 +226,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
f"Only support shared tp group between moe and non moe params, but found non-moe tp {dist.get_process_group_ranks(self.tp_group)}, moe tp {dist.get_process_group_ranks(self.moe_tp_group)}, please make sure tp_size == moe_tp_size" f"Only support shared tp group between moe and non moe params, but found non-moe tp {dist.get_process_group_ranks(self.tp_group)}, moe tp {dist.get_process_group_ranks(self.moe_tp_group)}, please make sure tp_size == moe_tp_size"
) )
self.logger.info(
f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}\n"
f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}",
ranks=[0],
)
def get_checkpoint_io(self) -> MoECheckpointIO: def get_checkpoint_io(self) -> MoECheckpointIO:
return MoECheckpointIO( return MoECheckpointIO(
self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage