mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-11-04 07:58:42 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			120 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			120 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import torch.distributed as dist
 | 
						|
 | 
						|
from colossalai.registry import DIST_GROUP_INITIALIZER
 | 
						|
from colossalai.global_variables import moe_env
 | 
						|
from .process_group_initializer import ProcessGroupInitializer
 | 
						|
from ..parallel_mode import ParallelMode
 | 
						|
 | 
						|
 | 
						|
@DIST_GROUP_INITIALIZER.register_module
 | 
						|
class Initializer_Moemodel(ProcessGroupInitializer):
 | 
						|
    """Model parallel initialization for MoE system.
 | 
						|
 | 
						|
    :param moe_moel: Size of moe model parallel
 | 
						|
    :param moe_data: Size of moe data parallel
 | 
						|
    :param args: Args used in base class
 | 
						|
    :param kwargs: Kwargs used in base class
 | 
						|
 | 
						|
    :type moe_model: int
 | 
						|
    :type moe_data: int
 | 
						|
    """
 | 
						|
    def __init__(self, moe_model, moe_data, *args, **kwargs):
 | 
						|
        super().__init__(*args, **kwargs)
 | 
						|
        self.moe_model = moe_model
 | 
						|
        self.moe_data = moe_data
 | 
						|
 | 
						|
    def init_dist_group(self):
 | 
						|
        """Initialize model parallel groups in moe parallel environment,
 | 
						|
        and assign local_ranks and groups to each gpu.
 | 
						|
 | 
						|
        :return: MoE model parallelism's information
 | 
						|
        :rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
 | 
						|
        """
 | 
						|
        local_rank = None
 | 
						|
        ranks_in_group = None
 | 
						|
        process_group = None
 | 
						|
        group_world_size = None
 | 
						|
        mode = ParallelMode.MOE_MODEL
 | 
						|
 | 
						|
        for i in range(self.moe_data):
 | 
						|
            ranks = [i * self.moe_model + j for j in range(self.moe_model)]
 | 
						|
            group = dist.new_group(ranks)
 | 
						|
 | 
						|
            if self.rank in ranks:
 | 
						|
                local_rank = ranks.index(self.rank)
 | 
						|
                group_world_size = len(ranks)
 | 
						|
                process_group = group
 | 
						|
                ranks_in_group = ranks
 | 
						|
 | 
						|
        return local_rank, group_world_size, process_group, ranks_in_group, mode
 | 
						|
 | 
						|
 | 
						|
@DIST_GROUP_INITIALIZER.register_module
 | 
						|
class Initializer_Moedata(ProcessGroupInitializer):
 | 
						|
    """Data parallel initialization for MoE system.
 | 
						|
 | 
						|
    :param moe_moel: Size of moe model parallel
 | 
						|
    :param moe_data: Size of moe data parallel
 | 
						|
    :param args: Args used in base class
 | 
						|
    :param kwargs: Kwargs used in base class
 | 
						|
 | 
						|
    :type moe_model: int
 | 
						|
    :type moe_data: int
 | 
						|
    """
 | 
						|
    def __init__(self, moe_model, moe_data, *args, **kwargs):
 | 
						|
        super().__init__(*args, **kwargs)
 | 
						|
        self.moe_model = moe_model
 | 
						|
        self.moe_data = moe_data
 | 
						|
 | 
						|
    def init_dist_group(self):
 | 
						|
        """Initialize data parallel groups in moe parallel environment,
 | 
						|
        and assign local_ranks and groups to each gpu.
 | 
						|
 | 
						|
        :return: MoE data parallelism's information
 | 
						|
        :rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
 | 
						|
        """
 | 
						|
        local_rank = None
 | 
						|
        ranks_in_group = None
 | 
						|
        process_group = None
 | 
						|
        group_world_size = None
 | 
						|
        mode = ParallelMode.MOE_DATA
 | 
						|
 | 
						|
        for i in range(self.moe_model):
 | 
						|
            ranks = [i + j * self.moe_model for j in range(self.moe_data)]
 | 
						|
            group = dist.new_group(ranks)
 | 
						|
 | 
						|
            if self.rank in ranks:
 | 
						|
                local_rank = ranks.index(self.rank)
 | 
						|
                group_world_size = len(ranks)
 | 
						|
                process_group = group
 | 
						|
                ranks_in_group = ranks
 | 
						|
 | 
						|
        return local_rank, group_world_size, process_group, ranks_in_group, mode
 | 
						|
 | 
						|
 | 
						|
@DIST_GROUP_INITIALIZER.register_module
 | 
						|
class Initializer_Moe(ProcessGroupInitializer):
 | 
						|
    """Serves as the single entry point to MoE parallel initialization.
 | 
						|
 | 
						|
    :param args: Args used to initialize ProcessGroupInitializer
 | 
						|
    :param kwargs: Kwargs used to initialize ProcessGroupInitializer
 | 
						|
    """
 | 
						|
    def __init__(self, *args, **kwargs):
 | 
						|
        super().__init__(*args, **kwargs)
 | 
						|
        self.moe_model = moe_env.model_parallel_size
 | 
						|
        self.moe_data = moe_env.data_parallel_size
 | 
						|
        self.model_initializer = Initializer_Moemodel(
 | 
						|
            self.moe_model, self.moe_data, *args, **kwargs)
 | 
						|
        self.data_initializer = Initializer_Moedata(
 | 
						|
            self.moe_model, self.moe_data, *args, **kwargs)
 | 
						|
 | 
						|
    def init_dist_group(self):
 | 
						|
        """Initializes MoE parallel communication groups.
 | 
						|
 | 
						|
        :return: MoE parallelism's information
 | 
						|
        :rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
 | 
						|
        """
 | 
						|
        parallel_setting = [self.model_initializer.init_dist_group(),
 | 
						|
                            self.data_initializer.init_dist_group()]
 | 
						|
        return parallel_setting
 |