[fix] fix llama, mixtral benchmark zbv loss none bug; update mixtral & llama policy and modeling;

This commit is contained in:
duanjunwen 2024-10-11 07:32:43 +00:00
parent e234dfa236
commit 0ca16d5cbe
5 changed files with 134 additions and 430 deletions

View File

@ -432,7 +432,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
internal_inputs = {} if input_obj is None else input_obj internal_inputs = {} if input_obj is None else input_obj
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
output_obj = model_forward(model_chunk, micro_batch, internal_inputs) output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
# last layer in model # last layer in model
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
loss = criterion(output_obj, micro_batch) / self.num_microbatch loss = criterion(output_obj, micro_batch) / self.num_microbatch
@ -500,12 +499,18 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None] output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None]
output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None] output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None]
optimizer.backward_by_grad( try:
tensor=output_obj_, ctx = optimizer.no_sync()
grad=output_obj_grad_, except AttributeError:
inputs=input_obj_, ctx = model_chunk.no_sync()
retain_graph=True,
) with ctx:
optimizer.backward_by_grad(
tensor=output_obj_,
grad=output_obj_grad_,
inputs=input_obj_,
retain_graph=True,
)
# Format output_obj_grad # Format output_obj_grad
input_obj_grad = {} input_obj_grad = {}

View File

@ -267,98 +267,25 @@ class MixtralPipelineForwards:
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if stage_manager.is_interleave: # retrieve input_ids and inputs_embeds
if stage_manager.use_zbv: if stage_manager.is_first_stage():
# zbv # retrieve input_ids and inputs_embeds
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 0: if input_ids is not None and inputs_embeds is not None:
# retrieve input_ids and inputs_embeds raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
if input_ids is not None and inputs_embeds is not None: elif input_ids is not None:
raise ValueError( batch_size, seq_length = input_ids.shape
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" elif inputs_embeds is not None:
) batch_size, seq_length, _ = inputs_embeds.shape
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
device = hidden_states.device
else: else:
# interleaved raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
if stage_manager.is_first_stage(ignore_chunk=True): device = input_ids.device if input_ids is not None else inputs_embeds.device
# retrieve input_ids and inputs_embeds if inputs_embeds is None:
if input_ids is not None and inputs_embeds is not None: inputs_embeds = self.embed_tokens(input_ids)
raise ValueError( hidden_states = inputs_embeds
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
device = hidden_states.device
else: else:
# 1f1b or None input_shape = hidden_states.shape[:-1]
if stage_manager.is_first_stage(): # No ignore_chunk=True for 1f1b batch_size, seq_length = input_shape
# retrieve input_ids and inputs_embeds device = hidden_states.device
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
device = hidden_states.device
#######
# Attention, we support consider 1f1b, interleaved, zbv
#######
# # retrieve input_ids and inputs_embeds
# print(f"model_chunk_id {stage_manager.model_chunk_id} stage_manager {stage_manager.stage}")
# if stage_manager.is_first_stage():
# # retrieve input_ids and inputs_embeds
# if input_ids is not None and inputs_embeds is not None:
# raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
# elif input_ids is not None:
# batch_size, seq_length = input_ids.shape
# elif inputs_embeds is not None:
# batch_size, seq_length, _ = inputs_embeds.shape
# else:
# raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
# device = input_ids.device if input_ids is not None else inputs_embeds.device
# if inputs_embeds is None:
# inputs_embeds = self.embed_tokens(input_ids)
# hidden_states = inputs_embeds
# else:
# input_shape = hidden_states.shape[:-1]
# batch_size, seq_length = input_shape
# device = hidden_states.device
seq_length_with_past = seq_length seq_length_with_past = seq_length
past_key_values_length = 0 past_key_values_length = 0
@ -462,22 +389,8 @@ class MixtralPipelineForwards:
if output_router_logits: if output_router_logits:
all_router_logits += (layer_outputs[-1],) all_router_logits += (layer_outputs[-1],)
####### if stage_manager.is_last_stage():
# Attention, we support consider 1f1b, interleaved, zbv hidden_states = self.norm(hidden_states)
#######
if stage_manager.is_interleave:
if stage_manager.use_zbv:
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 1:
hidden_states = self.norm(hidden_states)
else:
if stage_manager.is_last_stage(ignore_chunk=True):
hidden_states = self.norm(hidden_states)
else:
if stage_manager.is_last_stage(): # No ignore_chunk=True for 1f1b
hidden_states = self.norm(hidden_states)
# if stage_manager.is_last_stage():
# hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer # add hidden states from the last decoder layer
if output_hidden_states: if output_hidden_states:
@ -487,113 +400,30 @@ class MixtralPipelineForwards:
if output_router_logits and past_router_logits is not None: if output_router_logits and past_router_logits is not None:
all_router_logits = past_router_logits + all_router_logits all_router_logits = past_router_logits + all_router_logits
####### if stage_manager.is_last_stage():
# Attention, we support consider 1f1b, interleaved, zbv if not return_dict:
####### return tuple(
if stage_manager.is_interleave: v
if stage_manager.use_zbv: for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
# zbv if v is not None
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 1:
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
if v is not None
)
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
router_logits=all_router_logits,
)
else:
if output_router_logits:
return {
"hidden_states": hidden_states,
"past_router_logits": all_router_logits,
}
else:
return {
"hidden_states": hidden_states,
}
else:
# interlearved
if stage_manager.is_last_stage(ignore_chunk=True):
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
if v is not None
)
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
router_logits=all_router_logits,
)
else:
if output_router_logits:
return {
"hidden_states": hidden_states,
"past_router_logits": all_router_logits,
}
else:
return {
"hidden_states": hidden_states,
}
else:
# 1f1b or other
if stage_manager.is_last_stage():
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
if v is not None
)
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
router_logits=all_router_logits,
) )
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
router_logits=all_router_logits,
)
else:
if output_router_logits:
return {
"hidden_states": hidden_states,
"past_router_logits": all_router_logits,
}
else: else:
if output_router_logits: return {
return { "hidden_states": hidden_states,
"hidden_states": hidden_states, }
"past_router_logits": all_router_logits,
}
else:
return {
"hidden_states": hidden_states,
}
# if stage_manager.is_last_stage():
# if not return_dict:
# return tuple(
# v
# for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
# if v is not None
# )
# return MoeModelOutputWithPast(
# last_hidden_state=hidden_states,
# past_key_values=next_cache,
# hidden_states=all_hidden_states,
# attentions=all_self_attns,
# router_logits=all_router_logits,
# )
# else:
# if output_router_logits:
# return {
# "hidden_states": hidden_states,
# "past_router_logits": all_router_logits,
# }
# else:
# return {
# "hidden_states": hidden_states,
# }
@staticmethod @staticmethod
def mixtral_for_causal_lm_forward( def mixtral_for_causal_lm_forward(
@ -679,201 +509,51 @@ class MixtralPipelineForwards:
) )
past_key_values = None past_key_values = None
####### if stage_manager.is_last_stage():
# Attention, we support consider 1f1b, interleaved, zbv hidden_states = outputs[0]
####### logits = self.lm_head(hidden_states)
if stage_manager.is_interleave: logits = logits.float()
if stage_manager.use_zbv: loss = None
# zbv if labels is not None:
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 1: # Shift so that tokens < n predict n
hidden_states = outputs[0] shift_logits = logits[..., :-1, :].contiguous()
logits = self.lm_head(hidden_states) shift_labels = labels[..., 1:].contiguous()
logits = logits.float() # Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
loss = None aux_loss = None
if labels is not None: if output_router_logits:
# Shift so that tokens < n predict n aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss
if not return_dict:
output = (logits,) + outputs[1:]
if output_router_logits:
output = (aux_loss,) + output
return (loss,) + output if loss is not None else output
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=None,
hidden_states=outputs[0],
attentions=None,
router_logits=outputs[-1],
)
else:
out = {}
hidden_states = outputs.get("hidden_states")
out["hidden_states"] = hidden_states
if output_router_logits:
out["past_router_logits"] = outputs["past_router_logits"]
return out
else:
# interleaved
if stage_manager.is_last_stage(ignore_chunk=True):
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss
if not return_dict:
output = (logits,) + outputs[1:]
if output_router_logits:
output = (aux_loss,) + output
return (loss,) + output if loss is not None else output
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=None,
hidden_states=outputs[0],
attentions=None,
router_logits=outputs[-1],
)
else:
out = {}
hidden_states = outputs.get("hidden_states")
out["hidden_states"] = hidden_states
if output_router_logits:
out["past_router_logits"] = outputs["past_router_logits"]
return out
else:
# 1f1b or otherwise
if stage_manager.is_last_stage():
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None: if labels is not None:
# Shift so that tokens < n predict n loss += self.router_aux_loss_coef * aux_loss
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
aux_loss = None if not return_dict:
output = (logits,) + outputs[1:]
if output_router_logits: if output_router_logits:
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok) output = (aux_loss,) + output
if labels is not None: return (loss,) + output if loss is not None else output
loss += self.router_aux_loss_coef * aux_loss
if not return_dict: return MoeCausalLMOutputWithPast(
output = (logits,) + outputs[1:] loss=loss,
if output_router_logits: aux_loss=aux_loss,
output = (aux_loss,) + output logits=logits,
return (loss,) + output if loss is not None else output past_key_values=None,
hidden_states=outputs[0],
return MoeCausalLMOutputWithPast( attentions=None,
loss=loss, router_logits=outputs[-1],
aux_loss=aux_loss, )
logits=logits, else:
past_key_values=None, out = {}
hidden_states=outputs[0], hidden_states = outputs.get("hidden_states")
attentions=None, out["hidden_states"] = hidden_states
router_logits=outputs[-1], if output_router_logits:
) out["past_router_logits"] = outputs["past_router_logits"]
else: return out
out = {}
hidden_states = outputs.get("hidden_states")
out["hidden_states"] = hidden_states
if output_router_logits:
out["past_router_logits"] = outputs["past_router_logits"]
return out
# if stage_manager.is_last_stage():
# hidden_states = outputs[0]
# logits = self.lm_head(hidden_states)
# logits = logits.float()
# loss = None
# if labels is not None:
# # Shift so that tokens < n predict n
# shift_logits = logits[..., :-1, :].contiguous()
# shift_labels = labels[..., 1:].contiguous()
# # Flatten the tokens
# loss_fct = CrossEntropyLoss()
# shift_logits = shift_logits.view(-1, self.config.vocab_size)
# shift_labels = shift_labels.view(-1)
# # Enable model parallelism
# shift_labels = shift_labels.to(shift_logits.device)
# loss = loss_fct(shift_logits, shift_labels)
# aux_loss = None
# if output_router_logits:
# aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
# if labels is not None:
# loss += self.router_aux_loss_coef * aux_loss
# if not return_dict:
# output = (logits,) + outputs[1:]
# if output_router_logits:
# output = (aux_loss,) + output
# return (loss,) + output if loss is not None else output
# return MoeCausalLMOutputWithPast(
# loss=loss,
# aux_loss=aux_loss,
# logits=logits,
# past_key_values=None,
# hidden_states=outputs[0],
# attentions=None,
# router_logits=outputs[-1],
# )
# else:
# out = {}
# hidden_states = outputs.get("hidden_states")
# out["hidden_states"] = hidden_states
# if output_router_logits:
# out["past_router_logits"] = outputs["past_router_logits"]
# return out
def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):

View File

@ -343,18 +343,10 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
"""Get pipeline layers for current stage.""" """Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
if stage_manager.is_interleave: if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
if stage_manager.use_zbv: held_layers.append(self.model.lm_head)
if stage_manager.is_first_stage(ignore_chunk=True): elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head) held_layers.append(self.model.lm_head)
else:
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
else:
if stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
# if stage_manager.is_last_stage():
# held_layers.append(self.model.lm_head)
return held_layers return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]: def get_shared_params(self) -> List[Dict[int, Tensor]]:

View File

@ -21,6 +21,7 @@ from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchF
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.pipeline.schedule.v_schedule import PipelineGraph
from colossalai.shardformer import PipelineGradientCheckpointConfig from colossalai.shardformer import PipelineGradientCheckpointConfig
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
@ -91,7 +92,7 @@ def main():
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled") parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"]) parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"])
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval) parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
parser.add_argument("--profile", action="store_true", help="Profile the code") parser.add_argument("--profile", action="store_true", help="Profile the code")
parser.add_argument( parser.add_argument(
@ -137,6 +138,11 @@ def main():
# ============================== # ==============================
# Initialize Booster # Initialize Booster
# ============================== # ==============================
if args.config in MODEL_CONFIGS:
config = MODEL_CONFIGS[args.config]
else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
use_empty_init = True use_empty_init = True
if args.plugin == "gemini": if args.plugin == "gemini":
plugin = GeminiPlugin( plugin = GeminiPlugin(
@ -210,6 +216,23 @@ def main():
fp8_communication=args.use_fp8_comm, fp8_communication=args.use_fp8_comm,
) )
elif args.plugin == "3d": elif args.plugin == "3d":
if args.pp_style == "zbv":
mem_f = 34 * config.hidden_size + 5 * config.num_attention_heads * args.max_length
mem_w = -32 * config.hidden_size
mem_b = -mem_w - mem_f
scheduler_nodes = PipelineGraph(
n_stage=args.pp,
n_micro=args.batch_size // args.mbs,
f_cost=1000,
b_cost=1000,
w_cost=1000,
c_cost=1,
f_mem=mem_f,
b_mem=mem_b,
w_mem=mem_w,
).get_v_schedule()
else:
scheduler_nodes = None
plugin = HybridParallelPlugin( plugin = HybridParallelPlugin(
tp_size=args.tp, tp_size=args.tp,
pp_size=args.pp, pp_size=args.pp,
@ -227,6 +250,7 @@ def main():
overlap_allgather=args.overlap_allgather, overlap_allgather=args.overlap_allgather,
use_fp8=args.use_fp8, use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm, fp8_communication=args.use_fp8_comm,
scheduler_nodes=scheduler_nodes,
**hybrid_kwargs, **hybrid_kwargs,
) )
elif args.plugin == "3d_cpu": elif args.plugin == "3d_cpu":
@ -256,10 +280,6 @@ def main():
# ============================== # ==============================
dp_size = getattr(plugin, "dp_size", coordinator.world_size) dp_size = getattr(plugin, "dp_size", coordinator.world_size)
if args.config in MODEL_CONFIGS:
config = MODEL_CONFIGS[args.config]
else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
torch.cuda.manual_seed(42) torch.cuda.manual_seed(42)
dataset = RandomDataset( dataset = RandomDataset(
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
@ -334,8 +354,12 @@ def main():
return_loss=True, return_loss=True,
) )
loss = outputs["loss"] loss = outputs["loss"]
if dist.get_rank() == dist.get_world_size() - 1: if args.pp_style == "zbv":
print(f"Step {step} loss: {loss}") if dist.get_rank() == 0:
print(f"Step {step} loss: {loss}")
else:
if dist.get_rank() == dist.get_world_size() - 1:
print(f"Step {step} loss: {loss}")
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()

View File

@ -227,7 +227,6 @@ def main():
) )
optimizer = HybridAdam(model.parameters()) optimizer = HybridAdam(model.parameters())
# optimizer = torch.optim.SGD(model.parameters(), lr=1)
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader) model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
@ -258,8 +257,12 @@ def main():
return_loss=True, return_loss=True,
) )
loss = outputs["loss"] loss = outputs["loss"]
if dist.get_rank() == dist.get_world_size() - 1: if args.pp_style == "zbv":
print(f"Step {step} loss: {loss}") if dist.get_rank() == 0:
print(f"Step {step} loss: {loss}")
else:
if dist.get_rank() == dist.get_world_size() - 1:
print(f"Step {step} loss: {loss}")
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()