mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-11-03 23:48:41 +00:00 
			
		
		
		
	* [misc] add submodule * remove submodule * [example] support grok-1 tp inference * [example] add grok-1 inference script * [example] refactor code * [example] add grok-1 readme * [exmaple] add test ci * [exmaple] update readme
		
			
				
	
	
		
			100 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			100 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from typing import Dict, Union
 | 
						|
 | 
						|
import torch.nn as nn
 | 
						|
 | 
						|
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
 | 
						|
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
 | 
						|
 | 
						|
 | 
						|
class Grok1Policy(Policy):
 | 
						|
    def config_sanity_check(self):
 | 
						|
        pass
 | 
						|
 | 
						|
    def preprocess(self) -> nn.Module:
 | 
						|
        if self.shard_config.enable_tensor_parallelism:
 | 
						|
            vocab_size = self.model.config.vocab_size
 | 
						|
            world_size = self.shard_config.tensor_parallel_size
 | 
						|
            assert vocab_size % world_size == 0, f"vocab_size {vocab_size} must be divisible by world_size {world_size}"
 | 
						|
        return self.model
 | 
						|
 | 
						|
    def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
 | 
						|
        policy = {}
 | 
						|
        if self.shard_config.enable_tensor_parallelism:
 | 
						|
            decoder_attribute_replacement = {
 | 
						|
                "attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
 | 
						|
                "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
 | 
						|
                "attn.num_key_value_heads": self.model.config.num_key_value_heads
 | 
						|
                // self.shard_config.tensor_parallel_size,
 | 
						|
            }
 | 
						|
            decoder_submodule_replacement = [
 | 
						|
                SubModuleReplacementDescription(
 | 
						|
                    suffix="attn.q_proj",
 | 
						|
                    target_module=Linear1D_Col,
 | 
						|
                ),
 | 
						|
                SubModuleReplacementDescription(
 | 
						|
                    suffix="attn.k_proj",
 | 
						|
                    target_module=Linear1D_Col,
 | 
						|
                ),
 | 
						|
                SubModuleReplacementDescription(
 | 
						|
                    suffix="attn.v_proj",
 | 
						|
                    target_module=Linear1D_Col,
 | 
						|
                ),
 | 
						|
                SubModuleReplacementDescription(
 | 
						|
                    suffix="attn.o_proj",
 | 
						|
                    target_module=Linear1D_Row,
 | 
						|
                ),
 | 
						|
            ]
 | 
						|
            for i in range(self.model.config.num_experts):
 | 
						|
                decoder_submodule_replacement.extend(
 | 
						|
                    [
 | 
						|
                        SubModuleReplacementDescription(
 | 
						|
                            suffix=f"moe_block.experts[{i}].linear",
 | 
						|
                            target_module=Linear1D_Col,
 | 
						|
                        ),
 | 
						|
                        SubModuleReplacementDescription(
 | 
						|
                            suffix=f"moe_block.experts[{i}].linear_v",
 | 
						|
                            target_module=Linear1D_Col,
 | 
						|
                        ),
 | 
						|
                        SubModuleReplacementDescription(
 | 
						|
                            suffix=f"moe_block.experts[{i}].linear_1",
 | 
						|
                            target_module=Linear1D_Row,
 | 
						|
                        ),
 | 
						|
                    ]
 | 
						|
                )
 | 
						|
 | 
						|
            policy["DecoderLayer"] = ModulePolicyDescription(
 | 
						|
                attribute_replacement=decoder_attribute_replacement,
 | 
						|
                sub_module_replacement=decoder_submodule_replacement,
 | 
						|
            )
 | 
						|
            self.append_or_create_submodule_replacement(
 | 
						|
                description=SubModuleReplacementDescription(
 | 
						|
                    suffix="embed_tokens",
 | 
						|
                    target_module=VocabParallelEmbedding1D,
 | 
						|
                ),
 | 
						|
                policy=policy,
 | 
						|
                target_key="Grok1Model",
 | 
						|
            )
 | 
						|
        return policy
 | 
						|
 | 
						|
    def postprocess(self):
 | 
						|
        return self.model
 | 
						|
 | 
						|
 | 
						|
class Grok1ModelPolicy(Grok1Policy):
 | 
						|
    pass
 | 
						|
 | 
						|
 | 
						|
class Grok1ForCausalLMPolicy(Grok1Policy):
 | 
						|
    def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
 | 
						|
        policy = super().module_policy()
 | 
						|
        self.append_or_create_submodule_replacement(
 | 
						|
            description=SubModuleReplacementDescription(
 | 
						|
                suffix="lm_head",
 | 
						|
                target_module=Linear1D_Col,
 | 
						|
                kwargs={"gather_output": not self.shard_config.parallel_output},
 | 
						|
            ),
 | 
						|
            policy=policy,
 | 
						|
            target_key="Grok1ModelForCausalLM",
 | 
						|
        )
 | 
						|
        return policy
 |