mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 06:52:46 +00:00
[feat] update MixtralPipelineForwards --> mixtral_model_forward; support zbv;
This commit is contained in:
parent
9ee80fc828
commit
72b507a7be
@ -267,26 +267,98 @@ class MixtralPipelineForwards:
|
|||||||
)
|
)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
# retrieve input_ids and inputs_embeds
|
if stage_manager.is_interleave:
|
||||||
print(f"model_chunk_id {stage_manager.model_chunk_id} stage_manager {stage_manager.stage}")
|
if stage_manager.use_zbv:
|
||||||
if stage_manager.is_first_stage():
|
# zbv
|
||||||
# retrieve input_ids and inputs_embeds
|
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 0:
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
# retrieve input_ids and inputs_embeds
|
||||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
elif input_ids is not None:
|
raise ValueError(
|
||||||
batch_size, seq_length = input_ids.shape
|
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||||
elif inputs_embeds is not None:
|
)
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
elif input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
else:
|
||||||
|
input_shape = hidden_states.shape[:-1]
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
device = hidden_states.device
|
||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
# interleaved
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
if inputs_embeds is None:
|
# retrieve input_ids and inputs_embeds
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
hidden_states = inputs_embeds
|
raise ValueError(
|
||||||
|
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||||
|
)
|
||||||
|
elif input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
else:
|
||||||
|
input_shape = hidden_states.shape[:-1]
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
device = hidden_states.device
|
||||||
else:
|
else:
|
||||||
input_shape = hidden_states.shape[:-1]
|
# 1f1b or None
|
||||||
batch_size, seq_length = input_shape
|
if stage_manager.is_first_stage(): # No ignore_chunk=True for 1f1b
|
||||||
device = hidden_states.device
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||||
|
)
|
||||||
|
elif input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
else:
|
||||||
|
input_shape = hidden_states.shape[:-1]
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
device = hidden_states.device
|
||||||
|
|
||||||
|
#######
|
||||||
|
# Attention, we support consider 1f1b, interleaved, zbv
|
||||||
|
#######
|
||||||
|
|
||||||
|
# # retrieve input_ids and inputs_embeds
|
||||||
|
# print(f"model_chunk_id {stage_manager.model_chunk_id} stage_manager {stage_manager.stage}")
|
||||||
|
# if stage_manager.is_first_stage():
|
||||||
|
# # retrieve input_ids and inputs_embeds
|
||||||
|
# if input_ids is not None and inputs_embeds is not None:
|
||||||
|
# raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||||
|
# elif input_ids is not None:
|
||||||
|
# batch_size, seq_length = input_ids.shape
|
||||||
|
# elif inputs_embeds is not None:
|
||||||
|
# batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
# else:
|
||||||
|
# raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||||
|
# device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
# if inputs_embeds is None:
|
||||||
|
# inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
# hidden_states = inputs_embeds
|
||||||
|
# else:
|
||||||
|
# input_shape = hidden_states.shape[:-1]
|
||||||
|
# batch_size, seq_length = input_shape
|
||||||
|
# device = hidden_states.device
|
||||||
|
|
||||||
seq_length_with_past = seq_length
|
seq_length_with_past = seq_length
|
||||||
past_key_values_length = 0
|
past_key_values_length = 0
|
||||||
@ -390,8 +462,22 @@ class MixtralPipelineForwards:
|
|||||||
if output_router_logits:
|
if output_router_logits:
|
||||||
all_router_logits += (layer_outputs[-1],)
|
all_router_logits += (layer_outputs[-1],)
|
||||||
|
|
||||||
if stage_manager.is_last_stage():
|
#######
|
||||||
hidden_states = self.norm(hidden_states)
|
# Attention, we support consider 1f1b, interleaved, zbv
|
||||||
|
#######
|
||||||
|
if stage_manager.is_interleave:
|
||||||
|
if stage_manager.use_zbv:
|
||||||
|
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 1:
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
else:
|
||||||
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
else:
|
||||||
|
if stage_manager.is_last_stage(): # No ignore_chunk=True for 1f1b
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# if stage_manager.is_last_stage():
|
||||||
|
# hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
# add hidden states from the last decoder layer
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
@ -400,30 +486,114 @@ class MixtralPipelineForwards:
|
|||||||
|
|
||||||
if output_router_logits and past_router_logits is not None:
|
if output_router_logits and past_router_logits is not None:
|
||||||
all_router_logits = past_router_logits + all_router_logits
|
all_router_logits = past_router_logits + all_router_logits
|
||||||
if stage_manager.is_last_stage():
|
|
||||||
if not return_dict:
|
#######
|
||||||
return tuple(
|
# Attention, we support consider 1f1b, interleaved, zbv
|
||||||
v
|
#######
|
||||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
if stage_manager.is_interleave:
|
||||||
if v is not None
|
if stage_manager.use_zbv:
|
||||||
)
|
# zbv
|
||||||
return MoeModelOutputWithPast(
|
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 1:
|
||||||
last_hidden_state=hidden_states,
|
if not return_dict:
|
||||||
past_key_values=next_cache,
|
return tuple(
|
||||||
hidden_states=all_hidden_states,
|
v
|
||||||
attentions=all_self_attns,
|
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
||||||
router_logits=all_router_logits,
|
if v is not None
|
||||||
)
|
)
|
||||||
else:
|
return MoeModelOutputWithPast(
|
||||||
if output_router_logits:
|
last_hidden_state=hidden_states,
|
||||||
return {
|
past_key_values=next_cache,
|
||||||
"hidden_states": hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
"past_router_logits": all_router_logits,
|
attentions=all_self_attns,
|
||||||
}
|
router_logits=all_router_logits,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if output_router_logits:
|
||||||
|
return {
|
||||||
|
"hidden_states": hidden_states,
|
||||||
|
"past_router_logits": all_router_logits,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"hidden_states": hidden_states,
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
return {
|
# interlearved
|
||||||
"hidden_states": hidden_states,
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
}
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
|
return MoeModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
router_logits=all_router_logits,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if output_router_logits:
|
||||||
|
return {
|
||||||
|
"hidden_states": hidden_states,
|
||||||
|
"past_router_logits": all_router_logits,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"hidden_states": hidden_states,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# 1f1b or other
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
|
return MoeModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
router_logits=all_router_logits,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if output_router_logits:
|
||||||
|
return {
|
||||||
|
"hidden_states": hidden_states,
|
||||||
|
"past_router_logits": all_router_logits,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"hidden_states": hidden_states,
|
||||||
|
}
|
||||||
|
|
||||||
|
# if stage_manager.is_last_stage():
|
||||||
|
# if not return_dict:
|
||||||
|
# return tuple(
|
||||||
|
# v
|
||||||
|
# for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
||||||
|
# if v is not None
|
||||||
|
# )
|
||||||
|
# return MoeModelOutputWithPast(
|
||||||
|
# last_hidden_state=hidden_states,
|
||||||
|
# past_key_values=next_cache,
|
||||||
|
# hidden_states=all_hidden_states,
|
||||||
|
# attentions=all_self_attns,
|
||||||
|
# router_logits=all_router_logits,
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# if output_router_logits:
|
||||||
|
# return {
|
||||||
|
# "hidden_states": hidden_states,
|
||||||
|
# "past_router_logits": all_router_logits,
|
||||||
|
# }
|
||||||
|
# else:
|
||||||
|
# return {
|
||||||
|
# "hidden_states": hidden_states,
|
||||||
|
# }
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def mixtral_for_causal_lm_forward(
|
def mixtral_for_causal_lm_forward(
|
||||||
|
Loading…
Reference in New Issue
Block a user