diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 5c25c5bfa..c928a207c 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -432,7 +432,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): internal_inputs = {} if input_obj is None else input_obj internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] output_obj = model_forward(model_chunk, micro_batch, internal_inputs) - # last layer in model if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): loss = criterion(output_obj, micro_batch) / self.num_microbatch @@ -500,12 +499,18 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None] output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None] - optimizer.backward_by_grad( - tensor=output_obj_, - grad=output_obj_grad_, - inputs=input_obj_, - retain_graph=True, - ) + try: + ctx = optimizer.no_sync() + except AttributeError: + ctx = model_chunk.no_sync() + + with ctx: + optimizer.backward_by_grad( + tensor=output_obj_, + grad=output_obj_grad_, + inputs=input_obj_, + retain_graph=True, + ) # Format output_obj_grad input_obj_grad = {} diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index 3709af54c..a783b5c5e 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -267,98 +267,25 @@ class MixtralPipelineForwards: ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if stage_manager.is_interleave: - if stage_manager.use_zbv: - # zbv - if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 0: - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - device = input_ids.device if input_ids is not None else inputs_embeds.device - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - hidden_states = inputs_embeds - else: - input_shape = hidden_states.shape[:-1] - batch_size, seq_length = input_shape - device = hidden_states.device + # retrieve input_ids and inputs_embeds + if stage_manager.is_first_stage(): + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape else: - # interleaved - if stage_manager.is_first_stage(ignore_chunk=True): - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - device = input_ids.device if input_ids is not None else inputs_embeds.device - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - hidden_states = inputs_embeds - else: - input_shape = hidden_states.shape[:-1] - batch_size, seq_length = input_shape - device = hidden_states.device + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds else: - # 1f1b or None - if stage_manager.is_first_stage(): # No ignore_chunk=True for 1f1b - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - device = input_ids.device if input_ids is not None else inputs_embeds.device - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - hidden_states = inputs_embeds - else: - input_shape = hidden_states.shape[:-1] - batch_size, seq_length = input_shape - device = hidden_states.device - - ####### - # Attention, we support consider 1f1b, interleaved, zbv - ####### - - # # retrieve input_ids and inputs_embeds - # print(f"model_chunk_id {stage_manager.model_chunk_id} stage_manager {stage_manager.stage}") - # if stage_manager.is_first_stage(): - # # retrieve input_ids and inputs_embeds - # if input_ids is not None and inputs_embeds is not None: - # raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - # elif input_ids is not None: - # batch_size, seq_length = input_ids.shape - # elif inputs_embeds is not None: - # batch_size, seq_length, _ = inputs_embeds.shape - # else: - # raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # device = input_ids.device if input_ids is not None else inputs_embeds.device - # if inputs_embeds is None: - # inputs_embeds = self.embed_tokens(input_ids) - # hidden_states = inputs_embeds - # else: - # input_shape = hidden_states.shape[:-1] - # batch_size, seq_length = input_shape - # device = hidden_states.device + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device seq_length_with_past = seq_length past_key_values_length = 0 @@ -462,22 +389,8 @@ class MixtralPipelineForwards: if output_router_logits: all_router_logits += (layer_outputs[-1],) - ####### - # Attention, we support consider 1f1b, interleaved, zbv - ####### - if stage_manager.is_interleave: - if stage_manager.use_zbv: - if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 1: - hidden_states = self.norm(hidden_states) - else: - if stage_manager.is_last_stage(ignore_chunk=True): - hidden_states = self.norm(hidden_states) - else: - if stage_manager.is_last_stage(): # No ignore_chunk=True for 1f1b - hidden_states = self.norm(hidden_states) - - # if stage_manager.is_last_stage(): - # hidden_states = self.norm(hidden_states) + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: @@ -487,113 +400,30 @@ class MixtralPipelineForwards: if output_router_logits and past_router_logits is not None: all_router_logits = past_router_logits + all_router_logits - ####### - # Attention, we support consider 1f1b, interleaved, zbv - ####### - if stage_manager.is_interleave: - if stage_manager.use_zbv: - # zbv - if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 1: - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] - if v is not None - ) - return MoeModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - router_logits=all_router_logits, - ) - else: - if output_router_logits: - return { - "hidden_states": hidden_states, - "past_router_logits": all_router_logits, - } - else: - return { - "hidden_states": hidden_states, - } - else: - # interlearved - if stage_manager.is_last_stage(ignore_chunk=True): - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] - if v is not None - ) - return MoeModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - router_logits=all_router_logits, - ) - else: - if output_router_logits: - return { - "hidden_states": hidden_states, - "past_router_logits": all_router_logits, - } - else: - return { - "hidden_states": hidden_states, - } - else: - # 1f1b or other - if stage_manager.is_last_stage(): - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] - if v is not None - ) - return MoeModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - router_logits=all_router_logits, + if stage_manager.is_last_stage(): + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + else: + if output_router_logits: + return { + "hidden_states": hidden_states, + "past_router_logits": all_router_logits, + } else: - if output_router_logits: - return { - "hidden_states": hidden_states, - "past_router_logits": all_router_logits, - } - else: - return { - "hidden_states": hidden_states, - } - - # if stage_manager.is_last_stage(): - # if not return_dict: - # return tuple( - # v - # for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] - # if v is not None - # ) - # return MoeModelOutputWithPast( - # last_hidden_state=hidden_states, - # past_key_values=next_cache, - # hidden_states=all_hidden_states, - # attentions=all_self_attns, - # router_logits=all_router_logits, - # ) - # else: - # if output_router_logits: - # return { - # "hidden_states": hidden_states, - # "past_router_logits": all_router_logits, - # } - # else: - # return { - # "hidden_states": hidden_states, - # } + return { + "hidden_states": hidden_states, + } @staticmethod def mixtral_for_causal_lm_forward( @@ -679,201 +509,51 @@ class MixtralPipelineForwards: ) past_key_values = None - ####### - # Attention, we support consider 1f1b, interleaved, zbv - ####### - if stage_manager.is_interleave: - if stage_manager.use_zbv: - # zbv - if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 1: - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() + if stage_manager.is_last_stage(): + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - aux_loss = None - if output_router_logits: - aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok) - if labels is not None: - loss += self.router_aux_loss_coef * aux_loss - - if not return_dict: - output = (logits,) + outputs[1:] - if output_router_logits: - output = (aux_loss,) + output - return (loss,) + output if loss is not None else output - - return MoeCausalLMOutputWithPast( - loss=loss, - aux_loss=aux_loss, - logits=logits, - past_key_values=None, - hidden_states=outputs[0], - attentions=None, - router_logits=outputs[-1], - ) - else: - out = {} - hidden_states = outputs.get("hidden_states") - out["hidden_states"] = hidden_states - if output_router_logits: - out["past_router_logits"] = outputs["past_router_logits"] - return out - else: - # interleaved - if stage_manager.is_last_stage(ignore_chunk=True): - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - aux_loss = None - if output_router_logits: - aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok) - if labels is not None: - loss += self.router_aux_loss_coef * aux_loss - - if not return_dict: - output = (logits,) + outputs[1:] - if output_router_logits: - output = (aux_loss,) + output - return (loss,) + output if loss is not None else output - - return MoeCausalLMOutputWithPast( - loss=loss, - aux_loss=aux_loss, - logits=logits, - past_key_values=None, - hidden_states=outputs[0], - attentions=None, - router_logits=outputs[-1], - ) - else: - out = {} - hidden_states = outputs.get("hidden_states") - out["hidden_states"] = hidden_states - if output_router_logits: - out["past_router_logits"] = outputs["past_router_logits"] - return out - else: - # 1f1b or otherwise - if stage_manager.is_last_stage(): - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() - - loss = None + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok) if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) + loss += self.router_aux_loss_coef * aux_loss - aux_loss = None + if not return_dict: + output = (logits,) + outputs[1:] if output_router_logits: - aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok) - if labels is not None: - loss += self.router_aux_loss_coef * aux_loss + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output - if not return_dict: - output = (logits,) + outputs[1:] - if output_router_logits: - output = (aux_loss,) + output - return (loss,) + output if loss is not None else output - - return MoeCausalLMOutputWithPast( - loss=loss, - aux_loss=aux_loss, - logits=logits, - past_key_values=None, - hidden_states=outputs[0], - attentions=None, - router_logits=outputs[-1], - ) - else: - out = {} - hidden_states = outputs.get("hidden_states") - out["hidden_states"] = hidden_states - if output_router_logits: - out["past_router_logits"] = outputs["past_router_logits"] - return out - - # if stage_manager.is_last_stage(): - # hidden_states = outputs[0] - # logits = self.lm_head(hidden_states) - # logits = logits.float() - - # loss = None - # if labels is not None: - # # Shift so that tokens < n predict n - # shift_logits = logits[..., :-1, :].contiguous() - # shift_labels = labels[..., 1:].contiguous() - # # Flatten the tokens - # loss_fct = CrossEntropyLoss() - # shift_logits = shift_logits.view(-1, self.config.vocab_size) - # shift_labels = shift_labels.view(-1) - # # Enable model parallelism - # shift_labels = shift_labels.to(shift_logits.device) - # loss = loss_fct(shift_logits, shift_labels) - - # aux_loss = None - # if output_router_logits: - # aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok) - # if labels is not None: - # loss += self.router_aux_loss_coef * aux_loss - - # if not return_dict: - # output = (logits,) + outputs[1:] - # if output_router_logits: - # output = (aux_loss,) + output - # return (loss,) + output if loss is not None else output - - # return MoeCausalLMOutputWithPast( - # loss=loss, - # aux_loss=aux_loss, - # logits=logits, - # past_key_values=None, - # hidden_states=outputs[0], - # attentions=None, - # router_logits=outputs[-1], - # ) - # else: - # out = {} - # hidden_states = outputs.get("hidden_states") - # out["hidden_states"] = hidden_states - # if output_router_logits: - # out["past_router_logits"] = outputs["past_router_logits"] - # return out + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=None, + hidden_states=outputs[0], + attentions=None, + router_logits=outputs[-1], + ) + else: + out = {} + hidden_states = outputs.get("hidden_states") + out["hidden_states"] = hidden_states + if output_router_logits: + out["past_router_logits"] = outputs["past_router_logits"] + return out def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index a8cd49dc1..9d8d2b54b 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -343,18 +343,10 @@ class MixtralForCausalLMPolicy(MixtralPolicy): """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_interleave: - if stage_manager.use_zbv: - if stage_manager.is_first_stage(ignore_chunk=True): - held_layers.append(self.model.lm_head) - else: - if stage_manager.is_last_stage(ignore_chunk=True): - held_layers.append(self.model.lm_head) - else: - if stage_manager.is_last_stage(): - held_layers.append(self.model.lm_head) - # if stage_manager.is_last_stage(): - # held_layers.append(self.model.lm_head) + if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(self.model.lm_head) + elif stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 0e88fabf1..0f418edb6 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -21,6 +21,7 @@ from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchF from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam +from colossalai.pipeline.schedule.v_schedule import PipelineGraph from colossalai.shardformer import PipelineGradientCheckpointConfig warnings.filterwarnings("ignore") @@ -91,7 +92,7 @@ def main(): parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled") parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) - parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"]) + parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"]) parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval) parser.add_argument("--profile", action="store_true", help="Profile the code") parser.add_argument( @@ -137,6 +138,11 @@ def main(): # ============================== # Initialize Booster # ============================== + if args.config in MODEL_CONFIGS: + config = MODEL_CONFIGS[args.config] + else: + config = AutoConfig.from_pretrained(args.config, trust_remote_code=True) + use_empty_init = True if args.plugin == "gemini": plugin = GeminiPlugin( @@ -210,6 +216,23 @@ def main(): fp8_communication=args.use_fp8_comm, ) elif args.plugin == "3d": + if args.pp_style == "zbv": + mem_f = 34 * config.hidden_size + 5 * config.num_attention_heads * args.max_length + mem_w = -32 * config.hidden_size + mem_b = -mem_w - mem_f + scheduler_nodes = PipelineGraph( + n_stage=args.pp, + n_micro=args.batch_size // args.mbs, + f_cost=1000, + b_cost=1000, + w_cost=1000, + c_cost=1, + f_mem=mem_f, + b_mem=mem_b, + w_mem=mem_w, + ).get_v_schedule() + else: + scheduler_nodes = None plugin = HybridParallelPlugin( tp_size=args.tp, pp_size=args.pp, @@ -227,6 +250,7 @@ def main(): overlap_allgather=args.overlap_allgather, use_fp8=args.use_fp8, fp8_communication=args.use_fp8_comm, + scheduler_nodes=scheduler_nodes, **hybrid_kwargs, ) elif args.plugin == "3d_cpu": @@ -256,10 +280,6 @@ def main(): # ============================== dp_size = getattr(plugin, "dp_size", coordinator.world_size) - if args.config in MODEL_CONFIGS: - config = MODEL_CONFIGS[args.config] - else: - config = AutoConfig.from_pretrained(args.config, trust_remote_code=True) torch.cuda.manual_seed(42) dataset = RandomDataset( num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size @@ -334,8 +354,12 @@ def main(): return_loss=True, ) loss = outputs["loss"] - if dist.get_rank() == dist.get_world_size() - 1: - print(f"Step {step} loss: {loss}") + if args.pp_style == "zbv": + if dist.get_rank() == 0: + print(f"Step {step} loss: {loss}") + else: + if dist.get_rank() == dist.get_world_size() - 1: + print(f"Step {step} loss: {loss}") optimizer.step() optimizer.zero_grad() diff --git a/examples/language/mixtral/benchmark.py b/examples/language/mixtral/benchmark.py index 2685afced..0334bd81c 100644 --- a/examples/language/mixtral/benchmark.py +++ b/examples/language/mixtral/benchmark.py @@ -227,7 +227,6 @@ def main(): ) optimizer = HybridAdam(model.parameters()) - # optimizer = torch.optim.SGD(model.parameters(), lr=1) torch.set_default_dtype(torch.bfloat16) model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader) @@ -258,8 +257,12 @@ def main(): return_loss=True, ) loss = outputs["loss"] - if dist.get_rank() == dist.get_world_size() - 1: - print(f"Step {step} loss: {loss}") + if args.pp_style == "zbv": + if dist.get_rank() == 0: + print(f"Step {step} loss: {loss}") + else: + if dist.get_rank() == dist.get_world_size() - 1: + print(f"Step {step} loss: {loss}") optimizer.step() optimizer.zero_grad()