mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-10 12:22:28 +00:00
[fix] fix llama, mixtral benchmark zbv loss none bug; update mixtral & llama policy and modeling;
This commit is contained in:
parent
e234dfa236
commit
0ca16d5cbe
@ -432,7 +432,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
internal_inputs = {} if input_obj is None else input_obj
|
internal_inputs = {} if input_obj is None else input_obj
|
||||||
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
|
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
|
||||||
output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
|
output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
|
||||||
|
|
||||||
# last layer in model
|
# last layer in model
|
||||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
loss = criterion(output_obj, micro_batch) / self.num_microbatch
|
loss = criterion(output_obj, micro_batch) / self.num_microbatch
|
||||||
@ -500,12 +499,18 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None]
|
output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None]
|
||||||
output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None]
|
output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None]
|
||||||
|
|
||||||
optimizer.backward_by_grad(
|
try:
|
||||||
tensor=output_obj_,
|
ctx = optimizer.no_sync()
|
||||||
grad=output_obj_grad_,
|
except AttributeError:
|
||||||
inputs=input_obj_,
|
ctx = model_chunk.no_sync()
|
||||||
retain_graph=True,
|
|
||||||
)
|
with ctx:
|
||||||
|
optimizer.backward_by_grad(
|
||||||
|
tensor=output_obj_,
|
||||||
|
grad=output_obj_grad_,
|
||||||
|
inputs=input_obj_,
|
||||||
|
retain_graph=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Format output_obj_grad
|
# Format output_obj_grad
|
||||||
input_obj_grad = {}
|
input_obj_grad = {}
|
||||||
|
@ -267,98 +267,25 @@ class MixtralPipelineForwards:
|
|||||||
)
|
)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if stage_manager.is_interleave:
|
# retrieve input_ids and inputs_embeds
|
||||||
if stage_manager.use_zbv:
|
if stage_manager.is_first_stage():
|
||||||
# zbv
|
# retrieve input_ids and inputs_embeds
|
||||||
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 0:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
# retrieve input_ids and inputs_embeds
|
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
elif input_ids is not None:
|
||||||
raise ValueError(
|
batch_size, seq_length = input_ids.shape
|
||||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
elif inputs_embeds is not None:
|
||||||
)
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
elif input_ids is not None:
|
|
||||||
batch_size, seq_length = input_ids.shape
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
|
||||||
else:
|
|
||||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
|
||||||
hidden_states = inputs_embeds
|
|
||||||
else:
|
|
||||||
input_shape = hidden_states.shape[:-1]
|
|
||||||
batch_size, seq_length = input_shape
|
|
||||||
device = hidden_states.device
|
|
||||||
else:
|
else:
|
||||||
# interleaved
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
# retrieve input_ids and inputs_embeds
|
if inputs_embeds is None:
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
raise ValueError(
|
hidden_states = inputs_embeds
|
||||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
|
||||||
)
|
|
||||||
elif input_ids is not None:
|
|
||||||
batch_size, seq_length = input_ids.shape
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
|
||||||
else:
|
|
||||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
|
||||||
hidden_states = inputs_embeds
|
|
||||||
else:
|
|
||||||
input_shape = hidden_states.shape[:-1]
|
|
||||||
batch_size, seq_length = input_shape
|
|
||||||
device = hidden_states.device
|
|
||||||
else:
|
else:
|
||||||
# 1f1b or None
|
input_shape = hidden_states.shape[:-1]
|
||||||
if stage_manager.is_first_stage(): # No ignore_chunk=True for 1f1b
|
batch_size, seq_length = input_shape
|
||||||
# retrieve input_ids and inputs_embeds
|
device = hidden_states.device
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
|
||||||
)
|
|
||||||
elif input_ids is not None:
|
|
||||||
batch_size, seq_length = input_ids.shape
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
|
||||||
else:
|
|
||||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
|
||||||
hidden_states = inputs_embeds
|
|
||||||
else:
|
|
||||||
input_shape = hidden_states.shape[:-1]
|
|
||||||
batch_size, seq_length = input_shape
|
|
||||||
device = hidden_states.device
|
|
||||||
|
|
||||||
#######
|
|
||||||
# Attention, we support consider 1f1b, interleaved, zbv
|
|
||||||
#######
|
|
||||||
|
|
||||||
# # retrieve input_ids and inputs_embeds
|
|
||||||
# print(f"model_chunk_id {stage_manager.model_chunk_id} stage_manager {stage_manager.stage}")
|
|
||||||
# if stage_manager.is_first_stage():
|
|
||||||
# # retrieve input_ids and inputs_embeds
|
|
||||||
# if input_ids is not None and inputs_embeds is not None:
|
|
||||||
# raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
|
||||||
# elif input_ids is not None:
|
|
||||||
# batch_size, seq_length = input_ids.shape
|
|
||||||
# elif inputs_embeds is not None:
|
|
||||||
# batch_size, seq_length, _ = inputs_embeds.shape
|
|
||||||
# else:
|
|
||||||
# raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
|
||||||
# device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
||||||
# if inputs_embeds is None:
|
|
||||||
# inputs_embeds = self.embed_tokens(input_ids)
|
|
||||||
# hidden_states = inputs_embeds
|
|
||||||
# else:
|
|
||||||
# input_shape = hidden_states.shape[:-1]
|
|
||||||
# batch_size, seq_length = input_shape
|
|
||||||
# device = hidden_states.device
|
|
||||||
|
|
||||||
seq_length_with_past = seq_length
|
seq_length_with_past = seq_length
|
||||||
past_key_values_length = 0
|
past_key_values_length = 0
|
||||||
@ -462,22 +389,8 @@ class MixtralPipelineForwards:
|
|||||||
if output_router_logits:
|
if output_router_logits:
|
||||||
all_router_logits += (layer_outputs[-1],)
|
all_router_logits += (layer_outputs[-1],)
|
||||||
|
|
||||||
#######
|
if stage_manager.is_last_stage():
|
||||||
# Attention, we support consider 1f1b, interleaved, zbv
|
hidden_states = self.norm(hidden_states)
|
||||||
#######
|
|
||||||
if stage_manager.is_interleave:
|
|
||||||
if stage_manager.use_zbv:
|
|
||||||
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 1:
|
|
||||||
hidden_states = self.norm(hidden_states)
|
|
||||||
else:
|
|
||||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
|
||||||
hidden_states = self.norm(hidden_states)
|
|
||||||
else:
|
|
||||||
if stage_manager.is_last_stage(): # No ignore_chunk=True for 1f1b
|
|
||||||
hidden_states = self.norm(hidden_states)
|
|
||||||
|
|
||||||
# if stage_manager.is_last_stage():
|
|
||||||
# hidden_states = self.norm(hidden_states)
|
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
# add hidden states from the last decoder layer
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
@ -487,113 +400,30 @@ class MixtralPipelineForwards:
|
|||||||
if output_router_logits and past_router_logits is not None:
|
if output_router_logits and past_router_logits is not None:
|
||||||
all_router_logits = past_router_logits + all_router_logits
|
all_router_logits = past_router_logits + all_router_logits
|
||||||
|
|
||||||
#######
|
if stage_manager.is_last_stage():
|
||||||
# Attention, we support consider 1f1b, interleaved, zbv
|
if not return_dict:
|
||||||
#######
|
return tuple(
|
||||||
if stage_manager.is_interleave:
|
v
|
||||||
if stage_manager.use_zbv:
|
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
||||||
# zbv
|
if v is not None
|
||||||
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 1:
|
|
||||||
if not return_dict:
|
|
||||||
return tuple(
|
|
||||||
v
|
|
||||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
|
||||||
if v is not None
|
|
||||||
)
|
|
||||||
return MoeModelOutputWithPast(
|
|
||||||
last_hidden_state=hidden_states,
|
|
||||||
past_key_values=next_cache,
|
|
||||||
hidden_states=all_hidden_states,
|
|
||||||
attentions=all_self_attns,
|
|
||||||
router_logits=all_router_logits,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if output_router_logits:
|
|
||||||
return {
|
|
||||||
"hidden_states": hidden_states,
|
|
||||||
"past_router_logits": all_router_logits,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"hidden_states": hidden_states,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
# interlearved
|
|
||||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
|
||||||
if not return_dict:
|
|
||||||
return tuple(
|
|
||||||
v
|
|
||||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
|
||||||
if v is not None
|
|
||||||
)
|
|
||||||
return MoeModelOutputWithPast(
|
|
||||||
last_hidden_state=hidden_states,
|
|
||||||
past_key_values=next_cache,
|
|
||||||
hidden_states=all_hidden_states,
|
|
||||||
attentions=all_self_attns,
|
|
||||||
router_logits=all_router_logits,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if output_router_logits:
|
|
||||||
return {
|
|
||||||
"hidden_states": hidden_states,
|
|
||||||
"past_router_logits": all_router_logits,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"hidden_states": hidden_states,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
# 1f1b or other
|
|
||||||
if stage_manager.is_last_stage():
|
|
||||||
if not return_dict:
|
|
||||||
return tuple(
|
|
||||||
v
|
|
||||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
|
||||||
if v is not None
|
|
||||||
)
|
|
||||||
return MoeModelOutputWithPast(
|
|
||||||
last_hidden_state=hidden_states,
|
|
||||||
past_key_values=next_cache,
|
|
||||||
hidden_states=all_hidden_states,
|
|
||||||
attentions=all_self_attns,
|
|
||||||
router_logits=all_router_logits,
|
|
||||||
)
|
)
|
||||||
|
return MoeModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
router_logits=all_router_logits,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if output_router_logits:
|
||||||
|
return {
|
||||||
|
"hidden_states": hidden_states,
|
||||||
|
"past_router_logits": all_router_logits,
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
if output_router_logits:
|
return {
|
||||||
return {
|
"hidden_states": hidden_states,
|
||||||
"hidden_states": hidden_states,
|
}
|
||||||
"past_router_logits": all_router_logits,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"hidden_states": hidden_states,
|
|
||||||
}
|
|
||||||
|
|
||||||
# if stage_manager.is_last_stage():
|
|
||||||
# if not return_dict:
|
|
||||||
# return tuple(
|
|
||||||
# v
|
|
||||||
# for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
|
||||||
# if v is not None
|
|
||||||
# )
|
|
||||||
# return MoeModelOutputWithPast(
|
|
||||||
# last_hidden_state=hidden_states,
|
|
||||||
# past_key_values=next_cache,
|
|
||||||
# hidden_states=all_hidden_states,
|
|
||||||
# attentions=all_self_attns,
|
|
||||||
# router_logits=all_router_logits,
|
|
||||||
# )
|
|
||||||
# else:
|
|
||||||
# if output_router_logits:
|
|
||||||
# return {
|
|
||||||
# "hidden_states": hidden_states,
|
|
||||||
# "past_router_logits": all_router_logits,
|
|
||||||
# }
|
|
||||||
# else:
|
|
||||||
# return {
|
|
||||||
# "hidden_states": hidden_states,
|
|
||||||
# }
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def mixtral_for_causal_lm_forward(
|
def mixtral_for_causal_lm_forward(
|
||||||
@ -679,201 +509,51 @@ class MixtralPipelineForwards:
|
|||||||
)
|
)
|
||||||
past_key_values = None
|
past_key_values = None
|
||||||
|
|
||||||
#######
|
if stage_manager.is_last_stage():
|
||||||
# Attention, we support consider 1f1b, interleaved, zbv
|
hidden_states = outputs[0]
|
||||||
#######
|
logits = self.lm_head(hidden_states)
|
||||||
if stage_manager.is_interleave:
|
logits = logits.float()
|
||||||
if stage_manager.use_zbv:
|
loss = None
|
||||||
# zbv
|
if labels is not None:
|
||||||
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 1:
|
# Shift so that tokens < n predict n
|
||||||
hidden_states = outputs[0]
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
logits = self.lm_head(hidden_states)
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
logits = logits.float()
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
loss = None
|
aux_loss = None
|
||||||
if labels is not None:
|
if output_router_logits:
|
||||||
# Shift so that tokens < n predict n
|
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
# Flatten the tokens
|
|
||||||
loss_fct = CrossEntropyLoss()
|
|
||||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
||||||
shift_labels = shift_labels.view(-1)
|
|
||||||
# Enable model parallelism
|
|
||||||
shift_labels = shift_labels.to(shift_logits.device)
|
|
||||||
loss = loss_fct(shift_logits, shift_labels)
|
|
||||||
|
|
||||||
aux_loss = None
|
|
||||||
if output_router_logits:
|
|
||||||
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
|
|
||||||
if labels is not None:
|
|
||||||
loss += self.router_aux_loss_coef * aux_loss
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
if output_router_logits:
|
|
||||||
output = (aux_loss,) + output
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return MoeCausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
aux_loss=aux_loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=None,
|
|
||||||
hidden_states=outputs[0],
|
|
||||||
attentions=None,
|
|
||||||
router_logits=outputs[-1],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
out = {}
|
|
||||||
hidden_states = outputs.get("hidden_states")
|
|
||||||
out["hidden_states"] = hidden_states
|
|
||||||
if output_router_logits:
|
|
||||||
out["past_router_logits"] = outputs["past_router_logits"]
|
|
||||||
return out
|
|
||||||
else:
|
|
||||||
# interleaved
|
|
||||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
logits = self.lm_head(hidden_states)
|
|
||||||
logits = logits.float()
|
|
||||||
|
|
||||||
loss = None
|
|
||||||
if labels is not None:
|
|
||||||
# Shift so that tokens < n predict n
|
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
# Flatten the tokens
|
|
||||||
loss_fct = CrossEntropyLoss()
|
|
||||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
||||||
shift_labels = shift_labels.view(-1)
|
|
||||||
# Enable model parallelism
|
|
||||||
shift_labels = shift_labels.to(shift_logits.device)
|
|
||||||
loss = loss_fct(shift_logits, shift_labels)
|
|
||||||
|
|
||||||
aux_loss = None
|
|
||||||
if output_router_logits:
|
|
||||||
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
|
|
||||||
if labels is not None:
|
|
||||||
loss += self.router_aux_loss_coef * aux_loss
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
if output_router_logits:
|
|
||||||
output = (aux_loss,) + output
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return MoeCausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
aux_loss=aux_loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=None,
|
|
||||||
hidden_states=outputs[0],
|
|
||||||
attentions=None,
|
|
||||||
router_logits=outputs[-1],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
out = {}
|
|
||||||
hidden_states = outputs.get("hidden_states")
|
|
||||||
out["hidden_states"] = hidden_states
|
|
||||||
if output_router_logits:
|
|
||||||
out["past_router_logits"] = outputs["past_router_logits"]
|
|
||||||
return out
|
|
||||||
else:
|
|
||||||
# 1f1b or otherwise
|
|
||||||
if stage_manager.is_last_stage():
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
logits = self.lm_head(hidden_states)
|
|
||||||
logits = logits.float()
|
|
||||||
|
|
||||||
loss = None
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
# Shift so that tokens < n predict n
|
loss += self.router_aux_loss_coef * aux_loss
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
# Flatten the tokens
|
|
||||||
loss_fct = CrossEntropyLoss()
|
|
||||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
||||||
shift_labels = shift_labels.view(-1)
|
|
||||||
# Enable model parallelism
|
|
||||||
shift_labels = shift_labels.to(shift_logits.device)
|
|
||||||
loss = loss_fct(shift_logits, shift_labels)
|
|
||||||
|
|
||||||
aux_loss = None
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
if output_router_logits:
|
if output_router_logits:
|
||||||
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
|
output = (aux_loss,) + output
|
||||||
if labels is not None:
|
return (loss,) + output if loss is not None else output
|
||||||
loss += self.router_aux_loss_coef * aux_loss
|
|
||||||
|
|
||||||
if not return_dict:
|
return MoeCausalLMOutputWithPast(
|
||||||
output = (logits,) + outputs[1:]
|
loss=loss,
|
||||||
if output_router_logits:
|
aux_loss=aux_loss,
|
||||||
output = (aux_loss,) + output
|
logits=logits,
|
||||||
return (loss,) + output if loss is not None else output
|
past_key_values=None,
|
||||||
|
hidden_states=outputs[0],
|
||||||
return MoeCausalLMOutputWithPast(
|
attentions=None,
|
||||||
loss=loss,
|
router_logits=outputs[-1],
|
||||||
aux_loss=aux_loss,
|
)
|
||||||
logits=logits,
|
else:
|
||||||
past_key_values=None,
|
out = {}
|
||||||
hidden_states=outputs[0],
|
hidden_states = outputs.get("hidden_states")
|
||||||
attentions=None,
|
out["hidden_states"] = hidden_states
|
||||||
router_logits=outputs[-1],
|
if output_router_logits:
|
||||||
)
|
out["past_router_logits"] = outputs["past_router_logits"]
|
||||||
else:
|
return out
|
||||||
out = {}
|
|
||||||
hidden_states = outputs.get("hidden_states")
|
|
||||||
out["hidden_states"] = hidden_states
|
|
||||||
if output_router_logits:
|
|
||||||
out["past_router_logits"] = outputs["past_router_logits"]
|
|
||||||
return out
|
|
||||||
|
|
||||||
# if stage_manager.is_last_stage():
|
|
||||||
# hidden_states = outputs[0]
|
|
||||||
# logits = self.lm_head(hidden_states)
|
|
||||||
# logits = logits.float()
|
|
||||||
|
|
||||||
# loss = None
|
|
||||||
# if labels is not None:
|
|
||||||
# # Shift so that tokens < n predict n
|
|
||||||
# shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
# shift_labels = labels[..., 1:].contiguous()
|
|
||||||
# # Flatten the tokens
|
|
||||||
# loss_fct = CrossEntropyLoss()
|
|
||||||
# shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
||||||
# shift_labels = shift_labels.view(-1)
|
|
||||||
# # Enable model parallelism
|
|
||||||
# shift_labels = shift_labels.to(shift_logits.device)
|
|
||||||
# loss = loss_fct(shift_logits, shift_labels)
|
|
||||||
|
|
||||||
# aux_loss = None
|
|
||||||
# if output_router_logits:
|
|
||||||
# aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
|
|
||||||
# if labels is not None:
|
|
||||||
# loss += self.router_aux_loss_coef * aux_loss
|
|
||||||
|
|
||||||
# if not return_dict:
|
|
||||||
# output = (logits,) + outputs[1:]
|
|
||||||
# if output_router_logits:
|
|
||||||
# output = (aux_loss,) + output
|
|
||||||
# return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
# return MoeCausalLMOutputWithPast(
|
|
||||||
# loss=loss,
|
|
||||||
# aux_loss=aux_loss,
|
|
||||||
# logits=logits,
|
|
||||||
# past_key_values=None,
|
|
||||||
# hidden_states=outputs[0],
|
|
||||||
# attentions=None,
|
|
||||||
# router_logits=outputs[-1],
|
|
||||||
# )
|
|
||||||
# else:
|
|
||||||
# out = {}
|
|
||||||
# hidden_states = outputs.get("hidden_states")
|
|
||||||
# out["hidden_states"] = hidden_states
|
|
||||||
# if output_router_logits:
|
|
||||||
# out["past_router_logits"] = outputs["past_router_logits"]
|
|
||||||
# return out
|
|
||||||
|
|
||||||
|
|
||||||
def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
||||||
|
@ -343,18 +343,10 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
|
|||||||
"""Get pipeline layers for current stage."""
|
"""Get pipeline layers for current stage."""
|
||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
held_layers = super().get_held_layers()
|
held_layers = super().get_held_layers()
|
||||||
if stage_manager.is_interleave:
|
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
if stage_manager.use_zbv:
|
held_layers.append(self.model.lm_head)
|
||||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
elif stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
held_layers.append(self.model.lm_head)
|
held_layers.append(self.model.lm_head)
|
||||||
else:
|
|
||||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
|
||||||
held_layers.append(self.model.lm_head)
|
|
||||||
else:
|
|
||||||
if stage_manager.is_last_stage():
|
|
||||||
held_layers.append(self.model.lm_head)
|
|
||||||
# if stage_manager.is_last_stage():
|
|
||||||
# held_layers.append(self.model.lm_head)
|
|
||||||
return held_layers
|
return held_layers
|
||||||
|
|
||||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
|
@ -21,6 +21,7 @@ from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchF
|
|||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.lazy import LazyInitContext
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from colossalai.pipeline.schedule.v_schedule import PipelineGraph
|
||||||
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
@ -91,7 +92,7 @@ def main():
|
|||||||
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
|
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
|
||||||
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
|
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
|
||||||
|
|
||||||
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
|
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"])
|
||||||
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
|
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
|
||||||
parser.add_argument("--profile", action="store_true", help="Profile the code")
|
parser.add_argument("--profile", action="store_true", help="Profile the code")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -137,6 +138,11 @@ def main():
|
|||||||
# ==============================
|
# ==============================
|
||||||
# Initialize Booster
|
# Initialize Booster
|
||||||
# ==============================
|
# ==============================
|
||||||
|
if args.config in MODEL_CONFIGS:
|
||||||
|
config = MODEL_CONFIGS[args.config]
|
||||||
|
else:
|
||||||
|
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
|
||||||
|
|
||||||
use_empty_init = True
|
use_empty_init = True
|
||||||
if args.plugin == "gemini":
|
if args.plugin == "gemini":
|
||||||
plugin = GeminiPlugin(
|
plugin = GeminiPlugin(
|
||||||
@ -210,6 +216,23 @@ def main():
|
|||||||
fp8_communication=args.use_fp8_comm,
|
fp8_communication=args.use_fp8_comm,
|
||||||
)
|
)
|
||||||
elif args.plugin == "3d":
|
elif args.plugin == "3d":
|
||||||
|
if args.pp_style == "zbv":
|
||||||
|
mem_f = 34 * config.hidden_size + 5 * config.num_attention_heads * args.max_length
|
||||||
|
mem_w = -32 * config.hidden_size
|
||||||
|
mem_b = -mem_w - mem_f
|
||||||
|
scheduler_nodes = PipelineGraph(
|
||||||
|
n_stage=args.pp,
|
||||||
|
n_micro=args.batch_size // args.mbs,
|
||||||
|
f_cost=1000,
|
||||||
|
b_cost=1000,
|
||||||
|
w_cost=1000,
|
||||||
|
c_cost=1,
|
||||||
|
f_mem=mem_f,
|
||||||
|
b_mem=mem_b,
|
||||||
|
w_mem=mem_w,
|
||||||
|
).get_v_schedule()
|
||||||
|
else:
|
||||||
|
scheduler_nodes = None
|
||||||
plugin = HybridParallelPlugin(
|
plugin = HybridParallelPlugin(
|
||||||
tp_size=args.tp,
|
tp_size=args.tp,
|
||||||
pp_size=args.pp,
|
pp_size=args.pp,
|
||||||
@ -227,6 +250,7 @@ def main():
|
|||||||
overlap_allgather=args.overlap_allgather,
|
overlap_allgather=args.overlap_allgather,
|
||||||
use_fp8=args.use_fp8,
|
use_fp8=args.use_fp8,
|
||||||
fp8_communication=args.use_fp8_comm,
|
fp8_communication=args.use_fp8_comm,
|
||||||
|
scheduler_nodes=scheduler_nodes,
|
||||||
**hybrid_kwargs,
|
**hybrid_kwargs,
|
||||||
)
|
)
|
||||||
elif args.plugin == "3d_cpu":
|
elif args.plugin == "3d_cpu":
|
||||||
@ -256,10 +280,6 @@ def main():
|
|||||||
# ==============================
|
# ==============================
|
||||||
dp_size = getattr(plugin, "dp_size", coordinator.world_size)
|
dp_size = getattr(plugin, "dp_size", coordinator.world_size)
|
||||||
|
|
||||||
if args.config in MODEL_CONFIGS:
|
|
||||||
config = MODEL_CONFIGS[args.config]
|
|
||||||
else:
|
|
||||||
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
|
|
||||||
torch.cuda.manual_seed(42)
|
torch.cuda.manual_seed(42)
|
||||||
dataset = RandomDataset(
|
dataset = RandomDataset(
|
||||||
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
|
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
|
||||||
@ -334,8 +354,12 @@ def main():
|
|||||||
return_loss=True,
|
return_loss=True,
|
||||||
)
|
)
|
||||||
loss = outputs["loss"]
|
loss = outputs["loss"]
|
||||||
if dist.get_rank() == dist.get_world_size() - 1:
|
if args.pp_style == "zbv":
|
||||||
print(f"Step {step} loss: {loss}")
|
if dist.get_rank() == 0:
|
||||||
|
print(f"Step {step} loss: {loss}")
|
||||||
|
else:
|
||||||
|
if dist.get_rank() == dist.get_world_size() - 1:
|
||||||
|
print(f"Step {step} loss: {loss}")
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
@ -227,7 +227,6 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
optimizer = HybridAdam(model.parameters())
|
optimizer = HybridAdam(model.parameters())
|
||||||
# optimizer = torch.optim.SGD(model.parameters(), lr=1)
|
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
|
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
|
||||||
|
|
||||||
@ -258,8 +257,12 @@ def main():
|
|||||||
return_loss=True,
|
return_loss=True,
|
||||||
)
|
)
|
||||||
loss = outputs["loss"]
|
loss = outputs["loss"]
|
||||||
if dist.get_rank() == dist.get_world_size() - 1:
|
if args.pp_style == "zbv":
|
||||||
print(f"Step {step} loss: {loss}")
|
if dist.get_rank() == 0:
|
||||||
|
print(f"Step {step} loss: {loss}")
|
||||||
|
else:
|
||||||
|
if dist.get_rank() == dist.get_world_size() - 1:
|
||||||
|
print(f"Step {step} loss: {loss}")
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user