[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -1,4 +1,4 @@
from .engine import TPInferEngine
from .kvcache_manager import MemoryManager
__all__ = ['MemoryManager', 'TPInferEngine']
__all__ = ["MemoryManager", "TPInferEngine"]

View File

@@ -1,6 +1,5 @@
# might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later
from dataclasses import dataclass
from typing import Any
import torch
@@ -31,7 +30,7 @@ class BatchInferState:
decode_mem_index: torch.Tensor = None
decode_layer_id: int = None
device: torch.device = torch.device('cuda')
device: torch.device = torch.device("cuda")
@property
def total_token_num(self):
@@ -43,13 +42,15 @@ class BatchInferState:
self.cache_manager = manager
@staticmethod
def init_block_loc(b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int,
alloc_mem_index: torch.Tensor):
""" in-place update block loc mapping based on the sequence length of the inputs in current bath"""
def init_block_loc(
b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor
):
"""in-place update block loc mapping based on the sequence length of the inputs in current bath"""
start_index = 0
seq_len_numpy = seq_len.cpu().numpy()
for i, cur_seq_len in enumerate(seq_len_numpy):
b_loc[i, max_len_in_batch - cur_seq_len:max_len_in_batch] = alloc_mem_index[start_index:start_index +
cur_seq_len]
b_loc[i, max_len_in_batch - cur_seq_len : max_len_in_batch] = alloc_mem_index[
start_index : start_index + cur_seq_len
]
start_index += cur_seq_len
return

View File

@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, List, Optional, Union
import torch
import torch.nn as nn
@@ -15,7 +15,7 @@ from .kvcache_manager import MemoryManager
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
_supported_models = ['LlamaForCausalLM', 'LlamaModel', 'BloomForCausalLM']
_supported_models = ["LlamaForCausalLM", "LlamaModel", "BloomForCausalLM"]
class TPInferEngine:
@@ -39,14 +39,16 @@ class TPInferEngine:
>>> outputs = infer_engine.generate(input_ids, **generate_kwargs)
"""
def __init__(self,
model: nn.Module,
shard_config: ShardConfig,
max_batch_size: int,
max_input_len: int,
max_output_len: int,
dtype: torch.dtype = torch.float16,
device: str = 'cuda') -> None:
def __init__(
self,
model: nn.Module,
shard_config: ShardConfig,
max_batch_size: int,
max_input_len: int,
max_output_len: int,
dtype: torch.dtype = torch.float16,
device: str = "cuda",
) -> None:
self.max_batch_size = max_batch_size
self.max_input_len = max_input_len
self.max_output_len = max_output_len
@@ -63,7 +65,7 @@ class TPInferEngine:
self.head_num = model.config.num_attention_heads
self.layer_num = model.config.num_hidden_layers
self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
self.cache_manager = None
self.shard_config = shard_config
@@ -74,9 +76,10 @@ class TPInferEngine:
def _init_manager(self) -> None:
assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig"
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
self.head_num //= self.tp_size # update sharded number of heads
self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, self.head_num, self.head_dim,
self.layer_num)
self.head_num //= self.tp_size # update sharded number of heads
self.cache_manager = MemoryManager(
self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num
)
def _optimize_model(self, model: nn.Module) -> None:
"""
@@ -90,7 +93,7 @@ class TPInferEngine:
self._shard_model_by(shardformer, model)
def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) -> ShardConfig:
""" Prepare the engine with a given ShardConfig.
"""Prepare the engine with a given ShardConfig.
Args:
shard_config (ShardConfig): shard config given to specify settings of the engine.
@@ -118,9 +121,10 @@ class TPInferEngine:
return shard_config
def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None:
""" Shard original model by the given ShardFormer and store the sharded model. """
assert self.tp_size == shardformer.shard_config.tensor_parallel_size, \
"Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
"""Shard original model by the given ShardFormer and store the sharded model."""
assert (
self.tp_size == shardformer.shard_config.tensor_parallel_size
), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
model_name = model.__class__.__name__
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
policy = get_autopolicy(model, inference_only=True)
@@ -147,7 +151,7 @@ class TPInferEngine:
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
input_tokens[t] = input_tokens[t].cuda()
if 'max_new_tokens' not in generate_kwargs:
if "max_new_tokens" not in generate_kwargs:
generate_kwargs.update(max_new_tokens=self.max_output_len)
return self._generate_by_set_infer_state(input_tokens, **generate_kwargs)
@@ -176,18 +180,18 @@ class TPInferEngine:
attention_mask = None
if isinstance(inputs, (BatchEncoding, dict)):
input_ids_list = inputs['input_ids']
attention_mask = inputs['attention_mask']
input_ids_list = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
else:
input_ids_list = inputs
if isinstance(input_ids_list[0], int): # for a single input
if isinstance(input_ids_list[0], int): # for a single input
input_ids_list = [input_ids_list]
attention_mask = [attention_mask] if attention_mask is not None else attention_mask
batch_size = len(input_ids_list)
seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
start_index = 0
max_len_in_batch = -1
@@ -210,10 +214,10 @@ class TPInferEngine:
seq_start_indexes[i] = start_index
start_index += curr_seq_len
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device='cuda')
block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda")
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
batch_infer_state.seq_len = seq_lengths.to('cuda')
batch_infer_state.start_loc = seq_start_indexes.to('cuda')
batch_infer_state.seq_len = seq_lengths.to("cuda")
batch_infer_state.start_loc = seq_start_indexes.to("cuda")
batch_infer_state.block_loc = block_loc
batch_infer_state.decode_layer_id = 0
batch_infer_state.past_key_values_len = 0
@@ -248,7 +252,7 @@ class TPInferEngine:
model = self.model.model
elif isinstance(model, BloomForCausalLM):
model = self.model.transformer
setattr(model, 'infer_state', batch_infer_state)
setattr(model, "infer_state", batch_infer_state)
outputs = self.model.generate(**input_tokens, **generate_kwargs, early_stopping=False)
@@ -262,14 +266,15 @@ class TPInferEngine:
# as an arg into model.forward.
# It requires rewriting model generate and replacing model forward.
@torch.no_grad()
def _generate_by_pass_infer_state(self,
input_tokens,
max_out_length: int,
generation_config: Optional[GenerationConfig] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
**model_kwargs) -> torch.Tensor:
def _generate_by_pass_infer_state(
self,
input_tokens,
max_out_length: int,
generation_config: Optional[GenerationConfig] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
**model_kwargs,
) -> torch.Tensor:
raise NotImplementedError("generate by passing BatchInferState is not implemented.")
# might want to use in rewritten generate method: use after model.forward

View File

@@ -19,13 +19,15 @@ class MemoryManager:
device: device used to store the key and value cache
"""
def __init__(self,
size: int,
dtype: torch.dtype,
head_num: int,
head_dim: int,
layer_num: int,
device: torch.device = torch.device('cuda')):
def __init__(
self,
size: int,
dtype: torch.dtype,
head_num: int,
head_dim: int,
layer_num: int,
device: torch.device = torch.device("cuda"),
):
self.logger = logging.get_logger(__name__)
self.available_size = size
self.past_key_values_length = 0
@@ -33,13 +35,13 @@ class MemoryManager:
self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num)
def _init_mem_states(self, size, device):
""" Initialize tensors used to manage memory states """
"""Initialize tensors used to manage memory states"""
self.mem_state = torch.ones((size,), dtype=torch.bool, device=device)
self.mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=device)
self.indexes = torch.arange(0, size, dtype=torch.long, device=device)
def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num):
""" Initialize key buffer and value buffer on specified device """
"""Initialize key buffer and value buffer on specified device"""
self.key_buffer = [
torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num)
]
@@ -49,10 +51,9 @@ class MemoryManager:
@torch.no_grad()
def alloc(self, required_size):
""" allocate space of required_size by providing indexes representing available physical spaces """
"""allocate space of required_size by providing indexes representing available physical spaces"""
if required_size > self.available_size:
self.logger.warning(f"No enough cache: required_size {required_size} "
f"left_size {self.available_size}")
self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}")
return None
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1)
@@ -63,23 +64,25 @@ class MemoryManager:
@torch.no_grad()
def alloc_contiguous(self, required_size):
""" allocate contiguous space of required_size """
"""allocate contiguous space of required_size"""
if required_size > self.available_size:
self.logger.warning(f"No enough cache: required_size {required_size} "
f"left_size {self.available_size}")
self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}")
return None
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
sum_size = len(self.mem_cum_sum)
loc_sums = self.mem_cum_sum[required_size - 1:] - self.mem_cum_sum[0:sum_size - required_size +
1] + self.mem_state[0:sum_size -
required_size + 1]
can_used_loc = self.indexes[0:sum_size - required_size + 1][loc_sums == required_size]
loc_sums = (
self.mem_cum_sum[required_size - 1 :]
- self.mem_cum_sum[0 : sum_size - required_size + 1]
+ self.mem_state[0 : sum_size - required_size + 1]
)
can_used_loc = self.indexes[0 : sum_size - required_size + 1][loc_sums == required_size]
if can_used_loc.shape[0] == 0:
self.logger.info(f"No enough contiguous cache: required_size {required_size} "
f"left_size {self.available_size}")
self.logger.info(
f"No enough contiguous cache: required_size {required_size} " f"left_size {self.available_size}"
)
return None
start_loc = can_used_loc[0]
select_index = self.indexes[start_loc:start_loc + required_size]
select_index = self.indexes[start_loc : start_loc + required_size]
self.mem_state[select_index] = 0
self.available_size -= len(select_index)
start = start_loc.item()
@@ -88,13 +91,13 @@ class MemoryManager:
@torch.no_grad()
def free(self, free_index):
""" free memory by updating memory states based on given indexes """
"""free memory by updating memory states based on given indexes"""
self.available_size += free_index.shape[0]
self.mem_state[free_index] = 1
@torch.no_grad()
def free_all(self):
""" free all memory by updating memory states """
"""free all memory by updating memory states"""
self.available_size = len(self.mem_state)
self.mem_state[:] = 1
self.past_key_values_length = 0

View File

@@ -1,4 +1,4 @@
from .bloom import BloomInferenceForwards
from .llama import LlamaInferenceForwards
__all__ = ['BloomInferenceForwards', 'LlamaInferenceForwards']
__all__ = ["BloomInferenceForwards", "LlamaInferenceForwards"]

View File

@@ -1,6 +1,6 @@
import math
import warnings
from typing import List, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import torch
import torch.distributed as dist
@@ -31,17 +31,17 @@ def generate_alibi(n_head, dtype=torch.float16):
"""
def get_slopes_power_of_2(n):
start = 2**(-(2**-(math.log2(n) - 3)))
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
return [start * start**i for i in range(n)]
def get_slopes(n):
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2**math.floor(math.log2(n))
closest_power_of_2 = 2 ** math.floor(math.log2(n))
slopes_power_of_2 = get_slopes_power_of_2(closest_power_of_2)
slopes_double = get_slopes(2 * closest_power_of_2)
slopes_combined = slopes_power_of_2 + slopes_double[0::2][:n - closest_power_of_2]
slopes_combined = slopes_power_of_2 + slopes_double[0::2][: n - closest_power_of_2]
return slopes_combined
slopes = get_slopes(n_head)
@@ -72,7 +72,6 @@ class BloomInferenceForwards:
infer_state: Optional[BatchInferState] = None,
**deprecated_arguments,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
logger = logging.get_logger(__name__)
if deprecated_arguments.pop("position_ids", False) is not False:
@@ -86,8 +85,9 @@ class BloomInferenceForwards:
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@@ -122,14 +122,15 @@ class BloomInferenceForwards:
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# NOTE determine if BatchInferState is passed in via arg
# if not, get the attr binded to the model
# We might wantto remove setattr later
if infer_state is None:
assert hasattr(self, 'infer_state')
assert hasattr(self, "infer_state")
infer_state = self.infer_state
# Compute alibi tensor: check build_alibi_tensor documentation
@@ -146,10 +147,11 @@ class BloomInferenceForwards:
if use_cache and seq_length != 1:
# prefill stage
infer_state.is_context_stage = True # set prefill stage, notify attention layer
infer_state.is_context_stage = True # set prefill stage, notify attention layer
infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
BatchInferState.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length,
infer_state.context_mem_index)
BatchInferState.init_block_loc(
infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
)
else:
infer_state.is_context_stage = False
alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
@@ -182,8 +184,11 @@ class BloomInferenceForwards:
# alibi = generate_alibi(self.num_heads).contiguous().cuda()
tp_size = dist.get_world_size()
curr_tp_rank = dist.get_rank()
alibi = generate_alibi(self.num_heads * tp_size).contiguous()[curr_tp_rank * self.num_heads:(curr_tp_rank + 1) *
self.num_heads].cuda()
alibi = (
generate_alibi(self.num_heads * tp_size)
.contiguous()[curr_tp_rank * self.num_heads : (curr_tp_rank + 1) * self.num_heads]
.cuda()
)
causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
@@ -197,7 +202,6 @@ class BloomInferenceForwards:
if self.gradient_checkpointing and self.training:
# NOTE: currently our KV cache manager does not handle this condition
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
@@ -250,32 +254,34 @@ class BloomInferenceForwards:
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents, # should always be (None, None, ..., None)
past_key_values=presents, # should always be (None, None, ..., None)
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
@staticmethod
def bloom_for_causal_lm_forward(self: BloomForCausalLM,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
infer_state: Optional[BatchInferState] = None,
**deprecated_arguments):
def bloom_for_causal_lm_forward(
self: BloomForCausalLM,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
infer_state: Optional[BatchInferState] = None,
**deprecated_arguments,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
logger = logging.get_logger(__name__)
logging.get_logger(__name__)
if deprecated_arguments.pop("position_ids", False) is not False:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
@@ -289,17 +295,19 @@ class BloomInferenceForwards:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = BloomInferenceForwards.bloom_model_forward(self.transformer,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
infer_state=infer_state)
transformer_outputs = BloomInferenceForwards.bloom_model_forward(
self.transformer,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
infer_state=infer_state,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
@@ -314,8 +322,9 @@ class BloomInferenceForwards:
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size),
shift_labels.view(batch_size * seq_length))
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
@@ -353,11 +362,13 @@ class BloomInferenceForwards:
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update({
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
})
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
@staticmethod
@@ -416,7 +427,7 @@ class BloomInferenceForwards:
else:
outputs = (output,) + outputs[1:]
return outputs # hidden_states, present, attentions
return outputs # hidden_states, present, attentions
@staticmethod
def bloom_attention_forward(
@@ -431,20 +442,19 @@ class BloomInferenceForwards:
output_attentions: bool = False,
infer_state: Optional[BatchInferState] = None,
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, q_length, H, D_HEAD = query_layer.shape
k = key_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
v = value_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
k = key_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
v = value_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
mem_manager = infer_state.cache_manager
layer_id = infer_state.decode_layer_id
if layer_id == 0: # once per model.forward
infer_state.cache_manager.past_key_values_length += q_length # += 1
if layer_id == 0: # once per model.forward
infer_state.cache_manager.past_key_values_length += q_length # += 1
if infer_state.is_context_stage:
# context process
@@ -471,9 +481,11 @@ class BloomInferenceForwards:
if infer_state.decode_is_contiguous:
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
cache_k = infer_state.cache_manager.key_buffer[layer_id][
infer_state.decode_mem_start:infer_state.decode_mem_end, :, :]
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
]
cache_v = infer_state.cache_manager.value_buffer[layer_id][
infer_state.decode_mem_start:infer_state.decode_mem_end, :, :]
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
]
cache_k.copy_(k)
cache_v.copy_(v)
else:
@@ -486,8 +498,17 @@ class BloomInferenceForwards:
b_loc = infer_state.block_loc
b_seq_len = infer_state.seq_len
output = torch.empty_like(q)
token_attention_fwd(q, mem_manager.key_buffer[layer_id], mem_manager.value_buffer[layer_id], output, b_loc,
b_start_loc, b_seq_len, infer_state.cache_manager.past_key_values_length, alibi)
token_attention_fwd(
q,
mem_manager.key_buffer[layer_id],
mem_manager.value_buffer[layer_id],
output,
b_loc,
b_start_loc,
b_seq_len,
infer_state.cache_manager.past_key_values_length,
alibi,
)
context_layer = output.view(batch_size, q_length, H * D_HEAD)
@@ -504,8 +525,8 @@ class BloomInferenceForwards:
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices):int((i + 1) * slices)],
self.dense.weight[:, int(i * slices):int((i + 1) * slices)],
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)

View File

@@ -1,6 +1,5 @@
from typing import List, Optional, Tuple
import numpy as np
import torch
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
@@ -15,6 +14,7 @@ from colossalai.kernel.triton import (
try:
from vllm import layernorm_ops, pos_encoding_ops
rms_norm = layernorm_ops.rms_norm
rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
HAS_VLLM_KERNERL = True
@@ -29,17 +29,17 @@ except:
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
@@ -71,8 +71,7 @@ class LlamaInferenceForwards:
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
batch_size = input_ids.shape[0] # input_ids.shape[0]
batch_size = input_ids.shape[0] # input_ids.shape[0]
infer_state = self.infer_state
@@ -103,10 +102,11 @@ class LlamaInferenceForwards:
if use_cache and seq_length != 1:
# NOTE assuem prefill stage
# allocate memory block
infer_state.is_context_stage = True # set prefill stage, notify attention layer
infer_state.is_context_stage = True # set prefill stage, notify attention layer
infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
infer_state.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length,
infer_state.context_mem_index)
infer_state.init_block_loc(
infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
)
else:
infer_state.is_context_stage = False
alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
@@ -129,20 +129,20 @@ class LlamaInferenceForwards:
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device)
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if infer_state.is_context_stage:
infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
position_ids.view(-1).shape[0], -1)
position_ids.view(-1).shape[0], -1
)
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
position_ids.view(-1).shape[0], -1)
position_ids.view(-1).shape[0], -1
)
else:
seq_len = infer_state.seq_len
infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
@@ -153,12 +153,13 @@ class LlamaInferenceForwards:
# embed positions
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device)
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds,
past_key_values_length)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
hidden_states = inputs_embeds
@@ -216,7 +217,6 @@ class LlamaInferenceForwards:
use_cache: Optional[bool] = False,
infer_state: Optional[BatchInferState] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
@@ -261,7 +261,6 @@ class LlamaInferenceForwards:
use_cache: bool = False,
infer_state: Optional[BatchInferState] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
assert use_cache is True, "use_cache should be set to True using this llama attention"
bsz, q_len, _ = hidden_states.size()
@@ -277,8 +276,8 @@ class LlamaInferenceForwards:
# NOTE might want to revise
# need some way to record the length of past key values cache
# since we won't return past_key_value_cache right now
if infer_state.decode_layer_id == 0: # once per model.forward
infer_state.cache_manager.past_key_values_length += q_len # seq_len
if infer_state.decode_layer_id == 0: # once per model.forward
infer_state.cache_manager.past_key_values_length += q_len # seq_len
cos, sin = infer_state.position_cos, infer_state.position_sin
# print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, )
@@ -299,38 +298,62 @@ class LlamaInferenceForwards:
# first token generation
# copy key and value calculated in current step to memory manager
_copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.context_mem_index,
infer_state.cache_manager)
_copy_kv_to_mem_cache(
infer_state.decode_layer_id,
key_states,
value_states,
infer_state.context_mem_index,
infer_state.cache_manager,
)
attn_output = torch.empty_like(query_states)
llama_context_attn_fwd(query_states, key_states, value_states, attn_output, infer_state.start_loc,
infer_state.seq_len, infer_state.cache_manager.past_key_values_length)
llama_context_attn_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length,
)
else:
if infer_state.decode_is_contiguous:
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
infer_state.decode_mem_start:infer_state.decode_mem_end, :, :]
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
]
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
infer_state.decode_mem_start:infer_state.decode_mem_end, :, :]
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
]
cache_k.copy_(key_states)
cache_v.copy_(value_states)
else:
# if decode is not contiguous, use triton kernel to copy key and value cache
# k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
_copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states,
infer_state.decode_mem_index, infer_state.cache_manager)
_copy_kv_to_mem_cache(
infer_state.decode_layer_id,
key_states,
value_states,
infer_state.decode_mem_index,
infer_state.cache_manager,
)
# second token and follows
# kv = torch.stack((key_states, value_states), dim=2)
# (batch_size, seqlen, nheads, headdim)
attn_output = torch.empty_like(query_states)
token_attention_fwd(query_states, infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], attn_output,
infer_state.block_loc, infer_state.start_loc, infer_state.seq_len,
infer_state.cache_manager.past_key_values_length)
token_attention_fwd(
query_states,
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
attn_output,
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length,
)
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
@@ -341,7 +364,6 @@ class LlamaInferenceForwards:
def get_llama_vllm_rmsnorm_forward():
if HAS_VLLM_KERNERL:
def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):

View File

@@ -1,4 +1,4 @@
from .bloom import BloomModelInferPolicy
from .llama import LlamaModelInferPolicy
__all__ = ['BloomModelInferPolicy', 'LlamaModelInferPolicy']
__all__ = ["BloomModelInferPolicy", "LlamaModelInferPolicy"]

View File

@@ -9,6 +9,7 @@ from ..modeling.bloom import BloomInferenceForwards
try:
from colossalai.kernel.triton import layer_norm
HAS_TRITON_NORM = True
except:
print("Some of our kernels require triton. You might want to install triton from https://github.com/openai/triton")
@@ -27,40 +28,40 @@ def get_triton_layernorm_forward():
class BloomModelInferPolicy(BloomForCausalLMPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel
policy = super().module_policy()
# NOTE set inference mode to shard config
self.shard_config._infer()
method_replacement = {
'forward': BloomInferenceForwards.bloom_for_causal_lm_forward,
'prepare_inputs_for_generation': BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation
"forward": BloomInferenceForwards.bloom_for_causal_lm_forward,
"prepare_inputs_for_generation": BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation,
}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=BloomForCausalLM)
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=BloomForCausalLM
)
method_replacement = {'forward': BloomInferenceForwards.bloom_model_forward}
method_replacement = {"forward": BloomInferenceForwards.bloom_model_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomModel)
method_replacement = {'forward': BloomInferenceForwards.bloom_block_forward}
method_replacement = {"forward": BloomInferenceForwards.bloom_block_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomBlock)
method_replacement = {'forward': BloomInferenceForwards.bloom_attention_forward}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=BloomAttention)
method_replacement = {"forward": BloomInferenceForwards.bloom_attention_forward}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=BloomAttention
)
if HAS_TRITON_NORM:
infer_method = get_triton_layernorm_forward()
method_replacement = {'forward': partial(infer_method)}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=LayerNorm)
method_replacement = {"forward": partial(infer_method)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LayerNorm
)
return policy

View File

@@ -10,6 +10,7 @@ from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forw
try:
from colossalai.kernel.triton import rmsnorm_forward
HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
@@ -28,7 +29,6 @@ def get_triton_rmsnorm_forward():
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
def __init__(self) -> None:
super().__init__()
@@ -37,20 +37,20 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
self.shard_config._infer()
infer_forward = LlamaInferenceForwards.llama_model_forward
method_replacement = {'forward': partial(infer_forward)}
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward
method_replacement = {'forward': partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=LlamaDecoderLayer)
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
)
infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward
method_replacement = {'forward': partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=LlamaAttention)
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaAttention
)
infer_forward = None
if HAS_TRITON_RMSNORM:
@@ -60,9 +60,9 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
infer_forward = get_llama_vllm_rmsnorm_forward()
if infer_forward is not None:
method_replacement = {'forward': partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=LlamaRMSNorm)
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaRMSNorm
)
return policy