mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-20 20:54:55 +00:00
[inference] decouple pipeline logci for bloom (#5097)
This commit is contained in:
parent
afe3c78d9a
commit
67a07e6f64
@ -5,6 +5,7 @@ from typing import List, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from transformers.models.bloom.modeling_bloom import (
|
from transformers.models.bloom.modeling_bloom import (
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
BloomAttention,
|
BloomAttention,
|
||||||
@ -86,6 +87,7 @@ class BloomInferenceForwards:
|
|||||||
**deprecated_arguments,
|
**deprecated_arguments,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
|
This function is only used when pipeline is enabled.
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||||
@ -153,6 +155,7 @@ class BloomInferenceForwards:
|
|||||||
tp_group: Optional[dist.ProcessGroup] = None,
|
tp_group: Optional[dist.ProcessGroup] = None,
|
||||||
**deprecated_arguments,
|
**deprecated_arguments,
|
||||||
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
||||||
|
infer_state = infer_state or getattr(self, "infer_state", None)
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
# add warnings here
|
# add warnings here
|
||||||
@ -183,7 +186,7 @@ class BloomInferenceForwards:
|
|||||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||||
|
|
||||||
# first stage
|
# first stage
|
||||||
if stage_manager.is_first_stage():
|
if stage_manager is None or stage_manager.is_first_stage():
|
||||||
# check inputs and inputs embeds
|
# check inputs and inputs embeds
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
@ -255,10 +258,14 @@ class BloomInferenceForwards:
|
|||||||
|
|
||||||
infer_state.decode_layer_id = 0
|
infer_state.decode_layer_id = 0
|
||||||
|
|
||||||
|
if stage_index is None:
|
||||||
|
stage_index = (0, len(self.h))
|
||||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||||
if past_key_values is None:
|
if past_key_values is None:
|
||||||
past_key_values = tuple([None] * (end_idx - start_idx + 1))
|
past_key_values = tuple([None] * (end_idx - start_idx + 1))
|
||||||
|
|
||||||
|
# for HF api compatibility, kv-cache must be returned
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values):
|
for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values):
|
||||||
block = self.h[idx]
|
block = self.h[idx]
|
||||||
outputs = block(
|
outputs = block(
|
||||||
@ -274,8 +281,10 @@ class BloomInferenceForwards:
|
|||||||
|
|
||||||
infer_state.decode_layer_id += 1
|
infer_state.decode_layer_id += 1
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (outputs[2 if output_attentions else 1],)
|
||||||
|
|
||||||
if stage_manager.is_last_stage() or stage_manager.num_stages == 1:
|
if stage_manager is None or stage_manager.is_last_stage() or stage_manager.num_stages == 1:
|
||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
|
||||||
# update indices
|
# update indices
|
||||||
@ -283,6 +292,12 @@ class BloomInferenceForwards:
|
|||||||
infer_state.seq_len += 1
|
infer_state.seq_len += 1
|
||||||
infer_state.max_len_in_batch += 1
|
infer_state.max_len_in_batch += 1
|
||||||
|
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
if stage_manager is None:
|
||||||
|
if not return_dict:
|
||||||
|
return (hidden_states, next_cache)
|
||||||
|
return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=next_cache)
|
||||||
|
|
||||||
# always return dict for imediate stage
|
# always return dict for imediate stage
|
||||||
return {"hidden_states": hidden_states}
|
return {"hidden_states": hidden_states}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user