mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-28 16:28:10 +00:00
[infer] Infer/llama demo (#4503)
* add * add infer example * finish * finish * stash * fix
This commit is contained in:
parent
d20dceb9a3
commit
c427366024
@ -19,6 +19,7 @@ class LlamaPipelineForwards:
|
|||||||
under pipeline setting.
|
under pipeline setting.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def llama_model_forward(
|
def llama_model_forward(
|
||||||
self: LlamaModel,
|
self: LlamaModel,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
@ -169,6 +170,7 @@ class LlamaPipelineForwards:
|
|||||||
# always return dict for imediate stage
|
# always return dict for imediate stage
|
||||||
return {'hidden_states': hidden_states}
|
return {'hidden_states': hidden_states}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def llama_for_causal_lm_forward(
|
def llama_for_causal_lm_forward(
|
||||||
self: LlamaForCausalLM,
|
self: LlamaForCausalLM,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
@ -276,6 +278,7 @@ class LlamaPipelineForwards:
|
|||||||
hidden_states = outputs.get('hidden_states')
|
hidden_states = outputs.get('hidden_states')
|
||||||
return {'hidden_states': hidden_states}
|
return {'hidden_states': hidden_states}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def llama_for_sequence_classification_forward(
|
def llama_for_sequence_classification_forward(
|
||||||
self: LlamaForSequenceClassification,
|
self: LlamaForSequenceClassification,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
@ -388,6 +391,84 @@ class LlamaPipelineForwards:
|
|||||||
return {'hidden_states': hidden_states}
|
return {'hidden_states': hidden_states}
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaInferenceForwards:
|
||||||
|
"""
|
||||||
|
This class holds forwards for llama inference.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def llama_model_forward(
|
||||||
|
self: LlamaModel,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[
|
||||||
|
torch.LongTensor] = None, # TODO: this can also be removed if we got sin,cos cached in inferinfo
|
||||||
|
past_key_values: Optional[List[
|
||||||
|
torch.FloatTensor]] = None, #TODO: maybe removed after memory cache manager is done.
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
inferinfo=None,
|
||||||
|
):
|
||||||
|
# only keep the basic items
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||||
|
|
||||||
|
seq_length_with_past = seq_length
|
||||||
|
past_key_values_length = 0
|
||||||
|
|
||||||
|
if past_key_values is not None:
|
||||||
|
past_key_values_length = past_key_values[0][0].shape[2]
|
||||||
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
position_ids = torch.arange(past_key_values_length,
|
||||||
|
seq_length + past_key_values_length,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=device)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
|
else:
|
||||||
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
# embed positions
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones((batch_size, seq_length_with_past),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=inputs_embeds.device)
|
||||||
|
attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds,
|
||||||
|
past_key_values_length)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return hidden_states
|
||||||
|
return BaseModelOutputWithPast(last_hidden_state=hidden_states,)
|
||||||
|
|
||||||
|
|
||||||
def get_llama_flash_attention_forward():
|
def get_llama_flash_attention_forward():
|
||||||
|
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import importlib
|
import importlib
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
@ -130,6 +131,12 @@ _POLICY_LIST = {
|
|||||||
PolicyLocation(file_name="chatglm", class_name="ChatGLMForConditionalGenerationPolicy"),
|
PolicyLocation(file_name="chatglm", class_name="ChatGLMForConditionalGenerationPolicy"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_INFER_POLICY_LIST = {
|
||||||
|
# LlaMa
|
||||||
|
"transformers.models.llama.modeling_llama.LlamaModel":
|
||||||
|
PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def import_policy(policy_location: PolicyLocation) -> Policy:
|
def import_policy(policy_location: PolicyLocation) -> Policy:
|
||||||
"""
|
"""
|
||||||
@ -151,7 +158,7 @@ def _fullname(obj):
|
|||||||
return module + '.' + klass.__qualname__
|
return module + '.' + klass.__qualname__
|
||||||
|
|
||||||
|
|
||||||
def get_autopolicy(model: nn.Module) -> Policy:
|
def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> Policy:
|
||||||
r"""
|
r"""
|
||||||
Return the auto policy for the model
|
Return the auto policy for the model
|
||||||
|
|
||||||
@ -162,7 +169,10 @@ def get_autopolicy(model: nn.Module) -> Policy:
|
|||||||
:class:`Policy`: The auto policy for the model
|
:class:`Policy`: The auto policy for the model
|
||||||
"""
|
"""
|
||||||
full_name = _fullname(model)
|
full_name = _fullname(model)
|
||||||
policy_location = _POLICY_LIST.get(full_name, None)
|
if inference_only:
|
||||||
|
policy_location = _INFER_POLICY_LIST.get(full_name, None)
|
||||||
|
else:
|
||||||
|
policy_location = _POLICY_LIST.get(full_name, None)
|
||||||
|
|
||||||
if policy_location is None:
|
if policy_location is None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
@ -7,7 +7,7 @@ from torch.nn import Module
|
|||||||
|
|
||||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||||
|
|
||||||
from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward
|
from ..modeling.llama import LlamaInferenceForwards, LlamaPipelineForwards, get_llama_flash_attention_forward
|
||||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy']
|
__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy']
|
||||||
@ -263,3 +263,21 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
|||||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
"""No shared params in llama for sequence classification model"""
|
"""No shared params in llama for sequence classification model"""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaModelInferPolicy(LlamaPolicy):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def module_policy(self):
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
|
||||||
|
policy = super().module_policy()
|
||||||
|
# configure default shard config for inference
|
||||||
|
self.shard_config._infer()
|
||||||
|
|
||||||
|
infer_forward = LlamaInferenceForwards.llama_model_forward
|
||||||
|
method_replacement = {'forward': partial(infer_forward)}
|
||||||
|
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
|
||||||
|
|
||||||
|
return policy
|
||||||
|
@ -28,6 +28,7 @@ class ShardConfig:
|
|||||||
enable_all_optimization: bool = False
|
enable_all_optimization: bool = False
|
||||||
enable_flash_attention: bool = False
|
enable_flash_attention: bool = False
|
||||||
enable_jit_fused: bool = False
|
enable_jit_fused: bool = False
|
||||||
|
inference_only: bool = False
|
||||||
|
|
||||||
# pipeline_parallel_size: int
|
# pipeline_parallel_size: int
|
||||||
# data_parallel_size: int
|
# data_parallel_size: int
|
||||||
@ -57,3 +58,9 @@ class ShardConfig:
|
|||||||
self.enable_fused_normalization = True
|
self.enable_fused_normalization = True
|
||||||
self.enable_flash_attention = True
|
self.enable_flash_attention = True
|
||||||
self.enable_jit_fused = True
|
self.enable_jit_fused = True
|
||||||
|
|
||||||
|
def _infer(self):
|
||||||
|
"""
|
||||||
|
Set default params for inference.
|
||||||
|
"""
|
||||||
|
self.pipeline_stage_manager = None
|
||||||
|
@ -27,7 +27,8 @@ class ModelSharder(object):
|
|||||||
|
|
||||||
def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None:
|
def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.policy = get_autopolicy(self.model) if policy is None else policy
|
self.policy = get_autopolicy(self.model, shard_config.inference_only) if policy is None else policy
|
||||||
|
print(self.policy)
|
||||||
self.shard_config = shard_config
|
self.shard_config = shard_config
|
||||||
|
|
||||||
def shard(self) -> List[Dict[int, Tensor]]:
|
def shard(self) -> List[Dict[int, Tensor]]:
|
||||||
|
53
tests/test_infer/_utils.py
Normal file
53
tests/test_infer/_utils.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
import copy
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch import Tensor
|
||||||
|
from torch import distributed as dist
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
from torch.nn import Module
|
||||||
|
from torch.optim import Adam, Optimizer
|
||||||
|
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin import HybridParallelPlugin
|
||||||
|
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
|
||||||
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||||
|
from colossalai.shardformer._utils import getattr_
|
||||||
|
from colossalai.shardformer.policies.auto_policy import Policy
|
||||||
|
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def build_model(
|
||||||
|
model_fn,
|
||||||
|
enable_fused_normalization=False,
|
||||||
|
enable_tensor_parallelism=False,
|
||||||
|
enable_flash_attention=False,
|
||||||
|
enable_jit_fused=False,
|
||||||
|
):
|
||||||
|
# create new model
|
||||||
|
org_model = model_fn()
|
||||||
|
|
||||||
|
# shard model
|
||||||
|
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
|
||||||
|
enable_tensor_parallelism=enable_tensor_parallelism,
|
||||||
|
enable_flash_attention=enable_flash_attention,
|
||||||
|
enable_jit_fused=enable_jit_fused,
|
||||||
|
inference_only=True)
|
||||||
|
model_copy = copy.deepcopy(org_model)
|
||||||
|
shard_former = ShardFormer(shard_config=shard_config)
|
||||||
|
sharded_model, shared_params = shard_former.optimize(model_copy)
|
||||||
|
return org_model.cuda(), sharded_model.cuda()
|
||||||
|
|
||||||
|
|
||||||
|
def run_infer(original_model, sharded_model, data_gen_fn, output_transform_fn):
|
||||||
|
# prepare input
|
||||||
|
data = data_gen_fn()
|
||||||
|
data = {k: v.cuda() for k, v in data.items()}
|
||||||
|
# run forward
|
||||||
|
org_output = original_model(**data)
|
||||||
|
org_output = output_transform_fn(org_output)
|
||||||
|
|
||||||
|
shard_output = sharded_model(**data)
|
||||||
|
shard_output = output_transform_fn(shard_output)
|
||||||
|
|
||||||
|
return org_output, shard_output
|
55
tests/test_infer/test_llama_infer.py
Normal file
55
tests/test_infer/test_llama_infer.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from torch import distributed as dist
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.shardformer.layer.utils import Randomizer
|
||||||
|
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||||
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
from tests.test_infer._utils import build_model, run_infer
|
||||||
|
|
||||||
|
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||||
|
|
||||||
|
|
||||||
|
def check_infer(model_fn, data_gen_fn, output_transform_fn, test_config):
|
||||||
|
org_model, sharded_model = build_model(model_fn, **test_config)
|
||||||
|
|
||||||
|
org_output, infer_output = run_infer(org_model, sharded_model, data_gen_fn, output_transform_fn)
|
||||||
|
|
||||||
|
print('original output', org_output[0])
|
||||||
|
print('infer output', infer_output[0])
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize('test_config', [{
|
||||||
|
'enable_flash_attention': False,
|
||||||
|
}])
|
||||||
|
def run_llama_test(test_config):
|
||||||
|
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
|
||||||
|
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
if name != "transformers_llama":
|
||||||
|
continue
|
||||||
|
check_infer(model_fn, data_gen_fn, output_transform_fn, test_config)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def check_llama(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
run_llama_test()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
@clear_cache_before_run()
|
||||||
|
def test_llama():
|
||||||
|
spawn(check_llama, 1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_llama()
|
Loading…
Reference in New Issue
Block a user