diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index f1d2998bb..57afce70e 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -19,6 +19,7 @@ class LlamaPipelineForwards: under pipeline setting. ''' + @staticmethod def llama_model_forward( self: LlamaModel, input_ids: torch.LongTensor = None, @@ -169,6 +170,7 @@ class LlamaPipelineForwards: # always return dict for imediate stage return {'hidden_states': hidden_states} + @staticmethod def llama_for_causal_lm_forward( self: LlamaForCausalLM, input_ids: torch.LongTensor = None, @@ -276,6 +278,7 @@ class LlamaPipelineForwards: hidden_states = outputs.get('hidden_states') return {'hidden_states': hidden_states} + @staticmethod def llama_for_sequence_classification_forward( self: LlamaForSequenceClassification, input_ids: torch.LongTensor = None, @@ -388,6 +391,84 @@ class LlamaPipelineForwards: return {'hidden_states': hidden_states} +class LlamaInferenceForwards: + """ + This class holds forwards for llama inference. + """ + + @staticmethod + def llama_model_forward( + self: LlamaModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[ + torch.LongTensor] = None, # TODO: this can also be removed if we got sin,cos cached in inferinfo + past_key_values: Optional[List[ + torch.FloatTensor]] = None, #TODO: maybe removed after memory cache manager is done. + inputs_embeds: Optional[torch.FloatTensor] = None, + return_dict: Optional[bool] = None, + inferinfo=None, + ): + # only keep the basic items + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange(past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), + dtype=torch.bool, + device=inputs_embeds.device) + attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds, + past_key_values_length) + + hidden_states = inputs_embeds + + for idx, decoder_layer in enumerate(self.layers): + past_key_value = past_key_values[idx] if past_key_values is not None else None + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + ) + + hidden_states = layer_outputs[0] + + hidden_states = self.norm(hidden_states) + + if not return_dict: + return hidden_states + return BaseModelOutputWithPast(last_hidden_state=hidden_states,) + + def get_llama_flash_attention_forward(): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index eec339c02..43ea1c5ab 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -1,5 +1,6 @@ import importlib from dataclasses import dataclass +from typing import Optional import torch.nn as nn @@ -130,6 +131,12 @@ _POLICY_LIST = { PolicyLocation(file_name="chatglm", class_name="ChatGLMForConditionalGenerationPolicy"), } +_INFER_POLICY_LIST = { + # LlaMa + "transformers.models.llama.modeling_llama.LlamaModel": + PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy") +} + def import_policy(policy_location: PolicyLocation) -> Policy: """ @@ -151,7 +158,7 @@ def _fullname(obj): return module + '.' + klass.__qualname__ -def get_autopolicy(model: nn.Module) -> Policy: +def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> Policy: r""" Return the auto policy for the model @@ -162,7 +169,10 @@ def get_autopolicy(model: nn.Module) -> Policy: :class:`Policy`: The auto policy for the model """ full_name = _fullname(model) - policy_location = _POLICY_LIST.get(full_name, None) + if inference_only: + policy_location = _INFER_POLICY_LIST.get(full_name, None) + else: + policy_location = _POLICY_LIST.get(full_name, None) if policy_location is None: raise NotImplementedError( diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 5ee95f3be..0de2752cf 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -7,7 +7,7 @@ from torch.nn import Module from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D -from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward +from ..modeling.llama import LlamaInferenceForwards, LlamaPipelineForwards, get_llama_flash_attention_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy'] @@ -263,3 +263,21 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy): def get_shared_params(self) -> List[Dict[int, Tensor]]: """No shared params in llama for sequence classification model""" return [] + + +class LlamaModelInferPolicy(LlamaPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel + policy = super().module_policy() + # configure default shard config for inference + self.shard_config._infer() + + infer_forward = LlamaInferenceForwards.llama_model_forward + method_replacement = {'forward': partial(infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel) + + return policy diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 0c28f115d..35b526456 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -28,6 +28,7 @@ class ShardConfig: enable_all_optimization: bool = False enable_flash_attention: bool = False enable_jit_fused: bool = False + inference_only: bool = False # pipeline_parallel_size: int # data_parallel_size: int @@ -57,3 +58,9 @@ class ShardConfig: self.enable_fused_normalization = True self.enable_flash_attention = True self.enable_jit_fused = True + + def _infer(self): + """ + Set default params for inference. + """ + self.pipeline_stage_manager = None diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 0ed745a1f..39704ae5e 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -27,7 +27,8 @@ class ModelSharder(object): def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None: self.model = model - self.policy = get_autopolicy(self.model) if policy is None else policy + self.policy = get_autopolicy(self.model, shard_config.inference_only) if policy is None else policy + print(self.policy) self.shard_config = shard_config def shard(self) -> List[Dict[int, Tensor]]: diff --git a/tests/test_infer/_utils.py b/tests/test_infer/_utils.py new file mode 100644 index 000000000..3d56cc348 --- /dev/null +++ b/tests/test_infer/_utils.py @@ -0,0 +1,53 @@ +import copy + +import torch +import torch.distributed as dist +from torch import Tensor +from torch import distributed as dist +from torch.distributed import ProcessGroup +from torch.nn import Module +from torch.optim import Adam, Optimizer + +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer._utils import getattr_ +from colossalai.shardformer.policies.auto_policy import Policy +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor + + +def build_model( + model_fn, + enable_fused_normalization=False, + enable_tensor_parallelism=False, + enable_flash_attention=False, + enable_jit_fused=False, +): + # create new model + org_model = model_fn() + + # shard model + shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, + enable_tensor_parallelism=enable_tensor_parallelism, + enable_flash_attention=enable_flash_attention, + enable_jit_fused=enable_jit_fused, + inference_only=True) + model_copy = copy.deepcopy(org_model) + shard_former = ShardFormer(shard_config=shard_config) + sharded_model, shared_params = shard_former.optimize(model_copy) + return org_model.cuda(), sharded_model.cuda() + + +def run_infer(original_model, sharded_model, data_gen_fn, output_transform_fn): + # prepare input + data = data_gen_fn() + data = {k: v.cuda() for k, v in data.items()} + # run forward + org_output = original_model(**data) + org_output = output_transform_fn(org_output) + + shard_output = sharded_model(**data) + shard_output = output_transform_fn(shard_output) + + return org_output, shard_output diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py new file mode 100644 index 000000000..09a81ef7f --- /dev/null +++ b/tests/test_infer/test_llama_infer.py @@ -0,0 +1,55 @@ +import os + +import pytest +import torch +from torch import distributed as dist + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_infer._utils import build_model, run_infer + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' + + +def check_infer(model_fn, data_gen_fn, output_transform_fn, test_config): + org_model, sharded_model = build_model(model_fn, **test_config) + + org_output, infer_output = run_infer(org_model, sharded_model, data_gen_fn, output_transform_fn) + + print('original output', org_output[0]) + print('infer output', infer_output[0]) + + +@parameterize('test_config', [{ + 'enable_flash_attention': False, +}]) +def run_llama_test(test_config): + + sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + if name != "transformers_llama": + continue + check_infer(model_fn, data_gen_fn, output_transform_fn, test_config) + torch.cuda.empty_cache() + + +def check_llama(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_llama_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama(): + spawn(check_llama, 1) + + +if __name__ == "__main__": + test_llama()