From 0b2d55c4ab518bd2e6e66195aaead28d7311ab8f Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Fri, 2 Aug 2024 06:51:38 +0000 Subject: [PATCH 01/16] Support overall loss, update KTO logging --- .../coati/dataset/tokenization_utils.py | 19 +++- .../ColossalChat/coati/models/loss.py | 16 ++- .../ColossalChat/coati/trainer/dpo.py | 9 ++ .../ColossalChat/coati/trainer/kto.py | 37 ++++++- .../ColossalChat/coati/trainer/orpo.py | 12 ++ .../ColossalChat/coati/trainer/ppo.py | 12 +- .../ColossalChat/coati/trainer/sft.py | 14 ++- applications/ColossalChat/examples/README.md | 1 + .../examples/inference/inference.py | 4 +- .../ColossalChat/examples/inference/round.txt | 104 ------------------ .../examples/training_scripts/train_dpo.py | 2 + .../examples/training_scripts/train_kto.py | 2 + .../examples/training_scripts/train_orpo.py | 2 + .../examples/training_scripts/train_ppo.py | 2 + .../examples/training_scripts/train_sft.py | 2 + 15 files changed, 119 insertions(+), 119 deletions(-) delete mode 100644 applications/ColossalChat/examples/inference/round.txt diff --git a/applications/ColossalChat/coati/dataset/tokenization_utils.py b/applications/ColossalChat/coati/dataset/tokenization_utils.py index 9eb2eba87..4f890ffc9 100755 --- a/applications/ColossalChat/coati/dataset/tokenization_utils.py +++ b/applications/ColossalChat/coati/dataset/tokenization_utils.py @@ -49,6 +49,10 @@ def tokenize_sft( messages = data_point["messages"] template = deepcopy(conversation_template) + + if messages[0]["from"] == "system": + template.system_message = str(messages[0]["content"]) + messages.pop(0) template.messages = [] for idx, mess in enumerate(messages): if mess["from"] != template.roles[idx % 2]: @@ -148,11 +152,14 @@ def tokenize_prompt( template = deepcopy(conversation_template) template.messages = [] + if messages[0]["from"] == "system": + template.system_message = str(messages[0]["content"]) + messages.pop(0) + for idx, mess in enumerate(messages): if mess["from"] != template.roles[idx % 2]: raise ValueError( - f"Message should iterate between user and assistant and starts with a \ - line from the user. Got the following data:\n{messages}" + f"Message should iterate between user and assistant and starts with a line from the user. Got the following data:\n{messages}" ) template.append_message(mess["from"], mess["content"]) @@ -225,6 +232,10 @@ def tokenize_rlhf( template = deepcopy(conversation_template) template.clear() + if context[0]["from"] == "system": + template.system_message = str(context[0]["content"]) + context.pop(0) + for idx, mess in enumerate(context): if mess["from"] != template.roles[idx % 2]: raise ValueError( @@ -345,6 +356,10 @@ def tokenize_kto( template = deepcopy(conversation_template) template.clear() + if prompt[0]["from"] == "system": + template.system_message = str(prompt[0]["content"]) + prompt.pop(0) + if prompt[0].get("from", None) != "user": raise ValueError("conversation should start with user") if completion.get("from", None) != "assistant": diff --git a/applications/ColossalChat/coati/models/loss.py b/applications/ColossalChat/coati/models/loss.py index 840cca074..bd0bbd36b 100755 --- a/applications/ColossalChat/coati/models/loss.py +++ b/applications/ColossalChat/coati/models/loss.py @@ -46,7 +46,10 @@ class PolicyLoss(nn.Module): action_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: skip = False - ratio_ = ((log_probs - old_log_probs) * action_mask).exp() + if action_mask is None: + ratio_ = (log_probs - old_log_probs).exp() + else: + ratio_ = ((log_probs - old_log_probs) * action_mask).exp() # note that if dropout is disabled (recommanded), ratio will always be 1. if ratio_.mean() > self.skip_threshold: @@ -56,7 +59,10 @@ class PolicyLoss(nn.Module): surr1 = ratio * advantages surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages loss = -torch.min(surr1, surr2) - loss = masked_mean(loss, action_mask) + if action_mask is not None: + loss = masked_mean(loss, action_mask) + else: + loss = loss.mean(dim=1) loss = loss.mean() return loss, skip, ratio_.max() @@ -81,8 +87,10 @@ class ValueLoss(nn.Module): values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps) surr1 = (values_clipped - returns) ** 2 surr2 = (values - returns) ** 2 - loss = torch.max(surr1, surr2) / torch.sum(action_mask) - loss = torch.sum(loss * action_mask) + if action_mask is not None: + loss = torch.sum(torch.max(surr1, surr2) / torch.sum(action_mask) * action_mask) + else: + loss = torch.mean(torch.max(surr1, surr2)) return 0.5 * loss diff --git a/applications/ColossalChat/coati/trainer/dpo.py b/applications/ColossalChat/coati/trainer/dpo.py index c7ef2be8f..24ddca654 100755 --- a/applications/ColossalChat/coati/trainer/dpo.py +++ b/applications/ColossalChat/coati/trainer/dpo.py @@ -56,6 +56,7 @@ class DPOTrainer(SLTrainer): beta: float = 0.1, gamma: float = 0.0, length_normalization: bool = False, + apply_loss_mask: bool = True, accumulation_steps: int = 1, start_epoch: int = 0, save_interval: int = 0, @@ -67,6 +68,7 @@ class DPOTrainer(SLTrainer): self.actor_scheduler = actor_lr_scheduler self.tokenizer = tokenizer self.actor_loss_fn = DpoLoss(beta, gamma) + self.apply_loss_mask = apply_loss_mask self.save_interval = save_interval self.coordinator = coordinator self.save_dir = save_dir @@ -135,6 +137,10 @@ class DPOTrainer(SLTrainer): batch["reject_attention_mask"], batch["reject_loss_mask"], ) + if not self.apply_loss_mask: + chosen_loss_mask = chosen_loss_mask.fill_(1.0) + reject_loss_mask = reject_loss_mask.fill_(1.0) + batch_size = chosen_input_ids.size()[0] actor_all_logits = self.model( @@ -284,6 +290,9 @@ class DPOTrainer(SLTrainer): batch["reject_attention_mask"], batch["reject_loss_mask"], ) + if not self.apply_loss_mask: + chosen_loss_mask = chosen_loss_mask.fill_(1.0) + reject_loss_mask = reject_loss_mask.fill_(1.0) batch_size = chosen_input_ids.size()[0] diff --git a/applications/ColossalChat/coati/trainer/kto.py b/applications/ColossalChat/coati/trainer/kto.py index 8ab0bc66b..6462ba816 100755 --- a/applications/ColossalChat/coati/trainer/kto.py +++ b/applications/ColossalChat/coati/trainer/kto.py @@ -6,7 +6,7 @@ import os from typing import Any, Optional import torch -import torch.distributed +import torch.distributed as dist from coati.models.loss import KTOLoss from coati.models.utils import calc_masked_log_probs from coati.trainer.utils import all_reduce_mean @@ -59,6 +59,7 @@ class KTOTrainer(SLTrainer): beta: float = 0.1, desirable_weight: float = 1.0, undesirable_weight: float = 1.0, + apply_loss_mask: bool = True, accumulation_steps: int = 1, start_epoch: int = 0, save_interval: int = 0, @@ -70,6 +71,7 @@ class KTOTrainer(SLTrainer): self.actor_scheduler = actor_lr_scheduler self.tokenizer = tokenizer self.kto_loss = KTOLoss(beta=beta, desirable_weight=desirable_weight, undesirable_weight=undesirable_weight) + self.apply_loss_mask = apply_loss_mask self.save_interval = save_interval self.coordinator = coordinator self.save_dir = save_dir @@ -134,6 +136,10 @@ class KTOTrainer(SLTrainer): batch["kl_attention_mask"], batch["kl_loss_mask"], ) + if not self.apply_loss_mask: + loss_mask = loss_mask.fill_(1.0) + kl_loss_mask = kl_loss_mask.fill_(1.0) + batch_size = input_ids.size()[0] # actor logits @@ -182,8 +188,28 @@ class KTOTrainer(SLTrainer): # sync loss_mean = all_reduce_mean(tensor=loss) - chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards.mean()) - rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards.mean()) + chosen_reward_mean = chosen_rewards.mean() + chosen_rewards_list = [ + torch.tensor(0, dtype=loss.dtype, device=loss.device) for _ in range(dist.get_world_size()) + ] + dist.all_gather(chosen_rewards_list, chosen_reward_mean) + rejected_reward_mean = rejected_rewards.mean() + rejected_rewards_list = [ + torch.tensor(0, dtype=loss.dtype, device=loss.device) for _ in range(dist.get_world_size()) + ] + dist.all_gather(rejected_rewards_list, rejected_reward_mean) + chosen_rewards_list = [i for i in chosen_rewards_list if not i.isnan()] + rejected_rewards_list = [i for i in rejected_rewards_list if not i.isnan()] + chosen_rewards_mean = ( + torch.stack(chosen_rewards_list).mean() + if len(chosen_rewards_list) > 0 + else torch.tensor(torch.nan, dtype=loss.dtype, device=loss.device) + ) + rejected_rewards_mean = ( + torch.stack(rejected_rewards_list).mean() + if len(rejected_rewards_list) > 0 + else torch.tensor(torch.nan, dtype=loss.dtype, device=loss.device) + ) self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item()) self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()) self.accumulative_meter.add("loss", loss_mean.to(torch.float16).detach().item()) @@ -256,6 +282,11 @@ class KTOTrainer(SLTrainer): batch["kl_attention_mask"], batch["kl_loss_mask"], ) + + if not self.apply_loss_mask: + loss_mask = loss_mask.fill_(1.0) + kl_loss_mask = kl_loss_mask.fill_(1.0) + batch_size = input_ids.size()[0] # actor logits diff --git a/applications/ColossalChat/coati/trainer/orpo.py b/applications/ColossalChat/coati/trainer/orpo.py index b039da4af..c2f75771c 100644 --- a/applications/ColossalChat/coati/trainer/orpo.py +++ b/applications/ColossalChat/coati/trainer/orpo.py @@ -52,6 +52,7 @@ class ORPOTrainer(SLTrainer): tokenizer: PreTrainedTokenizerBase, max_epochs: int = 1, lam: float = 0.1, + apply_loss_mask: bool = True, accumulation_steps: int = 1, start_epoch: int = 0, save_interval: int = 0, @@ -67,6 +68,7 @@ class ORPOTrainer(SLTrainer): self.save_dir = save_dir self.num_train_step = 0 self.lam = lam + self.apply_loss_mask = apply_loss_mask self.accumulation_steps = accumulation_steps self.device = get_current_device() self.accumulative_meter = AccumulativeMeanMeter() @@ -130,6 +132,11 @@ class ORPOTrainer(SLTrainer): batch["reject_attention_mask"], batch["reject_loss_mask"], ) + + if not self.apply_loss_mask: + chosen_loss_mask = chosen_loss_mask.fill_(1.0) + reject_loss_mask = reject_loss_mask.fill_(1.0) + batch_size = chosen_input_ids.size()[0] actor_out = self.model( input_ids=torch.cat([chosen_input_ids, reject_input_ids]), @@ -263,6 +270,11 @@ class ORPOTrainer(SLTrainer): batch["reject_attention_mask"], batch["reject_loss_mask"], ) + + if not self.apply_loss_mask: + chosen_loss_mask = chosen_loss_mask.fill_(1.0) + reject_loss_mask = reject_loss_mask.fill_(1.0) + batch_size = chosen_input_ids.size()[0] actor_out = self.model( input_ids=torch.cat([chosen_input_ids, reject_input_ids]), diff --git a/applications/ColossalChat/coati/trainer/ppo.py b/applications/ColossalChat/coati/trainer/ppo.py index 287767669..63c813b39 100755 --- a/applications/ColossalChat/coati/trainer/ppo.py +++ b/applications/ColossalChat/coati/trainer/ppo.py @@ -102,6 +102,7 @@ class PPOTrainer(OLTrainer): sample_buffer: bool = False, dataloader_pin_memory: bool = True, offload_inference_models: bool = True, + apply_loss_mask: bool = True, accumulation_steps: int = 1, save_interval: int = 0, save_dir: str = None, @@ -140,6 +141,7 @@ class PPOTrainer(OLTrainer): self.actor_optim = actor_optim self.critic_optim = critic_optim self.save_interval = save_interval + self.apply_loss_mask = apply_loss_mask self.coordinator = coordinator self.actor_save_dir = os.path.join(save_dir, "actor") self.critic_save_dir = os.path.join(save_dir, "critic") @@ -229,7 +231,10 @@ class PPOTrainer(OLTrainer): action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions) actor_loss, to_skip, max_ratio = self.actor_loss_fn( - action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask + action_log_probs, + experience.action_log_probs, + experience.advantages, + action_mask=experience.action_mask if self.apply_loss_mask else None, ) actor_loss = (1 - self.ptx_coef) * actor_loss if not to_skip: @@ -249,7 +254,10 @@ class PPOTrainer(OLTrainer): input_ids=experience.sequences, attention_mask=experience.attention_mask ) # [batch size, prompt_length + response_length] critic_loss = self.critic_loss_fn( - values[:, -num_actions:], experience.values, experience.advantages, action_mask=experience.action_mask + values[:, -num_actions:], + experience.values, + experience.advantages, + action_mask=experience.action_mask if self.apply_loss_mask else None, ) critic_loss = critic_loss * self.vf_coef self.critic_booster.backward(loss=critic_loss, optimizer=self.critic_optim) diff --git a/applications/ColossalChat/coati/trainer/sft.py b/applications/ColossalChat/coati/trainer/sft.py index c09d61034..d37676ada 100755 --- a/applications/ColossalChat/coati/trainer/sft.py +++ b/applications/ColossalChat/coati/trainer/sft.py @@ -41,6 +41,7 @@ class SFTTrainer(SLTrainer): lr_scheduler: _LRScheduler, max_epochs: int = 2, accumulation_steps: int = 8, + apply_loss_mask: bool = True, start_epoch=0, save_interval: int = None, save_dir: str = None, @@ -55,6 +56,7 @@ class SFTTrainer(SLTrainer): self.coordinator = coordinator self.num_train_step = 0 self.num_eval_step = 0 + self.apply_loss_mask = apply_loss_mask self.accumulative_meter = AccumulativeMeanMeter() def _before_fit( @@ -100,7 +102,11 @@ class SFTTrainer(SLTrainer): for i, batch in enumerate(self.train_dataloader): batch = to_device(batch, torch.cuda.current_device()) batch_size = batch["input_ids"].size(0) - outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) + outputs = self.model( + batch["input_ids"], + attention_mask=batch["attention_mask"], + labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"], + ) loss = outputs.loss self.booster.backward(loss=loss, optimizer=self.optimizer) @@ -158,7 +164,11 @@ class SFTTrainer(SLTrainer): ) for batch in self.eval_dataloader: batch = to_device(batch, torch.cuda.current_device()) - outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) + outputs = self.model( + batch["input_ids"], + attention_mask=batch["attention_mask"], + labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"], + ) loss_mean = all_reduce_mean(tensor=outputs.loss) self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0)) step_bar.update() diff --git a/applications/ColossalChat/examples/README.md b/applications/ColossalChat/examples/README.md index b749f197e..904d69cfc 100755 --- a/applications/ColossalChat/examples/README.md +++ b/applications/ColossalChat/examples/README.md @@ -387,6 +387,7 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai - save_dir: path to store the model checkpoints. - max_length: input will be padded/truncated to max_length before feeding to the model. - max_epochs: number of epochs to train. +- disable_loss_mask: whether to use the loss mask to mask the loss or not. For example, in SFT, if the loss mask is disabled, the model will compute the loss across all tokens in the sequence, if the loss mask is applied, only tokens correspond to the assistant responses will contribute to the final loss. - batch_size: training batch size. - mixed_precision: precision to use in training. Support 'fp16' and 'bf16'. Note that some devices may not support the 'bf16' option, please refer to [Nvidia](https://developer.nvidia.com/) to check compatibility. - save_interval: save the model weights as well as optimizer/scheduler states every save_interval steps/episodes. diff --git a/applications/ColossalChat/examples/inference/inference.py b/applications/ColossalChat/examples/inference/inference.py index 103bd8d95..5f59ba452 100755 --- a/applications/ColossalChat/examples/inference/inference.py +++ b/applications/ColossalChat/examples/inference/inference.py @@ -53,8 +53,8 @@ def load_model_and_tokenizer(model_path, tokenizer_path, device="cuda", **kwargs tuple: A tuple containing the loaded model and tokenizer. """ - model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs) - tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs, trust_remote_code=True).to(torch.bfloat16) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token model.to(device) diff --git a/applications/ColossalChat/examples/inference/round.txt b/applications/ColossalChat/examples/inference/round.txt deleted file mode 100644 index ba02074c1..000000000 --- a/applications/ColossalChat/examples/inference/round.txt +++ /dev/null @@ -1,104 +0,0 @@ - - -========== -round 1: -[INST] <> -A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. - - -<> - -tell me a story [/INST] Great, let’s hear a story. - -========== - - -========== -round 2: -[INST] <> -A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. - - -<> - -tell me a story [/INST] Great, let’s hear a story. [INST] calculate 1+1 [/INST] 1+1=2 - -========== - - -========== -round 3: -[INST] <> -A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. - - -<> - -tell me a story [/INST] Great, let’s hear a story. [INST] calculate 1+1 [/INST] 1+1=2 [INST] who is the first president of the USA [/INST] The first president of the United States was George Washington. - -========== - - -========== -round 1: -[INST] <> -A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. - - -<> - -who is the first president of the USA? [/INST] The first president of the United States was George Washington. - -========== - - -========== -round 2: -[INST] <> -A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. - - -<> - -who is the first president of the USA? [/INST] The first president of the United States was George Washington. [INST] tell me a story [/INST] One story that might be interesting is the story of how the United States was founded. In 1776, the Thirteen Colonies united together to form the new nation of America. The first president of the new nation was George Washington. The first state was Pennsylvania. The first capital was Philadelphia. The first presidential election was held in 1787, and was between George Read, a Federalist, and John Adams, a Republican. The Federalists wanted to keep the power of the federal government limited, while the Republicans wanted the power to be spread around to the states. The Federalists won, and George Washington was elected president. - -========== - - -========== -round 1: -[INST] <> -A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. - - -<> - -tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? - -========== - - -========== -round 2: -[INST] <> -A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. - - -<> - -tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? [INST] about Donald Trump [/INST] I’d be happy to listen to a story about Donald Trump. I have no idea what a “story” is, but I do know a lot about human life and human behavior. I can try. Do you have a theme or a question you’d like to ask about Donald Trump? - -========== - - -========== -round 3: -[INST] <> -A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. - - -<> - -tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? [INST] about Donald Trump [/INST] I’d be happy to listen to a story about Donald Trump. I have no idea what a “story” is, but I do know a lot about human life and human behavior. I can try. Do you have a theme or a question you’d like to ask about Donald Trump? [INST] Is Donald Trump the president of the United States [/INST] Yes, Donald Trump became the 45th president of the United States in January of 2016. - -========== diff --git a/applications/ColossalChat/examples/training_scripts/train_dpo.py b/applications/ColossalChat/examples/training_scripts/train_dpo.py index 44131f572..d88750aeb 100755 --- a/applications/ColossalChat/examples/training_scripts/train_dpo.py +++ b/applications/ColossalChat/examples/training_scripts/train_dpo.py @@ -278,6 +278,7 @@ def train(args): beta=args.beta, gamma=args.gamma, length_normalization=args.length_normalization, + apply_loss_mask=not args.disable_loss_mask, ) trainer.fit( @@ -346,6 +347,7 @@ if __name__ == "__main__": default=False, help="Disable the reference model (enabled by default)", ) + parser.add_argument("--disable_loss_mask", default=False, action="store_true") parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision") parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path") parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints") diff --git a/applications/ColossalChat/examples/training_scripts/train_kto.py b/applications/ColossalChat/examples/training_scripts/train_kto.py index d063b82bb..598fd8062 100755 --- a/applications/ColossalChat/examples/training_scripts/train_kto.py +++ b/applications/ColossalChat/examples/training_scripts/train_kto.py @@ -297,6 +297,7 @@ def train(args): beta=args.beta, desirable_weight=args.desirable_weight, undesirable_weight=args.undesirable_weight, + apply_loss_mask=not args.disable_loss_mask, ) trainer.fit( @@ -341,6 +342,7 @@ if __name__ == "__main__": parser.add_argument("--beta", type=float, default=0.1, help="beta in KTO loss") parser.add_argument("--desirable_weight", type=float, default=1.0, help="desirable_weight in KTO loss") parser.add_argument("--undesirable_weight", type=float, default=1.0, help="undesirable_weight in KTO loss") + parser.add_argument("--disable_loss_mask", default=False, action="store_true") parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true") parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2]) parser.add_argument("--zero_cpu_offload", default=False, action="store_true") diff --git a/applications/ColossalChat/examples/training_scripts/train_orpo.py b/applications/ColossalChat/examples/training_scripts/train_orpo.py index f06524507..87860f7ea 100755 --- a/applications/ColossalChat/examples/training_scripts/train_orpo.py +++ b/applications/ColossalChat/examples/training_scripts/train_orpo.py @@ -259,6 +259,7 @@ def train(args): save_dir=args.save_dir, coordinator=coordinator, lam=args.lam, + apply_loss_mask=not args.disable_loss_mask, ) trainer.fit( @@ -301,6 +302,7 @@ if __name__ == "__main__": parser.add_argument("--pp", type=int, default=1) parser.add_argument("--sp", type=int, default=1) parser.add_argument("--lam", type=float, default=0.1, help="lambda in ORPO loss") + parser.add_argument("--disable_loss_mask", default=False, action="store_true") parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true") parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2]) parser.add_argument("--zero_cpu_offload", default=False, action="store_true") diff --git a/applications/ColossalChat/examples/training_scripts/train_ppo.py b/applications/ColossalChat/examples/training_scripts/train_ppo.py index 333be9963..c10418394 100755 --- a/applications/ColossalChat/examples/training_scripts/train_ppo.py +++ b/applications/ColossalChat/examples/training_scripts/train_ppo.py @@ -411,6 +411,7 @@ def train(args): use_cache=True, do_sample=True, temperature=0.7, + apply_loss_mask=not args.disable_loss_mask, accumulation_steps=args.accumulation_steps, save_dir=args.save_path, save_interval=args.save_interval, @@ -498,6 +499,7 @@ if __name__ == "__main__": parser.add_argument("--critic_lr", type=float, default=9e-6) parser.add_argument("--kl_coef", type=float, default=0.1) parser.add_argument("--ptx_coef", type=float, default=0.0) + parser.add_argument("--disable_loss_mask", default=False, action="store_true") parser.add_argument("--max_length", type=int, default=2048) parser.add_argument("--max_seq_len", type=int, default=256) parser.add_argument("--log_dir", default="logs", type=str) diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.py b/applications/ColossalChat/examples/training_scripts/train_sft.py index 6007a8599..c4ef3b783 100755 --- a/applications/ColossalChat/examples/training_scripts/train_sft.py +++ b/applications/ColossalChat/examples/training_scripts/train_sft.py @@ -272,6 +272,7 @@ def train(args): lr_scheduler=lr_scheduler, max_epochs=args.max_epochs, accumulation_steps=args.accumulation_steps, + apply_loss_mask=not args.disable_loss_mask, start_epoch=start_epoch, save_interval=args.save_interval, save_dir=args.save_path, @@ -317,6 +318,7 @@ if __name__ == "__main__": parser.add_argument("--tp", type=int, default=1) parser.add_argument("--pp", type=int, default=1) parser.add_argument("--sp", type=int, default=1) + parser.add_argument("--disable_loss_mask", default=False, action="store_true") parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true") parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2]) parser.add_argument("--zero_cpu_offload", default=False, action="store_true") From 9179d4088e378c437178432168dea9f32fbf739f Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Wed, 7 Aug 2024 13:53:48 +0800 Subject: [PATCH 02/16] [Docs] clarify launch port Co-authored-by: Edenzzzz --- docs/source/en/basics/launch_colossalai.md | 7 ++++--- docs/source/zh-Hans/basics/launch_colossalai.md | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/source/en/basics/launch_colossalai.md b/docs/source/en/basics/launch_colossalai.md index 8a6028d6c..32748dae1 100644 --- a/docs/source/en/basics/launch_colossalai.md +++ b/docs/source/en/basics/launch_colossalai.md @@ -131,17 +131,18 @@ with one simple command. There are two ways you can launch multi-node jobs. This is suitable when you only have a few nodes. Let's say I have two nodes, namely `host1` and `host2`, I can start multi-node training with the following command. Compared to single-node training, you must specify the `master_addr` -option, which is auto-set to localhost if running on a single node only. +option, which is auto-set to localhost if running on a single node only. \ +Additionally, you must also ensure that all nodes share the same open ssh port, which can be specified using --ssh-port. :::caution -`master_addr` cannot be localhost when running on multiple nodes, it should be the hostname or IP address of a node. +`master_addr` cannot be localhost when running on multiple nodes, it should be the **hostname or IP address** of a node. ::: ```shell # run on these two nodes -colossalai run --nproc_per_node 4 --host host1,host2 --master_addr host1 test.py +colossalai run --nproc_per_node 4 --host host1,host2 --master_addr host1 test.py --ssh-port 22 ``` - Run with `--hostfile` diff --git a/docs/source/zh-Hans/basics/launch_colossalai.md b/docs/source/zh-Hans/basics/launch_colossalai.md index a80d16717..9e40f64c2 100644 --- a/docs/source/zh-Hans/basics/launch_colossalai.md +++ b/docs/source/zh-Hans/basics/launch_colossalai.md @@ -116,17 +116,17 @@ colossalai run --nproc_per_node 4 --master_port 29505 test.py - 通过`--hosts`来启动 这个方式适合节点数不多的情况。假设我们有两个节点,分别为`host`和`host2`。我们可以用以下命令进行多节点训练。 -比起单节点训练,多节点训练需要手动设置`--master_addr` (在单节点训练中`master_addr`默认为`127.0.0.1`)。 +比起单节点训练,多节点训练需要手动设置`--master_addr` (在单节点训练中`master_addr`默认为`127.0.0.1`)。同时,你需要确保每个节点都使用同一个ssh port。可以通过--ssh-port设置。 :::caution -多节点训练时,`master_addr`不能为`localhost`或者`127.0.0.1`,它应该是一个节点的名字或者IP地址。 +多节点训练时,`master_addr`不能为`localhost`或者`127.0.0.1`,它应该是一个节点的**名字或者IP地址**。 ::: ```shell # 在两个节点上训练 -colossalai run --nproc_per_node 4 --host host1,host2 --master_addr host1 test.py +colossalai run --nproc_per_node 4 --host host1,host2 --master_addr host1 test.py --ssh-port 22 ``` From ad3fa4f49cee16579e2997b55ff3ffd89577419d Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 8 Aug 2024 18:04:47 +0800 Subject: [PATCH 03/16] [Hotfix] README link (#5966) * update ignore * update readme * run style * update readme --- applications/ColossalChat/.gitignore | 1 + applications/ColossalChat/README.md | 2 +- applications/README.md | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/applications/ColossalChat/.gitignore b/applications/ColossalChat/.gitignore index 757cbb5da..7b361d38e 100755 --- a/applications/ColossalChat/.gitignore +++ b/applications/ColossalChat/.gitignore @@ -151,6 +151,7 @@ examples/training_scripts/wandb examples/training_scripts/output examples/awesome-chatgpt-prompts/ +examples/inference/round.txt temp/ # ColossalChat diff --git a/applications/ColossalChat/README.md b/applications/ColossalChat/README.md index de27ebaf6..3604fab10 100755 --- a/applications/ColossalChat/README.md +++ b/applications/ColossalChat/README.md @@ -121,7 +121,7 @@ cd $COLOSSAL_AI_ROOT BUILD_EXT=1 pip install . # Install ColossalChat -cd $COLOSSAL_AI_ROOT/applications/Chat +cd $COLOSSAL_AI_ROOT/applications/ColossalChat pip install . ``` diff --git a/applications/README.md b/applications/README.md index 5b8b5e501..9957300ae 100644 --- a/applications/README.md +++ b/applications/README.md @@ -14,9 +14,9 @@ This directory contains the applications that are powered by Colossal-AI. The list of applications include: - [X] [Open-Sora](https://github.com/hpcaitech/Open-Sora): Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models +- [X] [ColossalChat](./ColossalChat/): Replication of ChatGPT with RLHF. - [X] [Colossal-LLaMA](./Colossal-LLaMA/): Continual Pre-training and Supervisied Fine-tuning of LLaMA2 / LLaMA3. - [X] [ColossalEval](./ColossalEval): Evaluation Pipeline for LLMs. -- [X] [ColossalChat](./Chat/README.md): Replication of ChatGPT with RLHF. - [X] [FastFold](https://github.com/hpcaitech/FastFold): Optimizing AlphaFold (Biomedicine) Training and Inference on GPU Clusters. - [X] [ColossalQA](./ColossalQA/README.md): Document Retrieval Conversation System - [X] [SwiftInfer](https://github.com/hpcaitech/SwiftInfer): Breaks the Length Limit of LLM Inference for Multi-Round Conversations From b4d2377d4c482960af21bb77bf5ff78099865b02 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 9 Aug 2024 18:17:09 +0800 Subject: [PATCH 04/16] [Hotfix] Avoid fused RMSnorm import error without apex (#5985) Co-authored-by: Edenzzzz --- colossalai/shardformer/layer/normalization.py | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 59e1da9fc..043bf6aeb 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -42,7 +42,7 @@ try: return output except ImportError: - warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel") + warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel") FAST_LAYERNORM_SUPPORTED_SIZE = [ 1024, @@ -270,12 +270,6 @@ class FusedRMSNorm(BaseLayerNorm): Returns: nn.Module: FusedRMSNorm module. """ - try: - pass - except ImportError: - raise ImportError( - "Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel" - ) LazyInitContext.materialize(module) @@ -284,11 +278,18 @@ class FusedRMSNorm(BaseLayerNorm): eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps elementwise_affine = getattr(module, "elementwise_affine", True) - rmsnorm = FusedRMSNormWithHook( - normalized_shape=normalized_shape, - eps=eps, - elementwise_affine=elementwise_affine, - ) + try: + rmsnorm = FusedRMSNormWithHook( + normalized_shape=normalized_shape, + eps=eps, + elementwise_affine=elementwise_affine, + ) + except ImportError: + warnings.warn( + "Module replacement failed.\ + Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel" + ) + return module rmsnorm.weight = module.weight From ed97d3a5d3bb8cd3b2ff62b0097c96bf0991df92 Mon Sep 17 00:00:00 2001 From: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Date: Mon, 12 Aug 2024 14:55:17 +0800 Subject: [PATCH 05/16] [Chat] fix readme (#5989) * fix readme * fix readme, tokenization fully tested * fix readme, tokenization fully tested * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: root Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../coati/dataset/tokenization_utils.py | 2 +- .../ColossalChat/coati/models/utils.py | 7 ++-- applications/ColossalChat/examples/README.md | 28 +++++++-------- .../examples/inference/inference.py | 1 - .../examples/training_scripts/train_ppo.py | 2 +- applications/ColossalChat/requirements.txt | 2 +- applications/ColossalChat/tests/test_train.sh | 36 +++++++++---------- 7 files changed, 38 insertions(+), 40 deletions(-) diff --git a/applications/ColossalChat/coati/dataset/tokenization_utils.py b/applications/ColossalChat/coati/dataset/tokenization_utils.py index 4f890ffc9..020432b9e 100755 --- a/applications/ColossalChat/coati/dataset/tokenization_utils.py +++ b/applications/ColossalChat/coati/dataset/tokenization_utils.py @@ -169,7 +169,7 @@ def tokenize_prompt( template.messages = template.messages[:-1] # Prepare data - prompt = template.get_prompt(length=len(template.messages) - 1, add_generation_prompt=True) + prompt = template.get_prompt(length=len(template.messages), add_generation_prompt=True) tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0] if tokenizer.bos_token_id is not None: diff --git a/applications/ColossalChat/coati/models/utils.py b/applications/ColossalChat/coati/models/utils.py index 8ed8d3401..c583f057a 100755 --- a/applications/ColossalChat/coati/models/utils.py +++ b/applications/ColossalChat/coati/models/utils.py @@ -138,6 +138,7 @@ def disable_dropout(model: torch.nn.Module): Returns: None """ - for module in model.modules(): - if isinstance(module, torch.nn.Dropout): - module.p = 0.0 + if model is not None: + for module in model.modules(): + if isinstance(module, torch.nn.Dropout): + module.p = 0.0 diff --git a/applications/ColossalChat/examples/README.md b/applications/ColossalChat/examples/README.md index 904d69cfc..fec7bc061 100755 --- a/applications/ColossalChat/examples/README.md +++ b/applications/ColossalChat/examples/README.md @@ -462,26 +462,24 @@ Stage1 is supervised instructs fine-tuning (SFT). This step is a crucial part of #### Step 1: Data Collection -The first step in Stage 1 is to collect a dataset of human demonstrations of the following format. +The first step in Stage 1 is to collect a dataset of human demonstrations of the following JSONL format. ```json -[ - {"messages": - [ - { - "from": "user", - "content": "what are some pranks with a pen i can do?" - }, - { - "from": "assistant", - "content": "Are you looking for practical joke ideas?" - }, - ... - ] +{"messages": + [ + { + "from": "user", + "content": "what are some pranks with a pen i can do?" + }, + { + "from": "assistant", + "content": "Are you looking for practical joke ideas?" }, ... -] + ] +}, +... ``` diff --git a/applications/ColossalChat/examples/inference/inference.py b/applications/ColossalChat/examples/inference/inference.py index 5f59ba452..32310cce9 100755 --- a/applications/ColossalChat/examples/inference/inference.py +++ b/applications/ColossalChat/examples/inference/inference.py @@ -151,7 +151,6 @@ def main(args): chat_io.prompt_for_output("assistant") prompt = conv.get_prompt(add_generation_prompt=True) - print(prompt + "") input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].to( torch.cuda.current_device() ) diff --git a/applications/ColossalChat/examples/training_scripts/train_ppo.py b/applications/ColossalChat/examples/training_scripts/train_ppo.py index c10418394..a0a10e239 100755 --- a/applications/ColossalChat/examples/training_scripts/train_ppo.py +++ b/applications/ColossalChat/examples/training_scripts/train_ppo.py @@ -502,7 +502,7 @@ if __name__ == "__main__": parser.add_argument("--disable_loss_mask", default=False, action="store_true") parser.add_argument("--max_length", type=int, default=2048) parser.add_argument("--max_seq_len", type=int, default=256) - parser.add_argument("--log_dir", default="logs", type=str) + parser.add_argument("--log_dir", default=None, type=str) parser.add_argument("--use_wandb", default=False, action="store_true") parser.add_argument("--grad_checkpoint", default=False, action="store_true") parser.add_argument("--use_flash_attn", default=False, action="store_true") diff --git a/applications/ColossalChat/requirements.txt b/applications/ColossalChat/requirements.txt index 2188de12f..ac40ae821 100755 --- a/applications/ColossalChat/requirements.txt +++ b/applications/ColossalChat/requirements.txt @@ -2,7 +2,7 @@ transformers==4.39.3 tqdm datasets==2.14.7 loralib -colossalai==0.4.0 +colossalai>=0.4.0 torch>=2.1.0 langchain tokenizers diff --git a/applications/ColossalChat/tests/test_train.sh b/applications/ColossalChat/tests/test_train.sh index c26b25c83..69036de63 100755 --- a/applications/ColossalChat/tests/test_train.sh +++ b/applications/ColossalChat/tests/test_train.sh @@ -15,7 +15,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" } -set_n_least_used_CUDA_VISIBLE_DEVICES 4 +set_n_least_used_CUDA_VISIBLE_DEVICES 2 set -xu @@ -119,11 +119,11 @@ for lora_rank in ${LORA_RANK[@]}; do lora_config="" fi if [[ $plugin == "3d" ]]; then - tp='4' + tp='2' bs='8' fi if [[ $plugin == "tp_zero2" ]]; then - tp='4' + tp='2' bs='8' zero_stage='2' plugin='3d' @@ -136,13 +136,13 @@ for lora_rank in ${LORA_RANK[@]}; do fi if [[ $plugin == "pp" ]]; then bs='8' - pp='4' + pp='2' plugin='3d' fi if [[ $plugin == "sp_split_gather" ]]; then enable_sequence_parallelism='--enable_sequence_parallelism' sp_mode='split_gather' - tp='4' + tp='2' sp='1' bs='8' plugin='3d' @@ -150,7 +150,7 @@ for lora_rank in ${LORA_RANK[@]}; do if [[ $plugin == "sp_ring" ]]; then enable_sequence_parallelism='--enable_sequence_parallelism' sp_mode='ring' - tp='4' + tp='2' sp='1' bs='8' plugin='3d' @@ -159,7 +159,7 @@ for lora_rank in ${LORA_RANK[@]}; do enable_sequence_parallelism='--enable_sequence_parallelism' sp_mode='all_to_all' tp='1' - sp='4' + sp='2' bs='8' plugin='3d' fi @@ -175,7 +175,7 @@ for lora_rank in ${LORA_RANK[@]}; do for split in $(seq -f "%05g" 0 0); do dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split") done - colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \ + colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \ --pretrain $pretrain \ --tokenizer_dir $tokenizer_dir \ --dataset ${dataset[@]} \ @@ -242,7 +242,7 @@ for lora_rank in ${LORA_RANK[@]}; do lora_config="" fi if [[ $plugin == "3d" ]]; then - tp='4' + tp='2' bs='8' fi grad_accu='2' @@ -256,7 +256,7 @@ for lora_rank in ${LORA_RANK[@]}; do for split in $(seq -f "%05g" 0 0); do dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split") done - colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_rm.py \ + colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_rm.py \ --pretrain $pretrain \ --tokenizer_dir $tokenizer_dir \ --dataset ${dataset[@]} \ @@ -325,7 +325,7 @@ for lora_rank in ${LORA_RANK[@]}; do lora_config="" fi if [[ $plugin == "3d" ]]; then - tp='4' + tp='2' bs='16' ebs='32' fi @@ -350,7 +350,7 @@ for lora_rank in ${LORA_RANK[@]}; do for split in $(seq -f "%05g" 0 0); do ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split") done - colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_ppo.py \ + colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_ppo.py \ --pretrain $pretrain \ --rm_pretrain $pretrain \ --tokenizer_dir $tokenizer_dir \ @@ -417,7 +417,7 @@ for lora_rank in ${LORA_RANK[@]}; do tp='1' bs='2' if [[ $plugin == "3d" ]]; then - tp='4' + tp='2' bs='8' fi if [[ $plugin == "zero2" ]]; then @@ -442,7 +442,7 @@ for lora_rank in ${LORA_RANK[@]}; do for split in $(seq -f "%05g" 0 0); do dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split") done - colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_dpo.py \ + colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_dpo.py \ --pretrain $pretrain \ --tokenizer_dir $tokenizer_dir \ --dataset ${dataset[@]} \ @@ -500,7 +500,7 @@ for lora_rank in ${LORA_RANK[@]}; do tp='1' bs='2' if [[ $plugin == "3d" ]]; then - tp='4' + tp='2' bs='8' fi if [[ $plugin == "zero2" ]]; then @@ -525,7 +525,7 @@ for lora_rank in ${LORA_RANK[@]}; do for split in $(seq -f "%05g" 0 0); do dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split") done - colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_orpo.py \ + colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_orpo.py \ --pretrain $pretrain \ --tokenizer_dir $tokenizer_dir \ --dataset ${dataset[@]} \ @@ -583,7 +583,7 @@ for lora_rank in ${LORA_RANK[@]}; do tp='1' bs='2' if [[ $plugin == "3d" ]]; then - tp='4' + tp='2' bs='8' fi if [[ $plugin == "zero2" ]]; then @@ -608,7 +608,7 @@ for lora_rank in ${LORA_RANK[@]}; do for split in $(seq -f "%05g" 0 0); do dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_kto/arrow/part-$split") done - colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_kto.py \ + colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_kto.py \ --pretrain $pretrain \ --tokenizer_dir $tokenizer_dir \ --dataset ${dataset[@]} \ From ceb1e262e765242c1f130aa72ab9d5e2289162be Mon Sep 17 00:00:00 2001 From: Tong Li Date: Wed, 14 Aug 2024 11:22:39 +0800 Subject: [PATCH 06/16] fix sync condition (#6000) --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index d2933a4af..e5acdb051 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1326,8 +1326,10 @@ class HybridParallelPlugin(PipelinePluginBase): ) # run with gradients accumulation - if model.require_grad_sync == False or ( - isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False + if ( + model.require_grad_sync == False + or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False) + or not torch.is_grad_enabled() ): return outputs From 406f984063423042e25d0723258530ba506a44a9 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 15 Aug 2024 10:41:22 +0800 Subject: [PATCH 07/16] [plugin] add cast inputs option for zero (#6003) --- colossalai/booster/plugin/low_level_zero_plugin.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 66491821c..e4c386a22 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -62,7 +62,9 @@ class OptimizerParamCheckState(enum.Enum): class LowLevelZeroModel(ModelWrapper, AMPModelMixin): - def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None: + def __init__( + self, module: nn.Module, precision: str, overlap_allgather: bool = False, cast_inputs: bool = True + ) -> None: super().__init__(module) self.dtype = None if precision == "fp16": @@ -74,7 +76,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin): module = module.to(get_accelerator().get_current_device()) self.module = module self.convert_fn = None - if self.dtype is not None: + if self.dtype is not None and cast_inputs: self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) self.overlap_allgather = overlap_allgather if overlap_allgather: @@ -334,6 +336,7 @@ class LowLevelZeroPlugin(DPPluginBase): cpu_offload: bool = False, master_weights: bool = True, verbose: bool = False, + cast_inputs: bool = True, ) -> None: super().__init__() assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training" @@ -360,6 +363,7 @@ class LowLevelZeroPlugin(DPPluginBase): ) self.lora_enabled = False self.verbose = verbose + self.cast_inputs = cast_inputs # set class name with stage, for better error message setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}") @@ -474,7 +478,10 @@ class LowLevelZeroPlugin(DPPluginBase): if not isinstance(model, ModelWrapper): model = LowLevelZeroModel( - model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"] + model, + self.precision, + overlap_allgather=self.zero_optim_kwargs["overlap_allgather"], + cast_inputs=self.cast_inputs, ) # TODO: Support Galore + ZeRO From 4dd03999ecb1016b0919c090a065e5bf425432ea Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 15 Aug 2024 14:40:03 +0800 Subject: [PATCH 08/16] [pre-commit.ci] pre-commit autoupdate (#5995) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](https://github.com/psf/black-pre-commit-mirror/compare/24.4.2...24.8.0) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9088d0e1b..a4132a507 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,7 @@ repos: name: sort all imports (python) - repo: https://github.com/psf/black-pre-commit-mirror - rev: 24.4.2 + rev: 24.8.0 hooks: - id: black name: black formatter From 887d2d579b522cadab12571f2357d9e2cbd23aed Mon Sep 17 00:00:00 2001 From: Haze188 Date: Thu, 15 Aug 2024 14:40:26 +0800 Subject: [PATCH 09/16] [misc] Bypass the huggingface bug to solve the mask mismatch problem (#5991) --- colossalai/shardformer/modeling/deepseek.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py index a84a30972..429c4350c 100644 --- a/colossalai/shardformer/modeling/deepseek.py +++ b/colossalai/shardformer/modeling/deepseek.py @@ -666,6 +666,9 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + # TODO: upgrade transformers to 4.44.0 to fix the bug, remove the hard code. + self._use_flash_attention_2 = shard_config.enable_flash_attention + self._use_sdpa = False if shard_config.enable_flash_attention else self._use_sdpa if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None From f5c84af0b01bcd2e993d38dc628793f7f0a8ba64 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 16 Aug 2024 13:56:38 +0800 Subject: [PATCH 10/16] [Feature] Zigzag Ring attention (#5905) * halfway * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add sp_mode to benchmark; fix varlen interface * update softmax_lse shape by new interface * change tester name * remove buffer clone; support packed seq layout * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements --------- Co-authored-by: Edenzzzz Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 1 + .../booster/plugin/hybrid_parallel_plugin.py | 50 +- .../hybrid_parallel_checkpoint_io.py | 4 - colossalai/lazy/pretrained.py | 4 - .../moe/openmoe/model/openmoe_policy.py | 2 +- .../legacy/nn/layer/parallel_1d/_operation.py | 3 + colossalai/logging/logger.py | 5 +- .../pipeline/schedule/interleaved_pp.py | 1 - colossalai/pipeline/schedule/one_f_one_b.py | 1 + colossalai/shardformer/layer/__init__.py | 4 +- colossalai/shardformer/layer/_operation.py | 38 +- colossalai/shardformer/layer/attn.py | 904 +++++++++++++++++- colossalai/shardformer/layer/linear.py | 14 +- colossalai/shardformer/layer/loss.py | 165 +++- colossalai/shardformer/layer/utils.py | 198 +++- colossalai/shardformer/modeling/command.py | 8 +- colossalai/shardformer/modeling/llama.py | 143 ++- .../shardformer/policies/base_policy.py | 1 + colossalai/shardformer/policies/command.py | 31 +- colossalai/shardformer/policies/deepseek.py | 2 +- colossalai/shardformer/policies/llama.py | 45 +- colossalai/shardformer/policies/mistral.py | 2 +- colossalai/shardformer/policies/mixtral.py | 2 +- colossalai/shardformer/policies/qwen2.py | 2 +- colossalai/shardformer/shard/shard_config.py | 13 +- examples/language/llama/benchmark.py | 29 +- examples/language/opt/README.md | 2 +- examples/language/performance_evaluator.py | 24 +- examples/tutorial/opt/opt/README.md | 2 +- .../flash_attention_dao_cuda.py | 8 +- tests/kit/model_zoo/__init__.py | 4 +- tests/kit/model_zoo/transformers/command.py | 12 +- tests/kit/model_zoo/transformers/llama.py | 41 +- tests/kit/model_zoo/transformers/mistral.py | 2 +- tests/kit/model_zoo/transformers/qwen2.py | 12 +- .../test_plugin/test_3d_plugin.py | 2 +- .../test_plugin/test_low_level_zero_plugin.py | 2 +- .../test_gemini_checkpoint_io.py | 2 +- .../test_gemini_torch_compability.py | 2 +- ...st_hybrid_parallel_plugin_checkpoint_io.py | 2 +- .../test_low_level_zero_checkpoint_io.py | 2 +- .../test_plugins_huggingface_compatibility.py | 2 +- tests/test_lora/test_lora.py | 2 +- .../test_schedule/test_interleaved.py | 17 +- .../test_schedule/test_oneF_oneB.py | 17 +- .../test_shardformer/test_flash_attention.py | 3 + .../test_layer/test_ring_attn.py | 186 ++++ tests/test_shardformer/test_model/_utils.py | 27 +- .../test_model/test_shard_command.py | 4 +- .../test_model/test_shard_llama.py | 147 ++- 50 files changed, 1870 insertions(+), 326 deletions(-) create mode 100644 tests/test_shardformer/test_layer/test_ring_attn.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a4132a507..f7217a8f1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,6 +12,7 @@ repos: hooks: - id: isort name: sort all imports (python) + args: ["--profile", "black"] # avoid conflict with black - repo: https://github.com/psf/black-pre-commit-mirror rev: 24.8.0 diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index e5acdb051..63427192f 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -32,7 +32,7 @@ from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackw from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer -from colossalai.shardformer.layer.utils import SeqParallelUtils +from colossalai.shardformer.layer.utils import SeqParallelUtils, is_share_sp_tp from colossalai.shardformer.policies.base_policy import Policy from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.d_tensor.api import is_distributed_tensor @@ -42,7 +42,7 @@ from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_hand from .pp_plugin_base import PipelinePluginBase -SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"] +SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all", "ring_attn"] PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} @@ -72,7 +72,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): self.dp_group = dp_group self.tp_group = tp_group self.sp_group = sp_group - self.use_dpp = use_ddp + self.use_ddp = use_ddp self.require_grad_sync = True self.overlap_allgather = overlap_allgather @@ -139,8 +139,8 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): # Disable automatic gradient synchronization. self.require_grad_sync = False try: - if self.use_dpp: - # If using data parallel processing (use_dpp), disable synchronization too. + if self.use_ddp: + # If using data parallel processing (use_ddp), disable synchronization too. with self.module.no_sync(): yield else: @@ -188,7 +188,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): """ if self.shard_config.enable_sequence_parallelism: - if self.shard_config.sequence_parallelism_mode == "all_to_all": + if self.shard_config.sequence_parallelism_mode in ["all_to_all", "ring_attn"]: return if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: @@ -970,6 +970,9 @@ class HybridParallelPlugin(PipelinePluginBase): enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64. overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism + inner_ring_size (int, optional): The inner ring size of 2D Ring Attention when sp mode is "ring_attn". + It's advisable to not tune this (especially in single-node settings) and let it be heuristically set based on topology by default. + """ def __init__( @@ -1017,6 +1020,7 @@ class HybridParallelPlugin(PipelinePluginBase): dp_outside: bool = True, overlap_p2p: bool = True, overlap_allgather: bool = False, + inner_ring_size: int = None, ) -> None: super().__init__() @@ -1041,9 +1045,11 @@ class HybridParallelPlugin(PipelinePluginBase): ) self.sp_size = 1 self.dp_size = dist.get_world_size() // (tp_size * pp_size) - elif self.sequence_parallelism_mode in ["all_to_all"]: + elif self.sequence_parallelism_mode in ["all_to_all", "ring_attn"]: self.sp_size = 1 if sp_size is None else sp_size self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size) + if self.sequence_parallelism_mode == "ring_attn": + enable_flash_attention = True else: self.dp_size = dist.get_world_size() // (tp_size * pp_size) assert ( @@ -1063,10 +1069,21 @@ class HybridParallelPlugin(PipelinePluginBase): self.enable_sequence_parallelism = enable_sequence_parallelism if dp_outside: self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 - self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) + if sequence_parallelism_mode == "ring_attn": + # Swap tp and sp since 2D Ring has better inter-node latency + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.sp_size, self.tp_size) + self.sp_axis = 2 + self.tp_axis = 3 + else: + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) else: self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 - self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) + if sequence_parallelism_mode == "ring_attn": + self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.sp_size, self.tp_size) + self.sp_axis = 2 + self.tp_axis = 3 + else: + self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) self.stage_manager = None self.schedule = None @@ -1108,6 +1125,8 @@ class HybridParallelPlugin(PipelinePluginBase): ) else: raise NotImplementedError() + if sequence_parallelism_mode == "ring_attn": + assert parallel_output, "Ring Attention doesn't support gathering output yet." self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis) @@ -1132,6 +1151,7 @@ class HybridParallelPlugin(PipelinePluginBase): parallel_output=parallel_output, make_vocab_size_divisible_by=make_vocab_size_divisible_by, gradient_checkpoint_config=gradient_checkpoint_config, + inner_ring_size=inner_ring_size, ) self.amp_config = dict( initial_scale=initial_scale, @@ -1216,15 +1236,15 @@ class HybridParallelPlugin(PipelinePluginBase): zero_stage = 0 if not isinstance(model, ModelWrapper): + # Shouldn't use pp (frequent grad accumulation) with torch ddp use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( - self.dp_size == 1 - and self.pp_size == 1 - and self.enable_sequence_parallelism - and self.sequence_parallelism_mode == "all_to_all" + self.dp_size == 1 and self.pp_size == 1 ) - # sync gradients across DP * SP ranks - if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": + + # Apply Hybrid ZeRO across DP * SP ranks + if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode): dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) + self.dp_size = get_world_size(dp_group) else: dp_group = self.dp_group model = HybridParallelModule( diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 0310df548..043e5c2b0 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -203,7 +203,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): return Path(checkpoint).mkdir(parents=True, exist_ok=True) - # Devices along the same dp_group share the same copies of model. # So only let the device with dp_rank == 0 save the model. if self.dp_rank != 0: @@ -643,14 +642,12 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): assert isinstance(model, ModelWrapper), "Please boost the model before saving!" model._force_wait_all_gather() model = model.unwrap() - if self.dp_rank != 0: return # The logic of collecting parameter shards along tp degree # has been implemented by _save_to_state_dict method of ParallelModule in Shardformer. state_dict = model.state_dict() - if self.pp_size == 1: # When pipeline is not used, let master rank directly save the collected state_dict. if self.tp_rank == 0: @@ -660,7 +657,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): state_dict_list = [None for _ in range(self.pp_size)] dist.barrier(self.pp_group) dist.all_gather_object(state_dict_list, state_dict, self.pp_group) - # Only the master rank do the saving. if self.coordinator.is_master(): complete_state_dict = dict() diff --git a/colossalai/lazy/pretrained.py b/colossalai/lazy/pretrained.py index 736ffc5e4..226951598 100644 --- a/colossalai/lazy/pretrained.py +++ b/colossalai/lazy/pretrained.py @@ -62,7 +62,6 @@ def new_from_pretrained( config = kwargs.pop("config", None) cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) - resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) @@ -116,7 +115,6 @@ def new_from_pretrained( cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, - resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, @@ -195,7 +193,6 @@ def new_from_pretrained( "cache_dir": cache_dir, "force_download": force_download, "proxies": proxies, - "resume_download": resume_download, "local_files_only": local_files_only, "use_auth_token": use_auth_token, "user_agent": user_agent, @@ -312,7 +309,6 @@ def new_from_pretrained( pretrained_model_name_or_path, cache_dir=cache_dir, force_download=force_download, - resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, diff --git a/colossalai/legacy/moe/openmoe/model/openmoe_policy.py b/colossalai/legacy/moe/openmoe/model/openmoe_policy.py index ccd566b08..d5824afcb 100644 --- a/colossalai/legacy/moe/openmoe/model/openmoe_policy.py +++ b/colossalai/legacy/moe/openmoe/model/openmoe_policy.py @@ -171,7 +171,7 @@ class OpenMoeForCausalLMPolicy(OpenMoePolicy): policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: - # add a new item for casual lm + # add a new item for causal lm # TODO: recursively assign ep group foe all modules new_item = { OpenMoeForCausalLM: ModulePolicyDescription( diff --git a/colossalai/legacy/nn/layer/parallel_1d/_operation.py b/colossalai/legacy/nn/layer/parallel_1d/_operation.py index 8b8f04ccf..e892336bc 100644 --- a/colossalai/legacy/nn/layer/parallel_1d/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_1d/_operation.py @@ -81,6 +81,9 @@ class LinearWithAsyncCommunication(torch.autograd.Function): handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True) # Delay the start of weight gradient computation shortly (3us) to have # all-reduce scheduled first and have GPU resources allocated + # TODO: This seems to only work if you add torch.cuda.Event.wait() + + # _ = torch.zeros(1, device=grad_output.device) grad_weight = grad_output.t().matmul(total_input) grad_bias = grad_output.sum(dim=0) if use_bias else None diff --git a/colossalai/logging/logger.py b/colossalai/logging/logger.py index eb5f28e2a..9f4b7a7b0 100644 --- a/colossalai/logging/logger.py +++ b/colossalai/logging/logger.py @@ -64,7 +64,10 @@ class DistributedLogger: self._logger.propagate = False DistributedLogger.__instances[name] = self - self.rank = dist.get_rank() if dist.is_initialized() else 0 + + @property + def rank(self): + return dist.get_rank() if dist.is_initialized() else 0 @staticmethod def __get_call_info(): diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index a21b45c44..412f3896f 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -286,7 +286,6 @@ class InterleavedSchedule(PipelineSchedule): # for the first stage, input_obj is None # for other stages, input_obj is the output of the previous stage containing hidden_states etc. # Only attention_mask from micro_batch is used - with self.stage_manager.switch_model_chunk_id(model_chunk_id): if isinstance(model_chunk, ModuleList): output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 7f0d0e349..03df67ae7 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -244,6 +244,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): output_obj = model_forward(model, micro_batch, input_obj) if self.stage_manager.is_last_stage(): loss = criterion(output_obj, micro_batch) / self.num_microbatches + if accum_loss is not None: accum_loss.add_(loss.detach()) if outputs is not None: diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 331e49729..8882a33c1 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,5 +1,5 @@ from ._operation import all_to_all_comm -from .attn import AttnMaskType, ColoAttention +from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D @@ -31,5 +31,7 @@ __all__ = [ "VocabParallelLMHead1D", "AttnMaskType", "ColoAttention", + "RingAttention", + "get_pad_info", "all_to_all_comm", ] diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 19da348e7..25983e0a9 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -2,6 +2,8 @@ import torch import torch.distributed as dist import torch.nn.functional as F +from .utils import is_share_sp_tp + try: import fused_mix_prec_layer_norm_cuda except: @@ -93,7 +95,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function): if ctx.async_grad_allreduce: # Asynchronous all-reduce handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) - # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have + # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py grad_weight = total_input.t().matmul(grad_output) @@ -143,7 +145,9 @@ class LinearWithAsyncCommunication(torch.autograd.Function): if ctx.async_grad_allreduce: # Asynchronous all-reduce handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) - # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have + _ = torch.zeros(1, device=grad_input.device) + + # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py if _grad_accum_fusion_available and weight.grad is not None: @@ -331,7 +335,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): input_.shape, dtype=input_parallel.dtype, device=input_parallel.device ).contiguous() handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) - # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have + # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py if _grad_accum_fusion_available and weight.grad is not None: @@ -646,8 +650,8 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): input_.shape, dtype=input_parallel.dtype, device=input_parallel.device ).contiguous() handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) - # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have - # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py + # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have + # all-reduce scheduled first and have GPU resources allocated grad_weight = total_input.t().matmul(grad_output) grad_bias = grad_output.sum(dim=0) if use_bias else None @@ -721,16 +725,20 @@ class _ReduceForward(torch.autograd.Function): Args: input_: input matrix. - parallel_mode: parallel mode. + process_group: communication group. + """ @staticmethod - def forward(ctx, input_, process_group): + def forward(ctx, input_, process_group, grad_scale=None): + ctx.grad_scale = grad_scale return _reduce(input_, process_group) @staticmethod def backward(ctx, grad_output): - return grad_output, None + if ctx.grad_scale is not None: + grad_output = grad_output * ctx.grad_scale + return grad_output, None, None class _ReduceBackward(torch.autograd.Function): @@ -979,8 +987,8 @@ def split_forward_gather_backward(input_, dim, process_group, grad_scale=None): return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale) -def reduce_forward(input_, process_group): - return _ReduceForward.apply(input_, process_group) +def reduce_forward(input_, process_group, grad_scale=None): + return _ReduceForward.apply(input_, process_group, grad_scale) def reduce_backward(input_, process_group): @@ -989,3 +997,13 @@ def reduce_backward(input_, process_group): def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1): return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) + + +def gather_sp_output(hidden_states, sp_group, sp_mode): + """ + Gather the output of the last layer for cross entropy computation + """ + # Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group) + scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group) + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=scale) + return hidden_states diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 5872c6485..6dab17ec0 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -2,7 +2,10 @@ from enum import Enum from typing import Callable, Dict, Optional, Tuple import torch +import torch.distributed +import torch.distributed as dist import torch.nn.functional as F +from einops import rearrange from colossalai.kernel.kernel_loader import ( FlashAttentionForFloatAndCustomMaskLoader, @@ -10,12 +13,18 @@ from colossalai.kernel.kernel_loader import ( FlashAttentionWithCustomMaskLoader, KernelLoader, ) +from colossalai.logging import get_dist_logger + +from .utils import RingComm, get_half_index, split_varlen_zigzag __all__ = [ "AttnMaskType", "ColoAttention", ] +_flash_attn_forward = _flash_attn_backward = None +_unpad_input = _pad_input = None + class AttnMaskType(Enum): CUSTOM = 0 @@ -38,20 +47,32 @@ def invert_mask(mask: torch.Tensor) -> torch.Tensor: # adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py -def get_pad_info(padding_mask: torch.Tensor) -> Tuple[int, torch.Tensor, torch.Tensor]: +def get_pad_info( + padding_mask: torch.Tensor, invert: Optional[bool] = False, return_indices: Optional[bool] = True +) -> Tuple[int, torch.Tensor, torch.Tensor]: """Get padding information from padding mask. Args: - padding_mask (torch.Tensor): Padding mask tensor. Shape should be [B, S] + padding_mask (torch.Tensor): Padding mask tensor. Shape should be [B, Skv] + invert (Optional[bool], optional): Whether to reverse the padding mask. + return_indices (Optional[bool], optional): Whether to return the indices of non-masked tokens. Returns: - Tuple[int, torch.Tensor, torch.Tensor]: Tuple of (max_seq_len, cu_seqlens, indices) + max_seqlen_in_batch (int): Maximum sequence length in the batch. + cu_seqlens (torch.Tensor): Shape [B+1]. Cumulative sequence lengths of the sequences in the batch. + indices (torch.Tensor): Shape [total_nonzero]. The indices of non-masked tokens from the flattened input sequence. """ + if invert: + padding_mask = padding_mask.logical_not() seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() + if return_indices: + indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - return max_seqlen_in_batch, cu_seqlens, indices + if return_indices: + return max_seqlen_in_batch, cu_seqlens, indices + return max_seqlen_in_batch, cu_seqlens class ColoAttention: @@ -107,6 +128,7 @@ class ColoAttention: q_padding_mask: Optional[torch.Tensor] = None, kv_padding_mask: Optional[torch.Tensor] = None, is_causal: bool = False, + invert: bool = True, ) -> Dict[str, torch.Tensor]: """Return a dictionary of keyword arguments for attention function. It supports 4 mask type. 1. custom mask: no padding mask and is_causal=False, return {}, users should handle attention mask by themselves. @@ -124,7 +146,7 @@ class ColoAttention: The shape should be [B, Skv]. ``1`` means valid token, and ``0`` means padding token. If it's None and ``q_padding_mask`` is not None, it will be set to ``q_padding_mask``. Defaults to None. is_causal (bool, optional): Whether to use causal attention mask. Defaults to False. - + invert_mask (bool, optional): Whether to invert the mask. Defaults to True. Returns: Dict[str, torch.Tensor]: Dictionary of keyword arguments for attention function. """ @@ -154,7 +176,7 @@ class ColoAttention: assert kv_padding_mask.shape == ( b, s_kv, - ), f"q_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})" + ), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})" attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) outputs.update( { @@ -172,7 +194,8 @@ class ColoAttention: attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0) else: outputs["attention_mask_type"] = AttnMaskType.PADDED - attention_mask = invert_mask(attention_mask).unsqueeze(1) + if invert: + attention_mask = invert_mask(attention_mask).unsqueeze(1) outputs["attention_mask"] = attention_mask return outputs @@ -191,6 +214,7 @@ class ColoAttention: kv_indices: Optional[torch.Tensor] = None, dropout_p: float = 0.0, scale: Optional[float] = None, + **kwargs, ) -> torch.Tensor: """Flash Attention function. It supports 4 mask type. 1. custom mask: recv attention_mask @@ -199,9 +223,9 @@ class ColoAttention: 4. padded causal mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices Args: - q (torch.Tensor): Query tensor. Shape should be [B, N, Sq, D] - k (torch.Tensor): Key tensor. Shape should be [B, N, Skv, D] - v (torch.Tensor): Value tensor. Shape should be [B, N, Skv, D] + q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D] + k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Skv, D] + v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Skv, D] attention_mask (Optional[torch.Tensor], optional): Attention mask tensor. Shape should be [B, 1, Sq, Skv]. Defaults to None. attention_mask_type (AttnMaskType, optional): Attention mask type. Defaults to AttnMaskType.CUSTOM. cu_seqlens_q (Optional[torch.Tensor], optional): The cumulative sequence lengths @@ -218,7 +242,7 @@ class ColoAttention: scale (Optional[float], optional): Scaling factor applied prior to softmax. Defaults to None. Returns: - torch.Tensor: Output tensor. Shape should be [B, N, Sq, D] + torch.Tensor: Output tensor. Shape should be [B, nHeads, Sq, D] """ # known issue: sdpa does not support attention mask which contains whole row of masked tokens, which leads to nan # this case is usaul when padding mask is used and self attention is performed @@ -252,6 +276,7 @@ class ColoAttention: else: # if attention_mask is None, attention_mask_type should be the default value assert attention_mask_type == AttnMaskType.CUSTOM + # kernel dispatch mask_type = attention_mask_type if attention_mask is not None else None attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type) @@ -274,3 +299,858 @@ class ColoAttention: q_indices=q_indices, kv_indices=kv_indices, ) + + +def _load_varlen_helpers(): + """Helper to load functions for padding and unpadding packed sequences. + Use only when flash attn is installed + """ + global _pad_input, _unpad_input + # Flash attn claims this is more efficient than torch's bool indexing due to avoiding + # broadcast + if _pad_input is None or _unpad_input is None: + try: + from flash_attn.bert_padding import index_first_axis, pad_input + + def unpad_input(hidden_states: torch.Tensor, indices: torch.Tensor): + return index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices) + + _pad_input = pad_input + _unpad_input = unpad_input + except ImportError as e: + raise RuntimeError( + f"Flash Attention is not installed. You can install it via 'pip install flash-attn --no-build-isolation'" + ) from e + + +def _load_flash_attn(): + """A light-weight loader to check whether flash-attn is installed. + Can't use ColoAttention._dispatch_kernel because we mutate the backward pass + """ + global _flash_attn_forward, _flash_attn_backward + if _flash_attn_forward is None or _flash_attn_backward is None: + try: + from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward + from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward + except ImportError as e: + raise RuntimeError( + f"Flash Attention is not installed. You can install it via 'pip install flash-attn --no-build-isolation'" + ) from e + + _load_varlen_helpers() + + +# NOTE: This can cause spawned processes to hang on exit +# with python 3.9 +@torch.compile() +def _rescale_out_lse(out, block_out, lse, block_lse): + """ + Compute the new attention denominator: + exp(lse) + exp(block_lse) = exp(max_scale) * (exp(min_scale - max_scale) + 1) + Args: + out: (T, H, D) + block_out: (T, H, D) + lse: (H, T, 1) + block_lse: (H, T, 1) + """ + + # min_scale = torch.min(lse, block_lse) + # max_scale = torch.max(lse, block_lse) + # new_lse = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) + + # NOTE: directly assigning to .data here is buggy + # probably due to casting dtypes/strides + new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) + + new_block_lse = torch.exp(block_lse - new_lse) + out = (torch.exp(lse - new_lse) * out + new_block_lse * block_out).to(out) + lse = new_lse + + # Equivalent to the above + # See https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 + # out = (out - F.sigmoid(block_lse - lse) * (out - block_out)) + # lse = (lse - F.logsigmoid(lse - block_lse)) + return out, lse + + +class RingAttention(torch.autograd.Function): + """Implements the Ring Attention from `Ring Attention with Blockwise Transformers for Near-Infinite Context` + (https://arxiv.org/abs/2310.01889). + For load-balancing we adopted the "zigzag" attention scheme from https://github.com/zhuzilin/ring-flash-attention/tree/main + For portable integration with more models, we don't follow the spirit of "block-wise FNN" in the original paper, + which requires fusing FFN with the Flash Attention kernel/function (see https://arxiv.org/pdf/2305.19370; + implemented in Jax and not optimized). + We adopt the double ring topology from LoongTrain (https://arxiv.org/pdf/2406.18485) to fully utilize available + NICs on each node, by computing attention within a inner ring first and then sending all KVs to the next + ring at once. + """ + + # Globle cache to avoid recomputation for same-lengthed sequences + CU_SEQLENS: torch.Tensor = None # [B+1] + TOTAL_SEQLEN: int = None + HALF_INDICES: Tuple = None + SUPPORTED_MASK_TYPES = (AttnMaskType.CAUSAL, AttnMaskType.PADDED_CAUSAL) + ATTN_DONE: torch.cuda.Event = None + SP_STREAM: torch.cuda.Stream = None + SP_GROUP: dist.ProcessGroup = None + # duplicate process group for concurrent NCCL streams + # while both PyTorch and NCCL warns(https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7) + # against this, in practice it seems to work fine. + INNER_RING_GROUP: dist.ProcessGroup = None + INNER_RING_GROUP_COPY: dist.ProcessGroup = None + INTER_RING_GROUP: dist.ProcessGroup = None + INTER_RING_GROUP_COPY: dist.ProcessGroup = None + + @staticmethod + def get_double_ring_groups(sp_group, inner_ring_size=None): + """ + Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size + shouldn't be larger than the number of NICs on each node. + Args: + sp_group (dist.ProcessGroup): Process group for sequence parallelism + inner_ring_size (Optional[int], optional): Inner ring size. Defaults to None. + Returns: + Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group. + """ + sp_size = dist.get_world_size(sp_group) + sp_rank = dist.get_rank(sp_group) + + if inner_ring_size is None: + if torch.cuda.device_count() >= dist.get_world_size(): + # single node, no need to consider NICs + return sp_group, sp_group + if sp_size <= 4: + inner_ring_size = min(2, sp_size) + else: + inner_ring_size = min(4, sp_size) + else: + assert ( + inner_ring_size <= sp_size and sp_size % inner_ring_size == 0 + ), f"Error: sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}" + + if inner_ring_size == sp_size: + return sp_group, sp_group + assert ( + sp_size % inner_ring_size == 0 + ), f"sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}" + + logger = get_dist_logger() + logger.info( + f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for speed-ups!", + ranks=[0], + ) + num_rings = sp_size // inner_ring_size + inner_ring_group = None + inter_ring_group = None + + # Create inner ring groups + for i in range(inner_ring_size): + ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size)) + group = dist.new_group(ranks) + if sp_rank in ranks: + inner_ring_group = group + + # Create inter ring groups + for i in range(num_rings): + ranks = list(range(i, sp_size, num_rings)) + group = dist.new_group(ranks) + if sp_rank in ranks: + inter_ring_group = group + + return inner_ring_group, inter_ring_group + + @staticmethod + def attention( + q, # (B, H, Sq, D) + k, + v, + sp_group, + attention_mask_type, + cu_seqlens=None, + max_seqlen=None, + valid_indices=None, + dropout_p=0.0, + softmax_scale=None, + deterministic=False, + return_softmax=False, + inner_ring_size=None, + **kwargs, + ): + """ + Ring Attention forward pass supporting variable-length sequences. When using varlen mode, + each sequence in the batch should have length divisible by sp_size * 2. + + Args: + q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D] + k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Sq, Sq, D] + v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Sq, Sq, D] + sp_group (Optional[dist.ProcessGroup]): Process group for sequence parallelism + sp_tream (torch.cuda.Stream): An different stream for output correction. + cu_seqlens (Optional[torch.Tensor], optional): The cumulative sequence lengths + of the sequences in the batch, used to index into q. + Shape should be [B+1]. + max_seqlen (Optional[int], optional): Maximum query sequence length in the batch. + valid_indices (Optional[torch.Tensor], optional): The indices of non-masked tokens from get_pad_info. + Shape should be [t]. + dropout_p (float, optional): Dropout probability. Defaults to 0.0. + softmax_scale (Optional[float], optional): Scaling factor applied prior to softmax. + deterministic (bool, optional): Whether to force deterministic backward pass. See https://github.com/Dao-AILab/flash-attention/issues/349 + return_softmax (bool, optional): Whether to return the softmax denominator (logsumexp). + inner_ring_size (Optional[int], optional): Inner ring size of the 2D ring. By default use a heuristic to decide. + + Returns: + out: Output tensor of shape [B, nHeads, Sq, D] or [T, nHeads, D] if pad_output is False. + softmax_lse: (if return_softmax is True) Softmax denominator (logsumexp). + Shape should be [total_q_seqlen, nHeads] + """ + # Check input args + _load_flash_attn() + if RingAttention.ATTN_DONE is None: + RingAttention.ATTN_DONE = torch.cuda.Event() + if RingAttention.SP_STREAM is None: + RingAttention.SP_STREAM = torch.cuda.Stream() + + assert ( + q.shape[2] == k.shape[2] + ), "Q, K and V having different sequence lengths (inference or cross-attn)\ + is not supported yet in training." + assert ( + attention_mask_type in RingAttention.SUPPORTED_MASK_TYPES + ), f"Mask type {attention_mask_type} is not supported yet." + + clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg)) + + if RingAttention.SP_GROUP is not sp_group: + RingAttention.SP_GROUP = sp_group + inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size) + RingAttention.INNER_RING_GROUP = inner_ring_group + RingAttention.INTER_RING_GROUP = inter_ring_group + else: + inner_ring_group = RingAttention.INNER_RING_GROUP + inter_ring_group = RingAttention.INTER_RING_GROUP + + # (B, H, Sq, D) -> (B, Sq, H, D) + q, k, v = [x.transpose(1, 2).contiguous() for x in (q, k, v)] + pad_output = q.dim() == 4 + + # Get sequence length info for varlen forward + if attention_mask_type == AttnMaskType.CAUSAL: + # All sequences share the same length + b, sq, h, d = q.shape + max_seqlen = sq + # Cache to avoid recreation for a single sequence + if sq * b == RingAttention.TOTAL_SEQLEN: + cu_seqlens = RingAttention.CU_SEQLENS + else: + cu_seqlens = torch.arange(0, b * sq + 1, sq, device=q.device, dtype=torch.int32) + RingAttention.TOTAL_SEQLEN = b * sq + + # "Packed" mode where sequences of different lengths are packed into [total_q_seqlen, H, D] + elif attention_mask_type == AttnMaskType.PADDED_CAUSAL: + assert ( + cu_seqlens is not None and max_seqlen is not None and valid_indices is not None + ), "Packed mode requires pre-computed cu_seqlens and max_seq_len." + if pad_output: + b, sq, h, d = q.shape + q, k, v = [_unpad_input(x, valid_indices) for x in (q, k, v)] + + out, softmax_lse = RingAttention.apply( + q, + k, + v, + sp_group, + RingAttention.SP_STREAM, + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + deterministic, + return_softmax, + attention_mask_type == AttnMaskType.PADDED_CAUSAL, + inner_ring_group, + inter_ring_group, + ) + + if attention_mask_type == AttnMaskType.PADDED_CAUSAL: + if pad_output: + out = _pad_input(out, valid_indices, b, sq) # (T, ...) -> (B, Sq, ...) + out = out.transpose(1, 2) # (B, Sq, H, D) -> (B, H, Sq, D) + else: + out = out.transpose(1, 2) + + if return_softmax: + return out, softmax_lse + return out + + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sp_group: dist.ProcessGroup, + sp_stream: torch.cuda.Stream, + cu_seqlens: torch.Tensor, + max_seqlen: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + deterministic: Optional[bool] = False, + return_softmax: Optional[bool] = False, + is_packed: Optional[bool] = False, + inner_ring_group: Optional[dist.ProcessGroup] = None, + inter_ring_group: Optional[dist.ProcessGroup] = None, + ): + + cu_seqlens_q = cu_seqlens_kv = cu_seqlens + max_seqlen_q = max_seqlen_kv = max_seqlen + cu_seqlens_half = cu_seqlens // 2 + max_seqlen_half = max_seqlen // 2 + + misc_kwargs = { + "window_size": (-1, -1), + "alibi_slopes": None, + "softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale, + "dropout_p": dropout_p, + "block_table": None, + "softcap": 0.0, + "return_softmax": False, + } + + if ( + RingAttention.HALF_INDICES is not None + and cu_seqlens.shape == RingAttention.CU_SEQLENS.shape + and (cu_seqlens == RingAttention.CU_SEQLENS).all() + ): + half_idx_front, half_idx_back = RingAttention.HALF_INDICES + else: + half_idx_front = get_half_index(cu_seqlens, front=True) + half_idx_back = get_half_index(cu_seqlens, front=False) + RingAttention.HALF_INDICES = (half_idx_front, half_idx_back) + RingAttention.CU_SEQLENS = cu_seqlens + + if is_packed: + t, h, d = q.shape + else: + b, sq, h, d = q.shape + t = b * sq + # Be careful about GQA/MQA in reshape + q, k, v = [x.view(t, *x.shape[-2:]) for x in (q, k, v)] + + if inner_ring_group is None or inter_ring_group is None: + # Use one ring if not specified + inner_ring_group = inter_ring_group = sp_group + + sp_size = dist.get_world_size(sp_group) + sp_rank = dist.get_rank(sp_group) + # Attempt to achieve concurrent comm in the two-stream forward + local_kv_comms = [RingComm(inner_ring_group) for _ in range(2)] + inter_ring_comm = RingComm(inter_ring_group) + local_sp_size = dist.get_world_size(inner_ring_group) + local_sp_rank = dist.get_rank(inner_ring_group) + inter_ring_rank = dist.get_rank(inter_ring_group) if inter_ring_group is not sp_group else 0 + num_rings = dist.get_world_size(inter_ring_group) if inter_ring_group is not sp_group else 1 + + # Non-contiguous indexing copies to a new contiguous tensor, + # so only do it once + if sp_rank != sp_size - 1: + q1 = q[half_idx_back] + + # Pre-allocate double buffer for overlapping and receiving next step's inputs + kv_buffers = [torch.stack((k, v))] # (2, B, Sq, H, D) + kv_buffers.append(torch.empty_like(kv_buffers[0])) + + # outputs + out = None + block_out = [None, None] + softmax_lse = [None, None] + block_softmax_lse = [None, None] # log sum exp, the denominator of softmax in attention + rng_states = [None for _ in range(sp_size)] + sp_streams = [torch.cuda.current_stream(), sp_stream] + + def _forward(q, k, v, causal): + ( + _, + _, + _, + _, + out, + softmax_lse, + _, + rng_state, + ) = _flash_attn_forward( + q, + k, + v, + cu_seqlens_q if q.shape[0] == t else cu_seqlens_half, + cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half, + max_seqlen_q if q.shape[0] == t else max_seqlen_half, + max_seqlen_kv if k.shape[0] == t else max_seqlen_half, + causal=causal, + **misc_kwargs, + ) + return out, softmax_lse, rng_state + + def _local_ring_forward(): + # (Hopefully) overlap output correction with next flash attn + for i in range(local_sp_size): + with torch.cuda.stream(sp_streams[i % 2]): + # Wait for current kv from prev rank + # NOTE: waiting outside the current stream will NOT correctly synchronize. + if i > 0: + local_kv_comms[(i + 1) % 2].wait() + + # Avoid overwriting attn input when it shares mem with buffer + if not RingAttention.ATTN_DONE.query(): + kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2]) + if i < local_sp_size - 1: + local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) + + if i == 0: + # Compute with local KV; no mask + kv_block = kv_buffers[0] + q_block = q + (block_out[i % 2], block_softmax_lse[i % 2], rng_states[i]) = _forward( # (T, H, D) # (H, T) + q_block, kv_block[0], kv_block[1], causal=True + ) + elif i <= local_sp_rank: + # Received the "surrounding" kv chunks + # Drop the second half of received kv + # (2, t // 2, H, D) + kv_block = kv_buffers[i % 2][:, half_idx_front] + q_block = q + ( + block_out[i % 2], # (T, H, D) + block_softmax_lse[i % 2], # (H, T) + rng_states[i], + ) = _forward(q_block, kv_block[0], kv_block[1], causal=False) + else: + # Received the inner kv chunks + # Drop the first half of q + kv_block = kv_buffers[i % 2] + q_block = q1 + ( + block_out[i % 2], # (T, H, D) + block_softmax_lse[i % 2], # (H, T) + rng_states[i], + ) = _forward(q_block, kv_block[0], kv_block[1], causal=False) + RingAttention.ATTN_DONE.record() + + block_softmax_lse[i % 2] = ( + block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() + ) # (H, T) -> (T, H, 1) + assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1] + # Output and log sum exp correction. Ideally overlap this with the next flash attn kernel. + # In reality this always finishes before next flash attn; no need for extra sync. + if i == 0: + out = block_out[0] + softmax_lse = block_softmax_lse[0] + elif i <= local_sp_rank: + out, softmax_lse = _rescale_out_lse( + out, block_out[i % 2], softmax_lse, block_softmax_lse[i % 2] + ) + else: + out[half_idx_back], softmax_lse[half_idx_back] = _rescale_out_lse( + out[half_idx_back], block_out[i % 2], softmax_lse[half_idx_back], block_softmax_lse[i % 2] + ) + + torch.cuda.current_stream().wait_stream(sp_stream) + return out, softmax_lse + + def _other_ring_forward(ring_num_idx, out, softmax_lse): + # Loop through the inner ring after receiving + # all new KVs from the previous inner ring + for i in range(local_sp_size): + with torch.cuda.stream(sp_streams[i % 2]): + if not RingAttention.ATTN_DONE.query(): + kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2]) + if i < local_sp_size - 1: + local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) + + # Send & recv KV + if i > 0: + local_kv_comms[(i + 1) % 2].wait() + + if ring_num_idx > inter_ring_rank: + kv_block = kv_buffers[i % 2] + ( + block_out[i % 2], + block_softmax_lse[i % 2], + rng_states[i + local_sp_size * ring_num_idx], + ) = _forward(q1, kv_block[0], kv_block[1], causal=False) + RingAttention.ATTN_DONE.record() + block_softmax_lse[i % 2] = ( + block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() + ) + out[half_idx_back], softmax_lse[half_idx_back] = _rescale_out_lse( + out[half_idx_back], block_out[i % 2], softmax_lse[half_idx_back], block_softmax_lse[i % 2] + ) + else: + kv_block = kv_buffers[i % 2][:, half_idx_front] + ( + block_out[i % 2], + block_softmax_lse[i % 2], + rng_states[i + local_sp_size * ring_num_idx], + ) = _forward(q, kv_block[0], kv_block[1], causal=False) + RingAttention.ATTN_DONE.record() + block_softmax_lse[i % 2] = ( + block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() + ) + out, softmax_lse = _rescale_out_lse( + out, block_out[i % 2], softmax_lse, block_softmax_lse[i % 2] + ) + + torch.cuda.current_stream().wait_stream(sp_stream) + return out, softmax_lse + + # Send and recv KV between rings at once to maximize NIC util. + inter_ring_kv = None + for ring_num_idx in range(num_rings): + if ring_num_idx > 0: + inter_ring_comm.wait() + # Reset indices + kv_buffers[0] = inter_ring_kv + + if ring_num_idx < num_rings - 1: + if ring_num_idx == 0: + to_send = kv_buffers[0] + else: + # The last received KV + to_send = kv_buffers[(local_sp_size - 1) % 2] + inter_ring_kv = inter_ring_comm.send_recv(to_send) + + if ring_num_idx == 0: + out, softmax_lse = _local_ring_forward() + else: + out, softmax_lse = _other_ring_forward(ring_num_idx, out, softmax_lse) + + out = out.to(q.dtype) + if not is_packed: + out = out.view(b, sq, h, d) + q, k, v = [x.view(b, sq, *x.shape[-2:]) for x in (q, k, v)] # (T, H, D) -> (B, Sq, H, D) + softmax_lse = softmax_lse.squeeze(-1) + + ctx.sp_group = sp_group + ctx.max_seqlen_q = ctx.max_seqlen_kv = max_seqlen + misc_kwargs["deterministic"] = deterministic + del misc_kwargs["return_softmax"] + ctx.misc_kwargs = misc_kwargs + ctx.is_packed = is_packed + + ctx.kv_group = inner_ring_group + ctx.inter_kv_group = inter_ring_group + + ctx.save_for_backward( + q, + k, + v, + out, + softmax_lse.transpose(0, 1).contiguous(), # (T, H) -> (H, T) + cu_seqlens_q, + cu_seqlens_kv, + half_idx_front, + half_idx_back, + *rng_states, + ) + + if return_softmax: + return out, softmax_lse + return out, None + + def backward(ctx, dout, _): + """ + During backward, we accumulate q grads on each rank locally, but iterate kv and their grads + over all ranks for accumulation. + """ + (q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_kv, half_idx_front, half_idx_back) = ctx.saved_tensors[:9] + rng_states = ctx.saved_tensors[9:] + + is_packed = ctx.is_packed + max_seqlen_q = ctx.max_seqlen_q + max_seqlen_kv = ctx.max_seqlen_kv + cu_seqlens_half = cu_seqlens_q // 2 + max_seqlen_half = max_seqlen_q // 2 + misc_kwargs = ctx.misc_kwargs + del misc_kwargs["block_table"] + + assert ( + out.shape == dout.shape == q.shape + ), f"out {out.shape} and dout {dout.shape} should have the same shape ({q.shape})." + + if is_packed: + t, h, d = q.shape + else: + b, sq, h, d = q.shape + t = b * sq + q, k, v, out, dout = [x.view(t, *x.shape[-2:]) for x in (q, k, v, out, dout)] + + # Sequence parallel args + sp_group = ctx.sp_group + local_kv_group = ctx.kv_group + inter_kv_group = ctx.inter_kv_group + + local_sp_rank = dist.get_rank(sp_group) + sp_size = dist.get_world_size(sp_group) + # Using separate streams (pg) for concurrent kv and dkv comm may + # cause NCCL "software caused connection abort" here... + local_kv_comm = RingComm(local_kv_group) + local_dkv_comm = RingComm(local_kv_group) + inter_kv_comm = RingComm(inter_kv_group) + inter_dkv_comm = RingComm(inter_kv_group) + local_sp_size = dist.get_world_size(local_kv_group) + local_sp_rank = dist.get_rank(local_kv_group) + + if dist.get_world_size(inter_kv_group) != sp_size: + num_rings = dist.get_world_size(inter_kv_group) + inter_ring_rank = dist.get_rank(inter_kv_group) + else: + num_rings = 1 + inter_ring_rank = 0 + + if local_sp_rank != sp_size - 1: + softmax_lse1 = softmax_lse[:, half_idx_back] + dout = dout.contiguous() + + # Double comm buffers for sending and receiving kv + kv_buffers = [torch.stack((k, v))] # (2, T, H, D) + kv_buffers.append(torch.empty_like(kv_buffers[0])) + + dq = None # (T, H, D) + # Intermediate outputs + dq_block = torch.empty_like(q) # (T, H, D) + dk_block = torch.empty_like(k) # (T, H, D) + dv_block = torch.empty_like(v) # (T, H, D) + dkv_buffers = [torch.empty_like(kv, dtype=torch.float32) for kv in kv_buffers] # (T, H, D) + del k, v + + def _backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, rng_state, causal): + _flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q if dq.shape[0] == t else cu_seqlens_half, + cu_seqlens_kv if dk.shape[0] == t else cu_seqlens_half, + max_seqlen_q if dq.shape[0] == t else max_seqlen_half, + max_seqlen_kv if dk.shape[0] == t else max_seqlen_half, + causal=causal, + rng_state=rng_state, + **misc_kwargs, + ) + + # NOTE: We avoid using two streams due to doubled buffers + # and that backward is more communication intensive. + def _local_ring_backward(): + for i in range(local_sp_size): + if i > 0: + local_kv_comm.wait() + + if i < local_sp_size - 1: + # Send kv to next rank for backward + local_kv_comm.send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) + + if i == 0: + # Backward with local kv + k_, v_ = kv_buffers[i % 2] + q_, dout_, out_ = q, dout, out + dq_, dk_, dv_ = dq_block, dk_block, dv_block + _backward(dout_, q_, k_, v_, out_, softmax_lse, dq_, dk_, dv_, rng_states[i], causal=True) + + elif i <= local_sp_rank: + # Drop the second half of kv + # (T, H, D) -> (T // 2, H, D) + k_, v_ = [x[half_idx_front] for x in kv_buffers[i % 2]] + dk_, dv_ = [x[: t // 2] for x in (dk_block, dv_block)] + dq_, q_, out_, dout_ = (dq_block, q, out, dout) + _backward(dout_, q_, k_, v_, out_, softmax_lse, dq_, dk_, dv_, rng_states[i], causal=False) + + else: + # Drop the first half of q + k_, v_ = kv_buffers[i % 2] + dk_, dv_ = dk_block, dv_block + q_, out_, dout_ = [x[half_idx_back] for x in (q, out, dout)] + dq_ = dq_block[: t // 2] + _backward(dout_, q_, k_, v_, out_, softmax_lse1, dq_, dk_, dv_, rng_states[i], causal=False) + + # Accumulate grads + if i == 0: + dq = dq_block.float() + dkv_buffers[i % 2][0] = dk_block.float() + dkv_buffers[i % 2][1] = dv_block.float() + else: + # Accumulate local dq + if i <= local_sp_rank: + dq += dq_ # (T, H, D) + else: + dq[half_idx_back] += dq_ + + # Wait for mobile kv grad accumulators + local_dkv_comm.wait() + + if i <= local_sp_rank: + # q blocks "surrounded" by kv blocks + dkv_buffers[i % 2][0][half_idx_front] += dk_ + dkv_buffers[i % 2][1][half_idx_front] += dv_ + else: + # q blocks "surrounding" kv blocks + dkv_buffers[i % 2][0] += dk_ + dkv_buffers[i % 2][1] += dv_ + local_dkv_comm.send_recv(send_tensor=dkv_buffers[i % 2], recv_tensor=dkv_buffers[(i + 1) % 2]) + + local_dkv_comm.wait() + dkv_recv = dkv_buffers[local_sp_size % 2] + dkv_send = dkv_buffers[(local_sp_size - 1) % 2] + return dq, dkv_recv, dkv_send + + def _other_ring_backward(ring_num_idx, dq): + if ring_num_idx > inter_ring_rank: + # Indexing is expensive + q_, out_, dout_ = [x[half_idx_back] for x in (q, out, dout)] + else: + q_, out_, dout_ = (q, out, dout) + + for i in range(local_sp_size): + if i > 0: + local_kv_comm.wait() + + if i < local_sp_size - 1: + local_kv_comm.send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) + + rng_state = rng_states[i + local_sp_size * ring_num_idx] + if ring_num_idx > inter_ring_rank: + k_, v_ = kv_buffers[i % 2] + dk_, dv_ = dk_block, dv_block + dq_ = dq_block[: t // 2] + _backward(dout_, q_, k_, v_, out_, softmax_lse1, dq_, dk_, dv_, rng_state, causal=False) + + dq[half_idx_back] += dq_ + if i > 0: + local_dkv_comm.wait() + else: + inter_dkv_comm.wait() + + dkv_buffers[i % 2][0] += dk_ + dkv_buffers[i % 2][1] += dv_ + else: + k_, v_ = [x[half_idx_front] for x in kv_buffers[i % 2]] + dk_, dv_ = [x[: t // 2] for x in (dk_block, dv_block)] + dq_ = dq_block + _backward(dout_, q_, k_, v_, out_, softmax_lse, dq_, dk_, dv_, rng_state, causal=False) + + dq += dq_ + if i > 0: + local_dkv_comm.wait() + else: + inter_dkv_comm.wait() + + dkv_buffers[i % 2][0][half_idx_front] += dk_ + dkv_buffers[i % 2][1][half_idx_front] += dv_ + + local_dkv_comm.send_recv(send_tensor=dkv_buffers[i % 2], recv_tensor=dkv_buffers[(i + 1) % 2]) + + local_dkv_comm.wait() + dkv_recv = dkv_buffers[local_sp_size % 2] + dkv_send = dkv_buffers[(local_sp_size - 1) % 2] + return dq, dkv_recv, dkv_send + + inter_ring_kv = None + for ring_num_idx in range(num_rings): + if ring_num_idx > 0: + inter_kv_comm.wait() + kv_buffers[0] = inter_ring_kv + + if ring_num_idx < num_rings - 1: + # Re-allocate a buffer in each inter-ring step + inter_ring_kv = inter_kv_comm.send_recv(kv_buffers[0]) + + if ring_num_idx == 0: + dq, dkv_recv, dkv_send = _local_ring_backward() + else: + dq, dkv_recv, dkv_send = _other_ring_backward(ring_num_idx, dq) + + if num_rings > 1: + # Reuse the local buffers + inter_dkv_comm.send_recv(send_tensor=dkv_recv, recv_tensor=dkv_send) + # Reset indices + dkv_buffers[0] = dkv_send + dkv_buffers[1] = dkv_recv + if ring_num_idx == num_rings - 1: + inter_dkv_comm.wait() + dkv_recv = dkv_buffers[0] + + dq, dk, dv = [x.to(q.dtype) for x in (dq, *dkv_recv)] + if not is_packed: + dq, dk, dv = [x.view(b, sq, *x.shape[-2:]) for x in (dq, dk, dv)] + + return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None) + + @staticmethod + def prepare_varlen_batch( + attention_mask: torch.Tensor, + sp_group: dist.ProcessGroup, + inputs_embeds: torch.Tensor = None, + position_ids: Optional[torch.Tensor] = None, + is_label: bool = False, + is_2d: bool = True, + ): + """ + Preprocess a batch of padded sequence by splitting input sequence by sp_size + sequence-wise and packing them into one sequence. Updates the mask info accordingly. + Args: + attention_mask (torch.Tensor): Contains the mask [B, Sq], where True means the token is NOT masked. + sp_group (dist.ProcessGroup): Process group for sequence parallelism + inputs_embeds (torch.Tensor): Input embeddings. Shape should be [B, Sq, ...] + position_ids (Optional[torch.Tensor], optional): Position ids of shape [Sq] or [1, Sq]. Defaults to None. + is_label (bool, optional): Whether inputs_embeds is instead a label tensor. If True, mask out the first + token of each sequence. + is_2d (bool, optional): Whether to return 2D outputs padded to max_seqlen // sp_size or flatten + the batch dim to a packed 1d sequence. Contingent on model forward shape definitions. + + Returns: + inputs_embeds: Packed input embeddings of shape [B, Sq // sp_size, ...]. + mask_info: A dictionary of mask info. + position_ids: Packed position ids of shape [..., Sq // sp_size]. + + """ + _load_varlen_helpers() + sp_size = dist.get_world_size(group=sp_group) + sp_rank = dist.get_rank(group=sp_group) + mask_info = {} + mask_info["max_seqlen"], mask_info["cu_seqlens"] = get_pad_info(attention_mask, return_indices=False) + + # Unpad, split seq-wise, then pad back to (B, max_seqlen // sp_size) + # Split mask to compute local nonzero position indices + # (B, Sq) -> (B, max_seqlen // sp_size) + attention_mask = attention_mask[:, : mask_info["max_seqlen"]] + if inputs_embeds is not None: + inputs_embeds = inputs_embeds[:, : mask_info["max_seqlen"]] + inputs_embeds = split_varlen_zigzag( + inputs_embeds, + mask_info["cu_seqlens"], + sp_group, + mask_info["max_seqlen"], + is_2d=is_2d, + is_label=is_label, + ) + attention_mask = split_varlen_zigzag( + attention_mask, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_2d=is_2d + ) + + if position_ids is not None: + indices = torch.tensor([sp_rank, 2 * sp_size - sp_rank - 1], device=inputs_embeds.device) + position_ids = ( + position_ids[..., : mask_info["max_seqlen"]] # unpad + .view(-1, sp_size * 2, mask_info["max_seqlen"] // (sp_size * 2)) + .index_select(-2, indices) + .view(-1, mask_info["max_seqlen"] // sp_size) + ) + + mask_info["max_seqlen"] //= sp_size + mask_info["valid_indices"] = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + mask_info["cu_seqlens"] //= sp_size + mask_info["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL + return inputs_embeds, mask_info, position_ids diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 37c754241..020e793af 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -200,9 +200,7 @@ class Linear1D_Col(ParallelModule): # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - if self.seq_parallel_mode is None: - output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) - elif self.seq_parallel_mode == "split_gather": + if self.seq_parallel_mode == "split_gather": input_parallel = gather_forward_reducescatter_backward( input_parallel, self.process_group, self.seq_parallel_dim ) @@ -211,6 +209,8 @@ class Linear1D_Col(ParallelModule): output_parallel = linear_gather_forward_reducescatter_backward( input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True ) + else: + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) if self.gather_output: # All-gather across the partitions. @@ -416,10 +416,7 @@ class Linear1D_Row(ParallelModule): handle.wait() output = torch.cat(output_parallel_list, dim=-1) else: - if self.seq_parallel_mode is None: - output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) - output = reduce_forward(output_parallel, self.process_group) - elif self.seq_parallel_mode == "split_gather": + if self.seq_parallel_mode == "split_gather": output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) output = reducescatter_forward_gather_backward( output_parallel, self.process_group, self.seq_parallel_dim @@ -432,6 +429,9 @@ class Linear1D_Row(ParallelModule): dim=self.seq_parallel_dim, ring=True, ) + else: + output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) + output = reduce_forward(output_parallel, self.process_group) if not self.skip_bias_add: if self.bias is not None: diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index cea2da03f..12df824d1 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -4,10 +4,15 @@ from torch.autograd import Function from torch.distributed import ProcessGroup from torch.nn import CrossEntropyLoss +from colossalai.shardformer.layer._operation import reduce_forward from colossalai.shardformer.shard import ShardConfig +from .utils import is_share_sp_tp + __all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"] +_IGNORE_IDX = -100 + class DistCrossEntropy(Function): r""" @@ -26,11 +31,12 @@ class DistCrossEntropy(Function): process_group: ProcessGroup, vocab_size: int, dtype=torch.float32, + mode="mean", ): r""" Calculate the cross entropy loss before gather, the origin loss function is as follows: loss = -log(exp(x[class])/sum(exp(x[i])) - and can be rewrite as: + and can be rewriten as: loss = log(sum(exp(x[i])) - x[class] To avoid the `nan` of log(sum(exp(x[i]))), we minus the max of x[i] @@ -44,12 +50,10 @@ class DistCrossEntropy(Function): Returns: :class:`torch.Tensor`: The cross entropy loss """ + assert mode in ["mean", "sum"] # get the max logits_max = torch.max(vocab_logits, dim=-1)[0] - dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group) - - # minus the max to avoid the result of sum of exp is too large and the log is nan - vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1) + handle = dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group, async_op=True) # mask the target in the local device rank = dist.get_rank(group=process_group) @@ -70,24 +74,25 @@ class DistCrossEntropy(Function): mask = (target < down_threshold) | (target >= up_threshold) masked_target = target.clone() - down_threshold masked_target[mask] = 0 + masked_target_1d = masked_target.view(-1).contiguous() + # minus the max to avoid the result of sum of exp is too large and the log is nan + handle.wait() + vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1) # reshape the logits and target # reshape the vocab_logits to [bath_size * seq_len, vocab_size] # reshape the labels to [bath_size * seq_len] self_vocab_size = vocab_logits.size()[-1] logits_2d = vocab_logits.view(-1, self_vocab_size) - masked_target_1d = masked_target.view(-1) # extract the x[class] and set the x[other device] to zero - pred_logits_1d = logits_2d[ - torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device), masked_target_1d - ] - pred_logits_1d = pred_logits_1d.clone().contiguous() + idx = torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device) + pred_logits_1d = logits_2d[idx, masked_target_1d].contiguous() pred_logits = pred_logits_1d.view_as(target) pred_logits[mask] = 0.0 - # allreduce the get all x(i,y) - dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group) + # all-reduce to get full x[i, y] + handle = dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group, async_op=True) exp_logits = vocab_logits torch.exp(vocab_logits, out=exp_logits) sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32) @@ -95,23 +100,29 @@ class DistCrossEntropy(Function): # calculate the loss # loss = log(sum(exp(x[i]))) - x[class] + handle.wait() loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits) - num_non_zero = torch.sum(loss != 0.0) - ctx.inv_num_non_zero = 1.0 / num_non_zero - loss = torch.sum(loss).div_(num_non_zero) + if mode == "mean": + num_non_zero = torch.sum(loss != 0.0) + ctx.inv_num_non_zero = 1.0 / num_non_zero + loss = torch.sum(loss).div_(num_non_zero) + else: + loss = torch.sum(loss) # calculate the softmax exp_logits = exp_logits.div(sum_exp_logits.unsqueeze(dim=-1)).to(dtype) exp_logits[target == ignore_index] = 0.0 ctx.save_for_backward(exp_logits, mask, masked_target_1d) ctx.dtype = dtype + ctx.mode = mode return loss @staticmethod def backward(ctx, grad_output): # retrieve the saved tensors - grad_output = grad_output * ctx.inv_num_non_zero + if ctx.mode == "mean": + grad_output = grad_output * ctx.inv_num_non_zero exp_logits, mask, masked_target_1d = ctx.saved_tensors # use exp logits as the input grad @@ -123,55 +134,113 @@ class DistCrossEntropy(Function): grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update grad_logits.mul_(grad_output.unsqueeze(dim=-1)) - return grad_logits, None, None, None, None, None + return grad_logits, None, None, None, None, None, None def cross_entropy_1d( vocab_logits: torch.Tensor, labels: torch.Tensor, - ignore_index: int = -100, + ignore_index: int = _IGNORE_IDX, process_group: ProcessGroup = None, vocab_size: int = None, dtype: torch.dtype = None, + mode: str = "mean", ) -> torch.Tensor: - return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype) + return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype, mode) def dist_cross_entropy( - labels: torch.Tensor, - logits: torch.Tensor, + labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] + logits: torch.Tensor, # [B, S, Vocab_size] shard_config: ShardConfig, out_features: int, vocab_size: int, dtype: torch.dtype, + seq_dim: int = 1, ) -> torch.Tensor: """ - Helper to compute cross entropy loss for most shardformer models, - compatible with PP, TP and SP. + Helper to compute cross entropy loss for most shardformer models supporting PP, TP and SP. """ - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_labels = shift_labels.view(-1) - shift_labels = shift_labels.to(shift_logits.device) - if shard_config.enable_tensor_parallelism and shard_config.parallel_output: - # Cross entropy with all-reduce for TP - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=out_features, - dtype=dtype, - ) - else: - # NOTE if use TP and not parallel_output, the output is gathered. - # see VocabParallelLMHead1D - shift_logits = shift_logits.view(-1, vocab_size) - loss = loss_fct(shift_logits, shift_labels) + # Split labels if not gather output + sp_group = shard_config.sequence_parallel_process_group + sp_rank = dist.get_rank(sp_group) + sp_size = shard_config.sequence_parallel_size + sp_mode = shard_config.sequence_parallelism_mode + parallel_output = shard_config.parallel_output + is_tp = shard_config.enable_tensor_parallelism + is_packed = labels.dim() == 2 + if is_packed: + bs, seq_len = labels.shape + else: + # padded sequence + seq_len = labels.shape[-1] + logits = logits.reshape(-1, *logits.shape[2:]) + seq_dim = 0 - return loss + # Shift labels to predict the next token, and remove the tail logit predicting + is_sp = sp_size > 1 and (not is_share_sp_tp(sp_mode)) + split_labels_here = seq_len // sp_size == logits.size(seq_dim) # ring attn splits labels before forward + + if sp_mode == "ring_attn": + # For Zigzag Ring Attention, labels should've been split and + # shifted by RingAttention.prepare_varlen_batch() + if sp_rank == 0: + logits = logits[..., :-1, :] + logits = torch.cat([logits, torch.full_like(logits[:, :1, :], _IGNORE_IDX)], dim=seq_dim) + elif is_sp: + # Shift only once: either before splitting or in the last rank without splitting + if split_labels_here or (sp_rank == sp_size - 1): + labels = labels[..., 1:] + if split_labels_here: + labels = labels.split(seq_len // sp_size, dim=-1)[sp_rank] + + if sp_rank == sp_size - 1: + logits = logits[..., :-1, :] + # Pad logits and labels to the same shape across all ranks for TP all_reduce + if is_tp and parallel_output: + # If is packed sequence (label dim is 1), then each seq already has the end label token padded. + # torch.cat is faster than F.pad... + pad_shape = (logits.shape[0], 1, *logits.shape[2:]) if is_packed else (1, *logits.shape[1:]) + padding = torch.full(pad_shape, _IGNORE_IDX, dtype=logits.dtype, device=logits.device) + logits = torch.cat([logits, padding], dim=seq_dim) + pad_shape = (labels.shape[0], 1) if is_packed else (1,) + padding = torch.full(pad_shape, _IGNORE_IDX, dtype=labels.dtype, device=labels.device) + labels = torch.cat([labels, padding], dim=seq_dim) + else: + labels = labels[..., 1:] + logits = logits[..., :-1, :] + labels = labels.contiguous() + logits = logits.contiguous() + num_nonzero = (labels != _IGNORE_IDX).sum() + assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}" + + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=_IGNORE_IDX, reduction="sum") + labels = labels.view(-1) + + if is_tp and parallel_output: + # Cross entropy with all-reduce for TP + new_vocab_size = logits.shape[-1] + logits = logits.view(-1, new_vocab_size) + loss = cross_entropy_1d( + logits, + labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=out_features, + dtype=dtype, + mode="sum", + ) + else: + # NOTE if use TP and not parallel_output, the output is gathered in VocabParallelLMHead1D + logits = logits.view(-1, vocab_size) + loss = loss_fct(logits, labels) + + # Reduce loss instead of gathering logits over seq dim for savings + if split_labels_here or sp_mode == "ring_attn": + # Get the global non-zero count + loss = torch.stack((loss, num_nonzero)) + # Rescale to offset the grad / (DP * SP) in HybridParallelPlugin + loss = reduce_forward(loss, sp_group, grad_scale=sp_size) + loss, num_nonzero = loss[0], loss[1].detach() + loss = (loss / num_nonzero).squeeze() + return loss diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 9c6ced445..c1a73ce05 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from typing import List +from typing import List, Optional, Union import torch import torch.distributed as dist @@ -289,3 +289,199 @@ def create_randomizer_with_offset( Randomizer.increment_index() return Randomizer(seed=base_seed) + + +def split_batch_zigzag( + batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim: int = 1, is_label: bool = False +) -> Union[torch.Tensor, List[torch.Tensor]]: + """ + Split the input along the sequence dimension for Ring Attention. Naively spliting the attention mask + in the causal setting will result in the preceding ranks having much less workload. + We split after "folding" the 2D attention mask in half (https://github.com/zhuzilin/ring-flash-attention/issues/2). + For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |. + + Args: + batch (List[torch.Tensor] or Tensor): The input tensor(s) to split. + sp_group (ProcessGroup): The process group for sequence parallelism. + seq_dim (int): The sequence dimension to split. + is_label (bool): If True, mask and shift the tensor for next token prediction. + + """ + sp_size = dist.get_world_size(sp_group) + sp_rank = dist.get_rank(sp_group) + if isinstance(batch, torch.Tensor): + batch = [batch] + seq_dim = seq_dim if seq_dim != -1 else batch[0].dim() - 1 + + if sp_size > 1: + for idx, tensor in enumerate(batch): + assert ( + tensor.shape[seq_dim] // (sp_size * 2) > 1 and tensor.shape[seq_dim] % (sp_size * 2) == 0 + ), f"Bro, the seq length {tensor.shape[seq_dim]} for tensor {idx} can't be split by {sp_size * 2}!" + if is_label: + assert tensor.dim() == 2, "Label shape should be (B, Seqlen)" + tensor = torch.cat([tensor[:, 1:], torch.full_like(tensor[:, :1], -100)], dim=1) + + tensor = tensor.view( + *tensor.shape[:seq_dim], + 2 * sp_size, + tensor.shape[seq_dim] // (2 * sp_size), + *tensor.shape[seq_dim + 1 :], + ) + indices = torch.tensor([sp_rank, 2 * sp_size - 1 - sp_rank], device=tensor.device) + tensor = tensor.index_select(seq_dim, indices).contiguous() + # (B, 2, Sq // (2 * sp_size), ...) -> (B, Sq // sp_size, ...) + batch[idx] = tensor.view(*tensor.shape[:seq_dim], -1, *tensor.shape[seq_dim + 2 :]) + + if len(batch) == 1: + return batch[0] + return batch + + +def split_varlen_zigzag( + batch: Union[List[torch.Tensor], torch.Tensor], + cu_seqlens: torch.Tensor, + sp_group: ProcessGroup, + max_seqlen: int = 0, + is_2d: bool = False, + is_label: bool = False, +) -> Union[List[torch.Tensor], torch.Tensor]: + """Split each sequence in a batch of packed sequences in a zigzag fashion. + For each tensor in batch, return packed sequences if is_2d is False; + else return a padded batch of sequences. + + Args: + batch (List[torch.Tensor]): Packed sequences of shape (B * Sq, ...), or (B, Sq, ...) if is_2d. + cu_seqlens (torch.Tensor): Cumulative sequence lengths of shape (B + 1) before splitting. + sp_group (ProcessGroup): The process group for sequence parallelism. + max_seqlen (int): The maximum sequence length in the batch before splitting. + is_2d (bool): If True, then input has batch size and sequence length split into two dimensions. + is_label (bool): If True, mask out the first token in each sequence (). + + Returns: + batch (List[torch.Tensor]): Packed sequences of shape (B * max_seqlen // sp_size) + or (B, max_seqlen // sp_size, ...) if is_2d + """ + sp_size = dist.get_world_size(sp_group) + sp_rank = dist.get_rank(sp_group) + if is_2d: + assert max_seqlen > 0, "max_seqlen must be provided for 2D input" + + if isinstance(batch, torch.Tensor): + batch = [batch] + for i, packed_seq in enumerate(batch): + device = packed_seq.device + dtype = packed_seq.dtype + + if is_2d: + assert max_seqlen % (sp_size * 2) == 0 + # Recreate a padded tensor with the new max seqlen + shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:]) + local_seq = torch.zeros(shape, dtype=dtype, device=device) + else: + total_seqlen = cu_seqlens[-1] + assert ( + total_seqlen % (2 * sp_size) == 0 + ), f"total_seqlen {total_seqlen} must be divisible by 2 * sp_size = {2 * sp_size}" + local_seq = [] + + for j in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[j], cu_seqlens[j + 1] + seqlen = end - start + assert ( + seqlen % (2 * sp_size) == 0 + ), f"batch {i} seq {j}'s length ({seqlen}) must be divisible by 2 * sp_size = {2 * sp_size} for splitting" + + if is_2d: + seq = packed_seq[j][:seqlen] + if is_label: + # Shift one position to the right for next token prediction + seq = torch.cat([seq[1:], torch.tensor([-100], dtype=dtype, device=device)]) + + seq = seq.chunk(2 * sp_size, dim=0) + half = seqlen // sp_size // 2 + local_seq[j][:half] = seq[sp_rank] + local_seq[j][half : seqlen // sp_size] = seq[2 * sp_size - 1 - sp_rank] + else: + seq = packed_seq[start:end] + if is_label: + seq = torch.cat(seq[1:], torch.tensor([-100], dtype=dtype, device=device)) + seq = seq.chunk(sp_size * 2) + local_seq.extend([seq[sp_rank], seq[2 * sp_size - 1 - sp_rank]]) + + if is_2d: + batch[i] = local_seq.contiguous() + else: + batch[i] = torch.cat(local_seq, dim=0) + + if len(batch) == 1: + batch = batch[0] + return batch + + +def is_share_sp_tp(sp_mode: str): + """sp_mode "ring" and "split_gather" use the TP group as SP group + to split both the vocab and sequence, so we must gather the sequence + to correctly get logits at each positions. + """ + return sp_mode in ["ring", "split_gather"] + + +class RingComm: + def __init__(self, process_group: dist.ProcessGroup): + self._process_group = process_group + self._ops = [] + self.rank = dist.get_rank(self._process_group) + self.world_size = dist.get_world_size(self._process_group) + self._reqs = [] + + self.send_rank = (self.rank + 1) % self.world_size + self.recv_rank = (self.rank - 1) % self.world_size + + self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) + self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) + + def send_recv( + self, + send_tensor: torch.Tensor, + recv_tensor: Optional[torch.Tensor] = None, + commit: bool = True, + ) -> torch.Tensor: + if recv_tensor is None: + res = torch.empty_like(send_tensor) + else: + res = recv_tensor + + # looks like batch_isend_irecv doesn't deadlock even + # when we don't swap send recv ops based on rank + send_op = dist.P2POp(dist.isend, send_tensor, self.send_rank, group=self._process_group) + recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) + self._ops.extend([send_op, recv_op]) + + if commit: + self._reqs = dist.batch_isend_irecv(self._ops) + return res + + def commit(self): + assert len(self._ops) > 0, "No ops to commit" + self._reqs = dist.batch_isend_irecv(self._ops) + + def wait(self): + assert len(self._reqs) > 0, "No requests to wait for" + for req in self._reqs: + req.wait() + self._reqs = [] + self._ops = [] + + +@torch.jit.script +def get_half_index(cu_seqlens, *, front: bool): + index = torch.zeros(cu_seqlens[-1], dtype=torch.bool, device=cu_seqlens.device) + for i in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + if front: + end = (start + end) // 2 + else: + start = (start + end) // 2 + index[start:end] = True + return index diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index 5b36fc7db..67c20eed8 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -26,6 +26,8 @@ from colossalai.shardformer.shard import ShardConfig from ..layer import ColoAttention, dist_cross_entropy +_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"] + class CommandPipelineForwards: """ @@ -349,7 +351,7 @@ class CommandPipelineForwards: return {"hidden_states": hidden_states} -def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): +def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): def forward( self, hidden_states: torch.Tensor, @@ -362,7 +364,7 @@ def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: if sp_mode is not None: - assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode" + assert sp_mode in _SUPPORTED_SP_MODE, f"SP mode {sp_mode} is not supported by {type(self)} yet" assert (sp_size is not None) and ( sp_group is not None ), "Must specify sp_size and sp_group for sequence parallel" @@ -459,7 +461,7 @@ def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None return forward -def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): +def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): logger = logging.get_logger(__name__) def forward( diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 9ffbca517..af610500a 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -1,8 +1,9 @@ import math import warnings -from typing import List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch +import torch.distributed import torch.nn.functional as F import torch.utils.checkpoint from torch import nn @@ -24,14 +25,14 @@ from transformers.models.llama.modeling_llama import ( from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.layer._operation import ( - all_to_all_comm, - gather_forward_split_backward, - split_forward_gather_backward, -) +from colossalai.shardformer.layer import AttnMaskType +from colossalai.shardformer.layer._operation import all_to_all_comm, gather_sp_output, split_forward_gather_backward +from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag from colossalai.shardformer.shard import ShardConfig -from ..layer import ColoAttention, dist_cross_entropy +from ..layer import ColoAttention, RingAttention, dist_cross_entropy + +_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"] class LlamaPipelineForwards: @@ -57,6 +58,10 @@ class LlamaPipelineForwards: hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, + # Split output only when computing cross entropy using llama_for_causal_lm_forward + # or get_lm_forward_with_dist_cross_entropy + # Default to True to avoid bug when calling classification forward from huggingface + force_sp_output_gather: bool = True, ): logger = logging.get_logger(__name__) @@ -97,7 +102,7 @@ class LlamaPipelineForwards: sp_group = shard_config.sequence_parallel_process_group sp_size = shard_config.sequence_parallel_size if sp_mode == "all_to_all" and not stage_manager.is_first_stage(): - # For correct positions ids. The states will be gather along the seq dim in the attention layer later. + # For generating full positions ids, as the states will be gather along the seq dim in the attention layer later. seq_length *= sp_size past_seen_tokens = 0 @@ -127,22 +132,36 @@ class LlamaPipelineForwards: position_ids = cache_position.unsqueeze(0) # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage - if shard_config.enable_flash_attention: + if not stage_manager.is_first_stage() and sp_mode == "ring_attn": + _, attn_kwargs, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group) + elif shard_config.enable_flash_attention: # in this case, attention_mask is a dict rather than a tensor mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) - attention_mask = ColoAttention.prepare_attn_kwargs( + attn_kwargs = ColoAttention.prepare_attn_kwargs( mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True, + invert=(sp_mode != "ring_attn"), ) else: - attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position) + attn_kwargs = self._update_causal_mask(attention_mask, hidden_states, cache_position) # Support SP + PP + # TODO: support padded casual cu_seqlens across stages if stage_manager.is_first_stage(): - if sp_mode in ["ring", "split_gather"]: + # Ring Attention zigzag batch processing + if sp_mode == "ring_attn": + assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention." + if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: + hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch( + attention_mask, sp_group, hidden_states, position_ids + ) + else: + hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group) + + elif is_share_sp_tp(sp_mode): hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group) elif sp_mode == "all_to_all": hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size) @@ -177,12 +196,11 @@ class LlamaPipelineForwards: for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): if output_hidden_states: all_hidden_states += (hidden_states,) - if idx - start_idx < num_ckpt_layers: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + attn_kwargs, position_ids, past_key_values, output_attentions, @@ -192,14 +210,13 @@ class LlamaPipelineForwards: else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=attn_kwargs, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, ) - hidden_states = layer_outputs[0] if use_cache: @@ -209,10 +226,8 @@ class LlamaPipelineForwards: if stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) - if sp_mode == "ring" or sp_mode == "split_gather": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) - elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): + hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode) # add hidden states from the last decoder layer if output_hidden_states: @@ -298,6 +313,15 @@ class LlamaPipelineForwards: logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False + if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output: + # Split labels in a zigzag fashion too + sp_group = shard_config.sequence_parallel_process_group + if attention_mask.bool().all(): + labels = split_batch_zigzag(labels, sp_group, seq_dim=1) + else: + # [B, max_seqlen // sp_size] + labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True) + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = LlamaPipelineForwards.llama_model_forward( self.model, @@ -315,6 +339,7 @@ class LlamaPipelineForwards: hidden_states=hidden_states, stage_index=stage_index, shard_config=shard_config, + force_sp_output_gather=False, ) past_key_values = None @@ -457,11 +482,11 @@ class LlamaPipelineForwards: return {"hidden_states": hidden_states} -def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): +def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[Union[torch.Tensor, Dict]] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, @@ -470,7 +495,7 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: if sp_mode is not None: - assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode" + assert sp_mode in _SUPPORTED_SP_MODE, f"SP mode {sp_mode} is not supported by {type(self)} yet" assert (sp_size is not None) and ( sp_group is not None ), "Must specify sp_size and sp_group for sequence parallel" @@ -481,7 +506,7 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, bsz, q_len, _ = hidden_states.size() # sp: modify sp_len when sequence parallel mode is ring - if sp_mode in ["split_gather", "ring"]: + if is_share_sp_tp(sp_mode): q_len *= sp_size if self.config.pretraining_tp > 1: @@ -526,6 +551,7 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -537,12 +563,21 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - if shard_config.enable_flash_attention: + if sp_mode == "ring_attn": + attn_output = RingAttention.attention( + query_states, + key_states, + value_states, + sp_group, + **attention_mask, + inner_ring_size=shard_config.inner_ring_size, + ) + + elif shard_config.enable_flash_attention: assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) else: attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" @@ -588,7 +623,7 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, return forward -def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): +def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): logger = logging.get_logger(__name__) def forward( @@ -603,6 +638,10 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size= output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + # Split output only when computing cross entropy using llama_for_causal_lm_forward + # or get_lm_forward_with_dist_cross_entropy + # Default to True to avoid bug when calling classification forward from huggingface + force_sp_output_gather: bool = True, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -629,32 +668,45 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size= past_seen_tokens = 0 seq_len = inputs_embeds.shape[1] + batch_size = inputs_embeds.shape[0] if use_cache: # kept for BC (cache positions) if not isinstance(past_key_values, StaticCache): past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_seen_tokens = past_key_values.get_seq_length() + if cache_position is None: if isinstance(past_key_values, StaticCache): raise ValueError("cache_position is a required argument when using StaticCache.") cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device) - if position_ids is None: position_ids = cache_position.unsqueeze(0) - # in this case, attention_mask is a dict rather than a tensor if shard_config.enable_flash_attention: - mask_shape = (inputs_embeds.shape[0], 1, seq_len, past_seen_tokens + seq_len) - attention_mask = ColoAttention.prepare_attn_kwargs( + mask_shape = (batch_size, 1, seq_len, past_seen_tokens + seq_len) + attn_kwargs: dict = ColoAttention.prepare_attn_kwargs( mask_shape, inputs_embeds.dtype, inputs_embeds.device, q_padding_mask=attention_mask, is_causal=True, + invert=(sp_mode != "ring_attn"), ) - else: - attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) - if sp_mode in ["ring", "split_gather"]: + else: + attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + + # Ring Attention zigzag batch processing + if sp_mode == "ring_attn": + assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention." + if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: + inputs_embeds, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch( + attention_mask, sp_group, inputs_embeds, position_ids + ) + else: + inputs_embeds, position_ids = split_batch_zigzag([inputs_embeds, position_ids], sp_group) + attn_kwargs = {"attention_mask_type": attn_kwargs["attention_mask_type"]} # drop redundant tensors + + elif is_share_sp_tp(sp_mode): inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) elif sp_mode == "all_to_all": inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) @@ -672,7 +724,7 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size= layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + attn_kwargs, position_ids, past_key_values, output_attentions, @@ -683,7 +735,7 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size= else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=attn_kwargs, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, @@ -700,11 +752,9 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size= all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) - - if sp_mode == "ring" or sp_mode == "split_gather": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) - elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + # Cases that don't support parallelizing cross entropy computation along sequence + if (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather: + hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode) # add hidden states from the last decoder layer if output_hidden_states: @@ -777,6 +827,15 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output: + # Special processing: Split labels in a zigzag fashion too + sp_group = shard_config.sequence_parallel_process_group + if attention_mask.bool().all(): + labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True) + else: + # [B, max_seq_len // sp_size] + labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True) + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, @@ -789,6 +848,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + force_sp_output_gather=False, ) hidden_states = outputs[0] @@ -799,7 +859,6 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): else: logits = self.lm_head(hidden_states) logits = logits.float() - loss = dist_cross_entropy( labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype ) diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 282cf0464..7c1e6f0d7 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -75,6 +75,7 @@ class Policy(ABC): def __init__(self) -> None: self.shard_config: Optional[ShardConfig] = None self.model: Optional[Module] = None + self.is_causal = None # Whether we're doing causal lm, i.e. using cross entropy def set_model(self, model: nn.Module) -> None: r""" diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index a9b915d10..1efd3d017 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -69,13 +69,18 @@ class CommandPolicy(Policy): sp_size = self.shard_config.sequence_parallel_size or None sp_group = self.shard_config.sequence_parallel_process_group or None sp_partial_derived = sp_mode in ["split_gather", "ring"] + if sp_mode == "ring_attn" and not self.is_causal: + raise ValueError("Ring attention is only meant for causal language modeling.") + tp_size = self.shard_config.tensor_parallel_size or None + num_q_heads = self.model.config.num_attention_heads + num_kv_heads = getattr(self.model.config, "num_key_value_heads", None) if sp_mode == "all_to_all": - decoder_attribute_replacement = { - "num_heads": self.model.config.num_attention_heads // sp_size, - } - if getattr(self.model.config, "num_key_value_heads", False): - decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size + num_q_heads //= sp_size + decoder_attribute_replacement = {"num_heads": num_q_heads} + if num_kv_heads: + num_kv_heads //= sp_size + decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads policy[attn_cls] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, @@ -104,21 +109,18 @@ class CommandPolicy(Policy): if self.shard_config.enable_tensor_parallelism: assert ( - self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + num_q_heads % tp_size == 0 ), f"The number of attention heads must be divisible by tensor parallel size." if hasattr(self.model.config, "num_key_value_heads"): assert ( - self.model.config.num_key_value_heads >= self.shard_config.tensor_parallel_size - and self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0 + num_kv_heads >= tp_size and num_kv_heads % tp_size == 0 ), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size." decoder_attribute_replacement = { - "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attn.hidden_size": self.model.config.hidden_size // tp_size, + "self_attn.num_heads": num_q_heads // tp_size, } if getattr(self.model.config, "num_key_value_heads", False): - decoder_attribute_replacement["self_attn.num_key_value_heads"] = ( - self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size - ) + decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads // tp_size policy[CohereDecoderLayer] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, @@ -290,10 +292,11 @@ class CommandForCausalLMPolicy(CommandPolicy): def module_policy(self): from transformers import CohereForCausalLM + self.is_causal = True policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: - # add a new item for casual lm + # add a new item for causal lm new_item = { CohereForCausalLM: ModulePolicyDescription( sub_module_replacement=[ diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py index 605f69c4a..ea68649d5 100644 --- a/colossalai/shardformer/policies/deepseek.py +++ b/colossalai/shardformer/policies/deepseek.py @@ -298,7 +298,7 @@ class DeepseekForCausalLMPolicy(DeepseekPolicy): policy = super().module_policy() # TODO: assign pg mesh from plugin to all modules if self.shard_config.enable_tensor_parallelism: - # add a new item for casual lm + # add a new item for causal lm new_item = { "DeepseekForCausalLM": ModulePolicyDescription( sub_module_replacement=[ diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 36491b4b5..f72a72df0 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -69,13 +69,20 @@ class LlamaPolicy(Policy): sp_size = self.shard_config.sequence_parallel_size or None sp_group = self.shard_config.sequence_parallel_process_group or None sp_partial_derived = sp_mode in ["split_gather", "ring"] + if sp_mode == "ring_attn" and not self.is_causal: + raise ValueError("Ring attention is only meant for causal language modeling.") + + tp_size = self.shard_config.tensor_parallel_size + # Modified by SP and TP + num_q_heads = self.model.config.num_attention_heads + num_kv_heads = getattr(self.model.config, "num_key_value_heads", None) if sp_mode == "all_to_all": - decoder_attribute_replacement = { - "num_heads": self.model.config.num_attention_heads // sp_size, - } - if getattr(self.model.config, "num_key_value_heads", False): - decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size + num_q_heads //= sp_size + decoder_attribute_replacement = {"num_heads": num_q_heads} + if num_kv_heads: + num_kv_heads //= sp_size + decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads policy[attn_cls] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, @@ -104,21 +111,20 @@ class LlamaPolicy(Policy): if self.shard_config.enable_tensor_parallelism: assert ( - self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + num_q_heads % tp_size == 0 ), f"The number of attention heads must be divisible by tensor parallel size." if hasattr(self.model.config, "num_key_value_heads"): assert ( - self.model.config.num_key_value_heads >= self.shard_config.tensor_parallel_size - and self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0 + num_kv_heads >= tp_size and num_kv_heads % tp_size == 0 ), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size." + num_q_heads //= tp_size decoder_attribute_replacement = { - "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attn.hidden_size": self.model.config.hidden_size // tp_size, + "self_attn.num_heads": num_q_heads, } if getattr(self.model.config, "num_key_value_heads", False): - decoder_attribute_replacement["self_attn.num_key_value_heads"] = ( - self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size - ) + num_kv_heads //= tp_size + decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads policy[LlamaDecoderLayer] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, @@ -295,10 +301,11 @@ class LlamaForCausalLMPolicy(LlamaPolicy): def module_policy(self): from transformers import LlamaForCausalLM + self.is_causal = True policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: - # add a new item for casual lm + # add a new item for causal lm new_item = { LlamaForCausalLM: ModulePolicyDescription( sub_module_replacement=[ @@ -313,10 +320,6 @@ class LlamaForCausalLMPolicy(LlamaPolicy): ], ) } - if self.shard_config.parallel_output: - new_item[LlamaForCausalLM].method_replacement = { - "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) - } else: new_item = { LlamaForCausalLM: ModulePolicyDescription( @@ -336,7 +339,11 @@ class LlamaForCausalLMPolicy(LlamaPolicy): self.set_pipeline_forward( model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy ) - + elif self.shard_config.enable_tensor_parallelism or self.shard_config.enable_sequence_parallelism: + # Compute loss distributedly along the sequence dimension + new_item[LlamaForCausalLM].method_replacement = { + "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) + } return policy def get_held_layers(self) -> List[Module]: diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index c5a0277a5..6ea27e210 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -271,7 +271,7 @@ class MistralForCausalLMPolicy(MistralPolicy): policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: - # add a new item for casual lm + # add a new item for causal lm new_item = { MistralForCausalLM: ModulePolicyDescription( sub_module_replacement=[ diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 10df143c9..e11edae9f 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -275,7 +275,7 @@ class MixtralForCausalLMPolicy(MixtralPolicy): policy = super().module_policy() # TODO: assign pg mesh from plugin to all modules if self.shard_config.enable_tensor_parallelism: - # add a new item for casual lm + # add a new item for causal lm new_item = { MixtralForCausalLM: ModulePolicyDescription( sub_module_replacement=[ diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index 362c14060..235dc7d56 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -313,7 +313,7 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy): setattr(self.shard_config, "causal_lm", True) if self.shard_config.enable_tensor_parallelism: - # add a new item for casual lm + # add a new item for causal lm new_item = { Qwen2ForCausalLM: ModulePolicyDescription( sub_module_replacement=[ diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 163d7a7bb..70eb271c9 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -10,7 +10,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from .grad_ckpt_config import GradientCheckpointConfig __all__ = ["ShardConfig"] -SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"] +SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all", "ring_attn"] @dataclass @@ -29,6 +29,8 @@ class ShardConfig: enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False. gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None. enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. + parallel_output (bool): For TP: whether to use parallelize cross entropy computation along the feature dim. + For SP: set to True to NOT gather the output along the seq dim. """ tensor_parallel_process_group: Optional[ProcessGroup] = None @@ -47,10 +49,11 @@ class ShardConfig: gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None extra_kwargs: Dict[str, Any] = field(default_factory=dict) + # For ring attention + inner_ring_size: Optional[int] = None # for moe related moe_dp_group: Optional[ProcessGroup] = None ep_group: Optional[ProcessGroup] = None - # pipeline_parallel_size: int # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] @@ -80,9 +83,9 @@ class ShardConfig: self.enable_tensor_parallelism ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is True" elif self.sequence_parallelism_mode in ["all_to_all"]: - assert ( - not self.enable_tensor_parallelism - ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is False" + # assert ( + # not self.enable_tensor_parallelism + # ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is False" if self.enable_sequence_overlap: self.enable_sequence_overlap = False warnings.warn( diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index e530e2d6a..093377e7a 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -28,6 +28,7 @@ warnings.filterwarnings("ignore") # Constants # ============================== +# We have lots of llamas for your choice! MODEL_CONFIGS = { "100m": LlamaConfig( max_position_embeddings=4096, @@ -36,6 +37,7 @@ MODEL_CONFIGS = { intermediate_size=2048, hidden_size=1024, ), + "5b": LlamaConfig(max_position_embeddings=4096, num_key_value_heads=8), "7b": LlamaConfig(max_position_embeddings=4096), "13b": LlamaConfig( hidden_size=5120, @@ -68,9 +70,6 @@ def main(): default="gemini", help="Choose which plugin to use", ) - parser.add_argument( - "--overlap", action="store_true", help="Overlap communication with computation in Pipeline Parallel." - ) parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size") parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run") parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore") @@ -94,11 +93,24 @@ def main(): parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"]) parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval) - parser.add_argument("--profile", action="store_true", help="Profile the code", default=False) + parser.add_argument("--profile", action="store_true", help="Profile the code") + parser.add_argument( + "--nsys", + action="store_true", + help="Use nsys for profiling. \ + You should put something like this before colossalai launch: \ + nsys profile -w true -t cuda,cudnn,cublas -s cpu --capture-range=cudaProfilerApi --capture-range-end=stop --cudabacktrace=true -x true --python-backtrace=cuda -o prof_out", + ) parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation") parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") parser.add_argument("--no_cache", action="store_true") parser.add_argument("--overlap_allgather", action="store_true") + parser.add_argument( + "--sp_mode", + default="all_to_all", + choices=["all_to_all", "ring_attn", "ring", "split_gather"], + help="Sequence parallelism mode", + ) args = parser.parse_args() colossalai.launch_from_torch() @@ -195,12 +207,12 @@ def main(): num_model_chunks=args.n_chunks, zero_stage=args.zero, sp_size=args.sp, + sequence_parallelism_mode=args.sp_mode, enable_sequence_parallelism=args.sp > 1, enable_fused_normalization=torch.cuda.is_available(), enable_flash_attention=args.xformers, microbatch_size=args.mbs, precision="bf16", - overlap_p2p=args.overlap, enable_metadata_cache=not args.no_cache, overlap_allgather=args.overlap_allgather, **hybrid_kwargs, @@ -218,7 +230,6 @@ def main(): microbatch_size=args.mbs, initial_scale=2**8, precision="bf16", - overlap_p2p=args.overlap, ) else: raise ValueError(f"Unknown plugin {args.plugin}") @@ -295,6 +306,7 @@ def main(): args.ignore_steps, 1, # avoid creating massive log files save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}", + nsys=args.nsys, ) as prof: if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: data_iter = iter(dataloader) @@ -320,13 +332,16 @@ def main(): performance_evaluator.on_step_start(step) outputs = model(**batch) loss = outputs[0] + del outputs # free memory + + if dist.get_rank() == dist.get_world_size() - 1: + print(f"Step {step} loss: {loss}") booster.backward(loss, optimizer) optimizer.step() optimizer.zero_grad() performance_evaluator.on_step_end(**batch) prof.step() - performance_evaluator.on_fit_end() coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB") diff --git a/examples/language/opt/README.md b/examples/language/opt/README.md index af1e79437..694c5cf91 100644 --- a/examples/language/opt/README.md +++ b/examples/language/opt/README.md @@ -17,7 +17,7 @@ limitations under the License. ## OPT Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments. -The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost. +The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Causal Language Modelling at low cost. ## Our Modifications diff --git a/examples/language/performance_evaluator.py b/examples/language/performance_evaluator.py index ca4a02cd2..f5ad1d23d 100644 --- a/examples/language/performance_evaluator.py +++ b/examples/language/performance_evaluator.py @@ -28,7 +28,7 @@ def all_reduce_mean(x: float, world_size: int) -> float: return tensor.item() -def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir): +def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir, nsys=False): class DummyProfiler: def __init__(self): self.step_number = 0 @@ -42,7 +42,29 @@ def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir): def __exit__(self, exc_type, exc_value, traceback): pass + class NsysProfiler: + def __init__(self, warmup_steps, active_steps): + self.step_number = 0 + self.warmup_steps = warmup_steps + self.active_steps = active_steps + + def step(self): + if self.step_number == self.warmup_steps: + torch.cuda.cudart().cudaProfilerStart() + elif self.step_number == self.warmup_steps + self.active_steps: + torch.cuda.cudart().cudaProfilerStop() + self.step_number += 1 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + pass + if enable_flag: + if nsys: + return NsysProfiler(warmup_steps, active_steps) + return profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), diff --git a/examples/tutorial/opt/opt/README.md b/examples/tutorial/opt/opt/README.md index a01209cbd..3776e0c64 100644 --- a/examples/tutorial/opt/opt/README.md +++ b/examples/tutorial/opt/opt/README.md @@ -19,7 +19,7 @@ limitations under the License. ## OPT Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments. -The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost. +The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning causal Language Modelling at low cost. We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling). diff --git a/extensions/pybind/flash_attention/flash_attention_dao_cuda.py b/extensions/pybind/flash_attention/flash_attention_dao_cuda.py index a108377a8..560d952f6 100644 --- a/extensions/pybind/flash_attention/flash_attention_dao_cuda.py +++ b/extensions/pybind/flash_attention/flash_attention_dao_cuda.py @@ -57,14 +57,14 @@ class FlashAttentionDaoCudaExtension(_Extension): q_indices: Optional[torch.Tensor] = None, kv_indices: Optional[torch.Tensor] = None, ): - # [B, N, S, D] -> [B, S, N, D] + # [B, H, S, D] -> [B, S, H, D] q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) b, s_q = q.shape[:2] if cu_seqlens_q is not None: # padded / padded causal - # unpad input: [B, S, N, D] -> [T, N, D] + # unpad input: [B, S, H, D] -> [T, H, D] q = _unpad_input(q, q_indices) kv = _unpad_input(torch.stack(tensors=(k, v), dim=2), kv_indices) attn_output = flash_attn_varlen_kvpacked_func( @@ -78,7 +78,7 @@ class FlashAttentionDaoCudaExtension(_Extension): softmax_scale=scale, causal=is_causal, ) - # pad output: [T, N, D] -> [B, S, N, D] + # pad output: [T, H, D] -> [B, S, H, D] attn_output = pad_input(attn_output, q_indices, b, s_q) else: # causal / no attn mask @@ -90,7 +90,7 @@ class FlashAttentionDaoCudaExtension(_Extension): softmax_scale=scale, causal=is_causal, ) - # [B, S, N, D] -> [B, N, S, D] + # [B, S, H, D] -> [B, H, S, D] return attn_output.transpose(1, 2) return flash_attention diff --git a/tests/kit/model_zoo/__init__.py b/tests/kit/model_zoo/__init__.py index 66c794a7d..9c1a11e7b 100644 --- a/tests/kit/model_zoo/__init__.py +++ b/tests/kit/model_zoo/__init__.py @@ -22,9 +22,9 @@ COMMON_MODELS = [ "transformers_bloom_for_causal_lm", "transformers_falcon_for_causal_lm", "transformers_chatglm_for_conditional_generation", - "transformers_llama_for_casual_lm", + "transformers_llama_for_causal_lm", "transformers_vit_for_masked_image_modeling", - "transformers_mistral_for_casual_lm", + "transformers_mistral_for_causal_lm", ] IS_FAST_TEST = os.environ.get("FAST_TEST", "0") == "1" diff --git a/tests/kit/model_zoo/transformers/command.py b/tests/kit/model_zoo/transformers/command.py index a8b8842c5..3f4ea4583 100644 --- a/tests/kit/model_zoo/transformers/command.py +++ b/tests/kit/model_zoo/transformers/command.py @@ -32,8 +32,8 @@ if HAS_COMMAND: return dict(input_ids=input_ids, attention_mask=attention_mask) - # label is needed for casual lm - def data_gen_for_casual_lm(): + # label is needed for causal lm + def data_gen_for_causal_lm(): data = data_gen() labels = data["input_ids"].clone() data["labels"] = labels @@ -44,7 +44,7 @@ if HAS_COMMAND: # function to get the loss loss_fn = lambda output: output["last_hidden_state"].mean() - loss_fn_for_casual_lm = lambda output: output["loss"] + loss_fn_for_causal_lm = lambda output: output["loss"] loss_fn_for_seq_classification = lambda output: output["logits"].mean() config = CohereConfig( @@ -70,10 +70,10 @@ if HAS_COMMAND: model_attribute=ModelAttribute(has_control_flow=True), ) model_zoo.register( - name="transformers_command_for_casual_lm", + name="transformers_command_for_causal_lm", model_fn=lambda: transformers.CohereForCausalLM(config), - data_gen_fn=data_gen_for_casual_lm, + data_gen_fn=data_gen_for_causal_lm, output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_casual_lm, + loss_fn=loss_fn_for_causal_lm, model_attribute=ModelAttribute(has_control_flow=True), ) diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 61fa56050..05ac9d8d2 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -33,20 +33,21 @@ if HAS_LLAMA: [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], ] ).long() - - attention_mask = torch.Tensor( - [ - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - ] - ).long() - + attention_mask = torch.ones_like(input_ids) return dict(input_ids=input_ids, attention_mask=attention_mask) - # label is needed for casual lm - def data_gen_for_casual_lm(): + # label is needed for causal lm + def data_gen_for_causal_lm(): data = data_gen() + + # Test padded sequence + padding = torch.zeros(2, data["input_ids"].shape[1] // 2, dtype=torch.long) + data["input_ids"] = torch.cat([data["input_ids"], padding], dim=1) + data["attention_mask"] = torch.cat([data["attention_mask"], padding], dim=1) + + ignore_idx = -100 labels = data["input_ids"].clone() + labels[~data["attention_mask"].bool()] = ignore_idx data["labels"] = labels return data @@ -55,7 +56,7 @@ if HAS_LLAMA: # function to get the loss loss_fn = lambda output: output["last_hidden_state"].mean() - loss_fn_for_casual_lm = lambda output: output["loss"] + loss_fn_for_causal_lm = lambda output: output["loss"] loss_fn_for_seq_classification = lambda output: output["logits"].mean() config = LlamaConfig( @@ -70,9 +71,17 @@ if HAS_LLAMA: config.pad_token_id = config.eos_token_id # register the following models - # transformers.LlamaModel, # transformers.LlamaForCausalLM, + # transformers.LlamaModel, # transformers.LlamaForSequenceClassification, + model_zoo.register( + name="transformers_llama_for_causal_lm", + model_fn=lambda: transformers.LlamaForCausalLM(config), + data_gen_fn=data_gen_for_causal_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_causal_lm, + model_attribute=ModelAttribute(has_control_flow=True), + ) model_zoo.register( name="transformers_llama", model_fn=lambda: transformers.LlamaModel(config), @@ -81,14 +90,6 @@ if HAS_LLAMA: loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True), ) - model_zoo.register( - name="transformers_llama_for_casual_lm", - model_fn=lambda: transformers.LlamaForCausalLM(config), - data_gen_fn=data_gen_for_casual_lm, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_casual_lm, - model_attribute=ModelAttribute(has_control_flow=True), - ) model_zoo.register( name="transformers_llama_for_sequence_classification", model_fn=lambda: transformers.LlamaForSequenceClassification(config), diff --git a/tests/kit/model_zoo/transformers/mistral.py b/tests/kit/model_zoo/transformers/mistral.py index ae5a97002..43fc662cc 100644 --- a/tests/kit/model_zoo/transformers/mistral.py +++ b/tests/kit/model_zoo/transformers/mistral.py @@ -64,7 +64,7 @@ model_zoo.register( model_attribute=ModelAttribute(has_control_flow=True), ) model_zoo.register( - name="transformers_mistral_for_casual_lm", + name="transformers_mistral_for_causal_lm", model_fn=lambda: transformers.MistralForCausalLM(config), data_gen_fn=data_gen_for_lm, output_transform_fn=output_transform_fn, diff --git a/tests/kit/model_zoo/transformers/qwen2.py b/tests/kit/model_zoo/transformers/qwen2.py index 1c26af698..83bc9f941 100644 --- a/tests/kit/model_zoo/transformers/qwen2.py +++ b/tests/kit/model_zoo/transformers/qwen2.py @@ -33,8 +33,8 @@ if HAS_QWEN2: attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long() return dict(input_ids=input_ids, attention_mask=attention_mask) - # label is needed for casual lm - def data_gen_for_casual_lm(): + # label is needed for causal lm + def data_gen_for_causal_lm(): data = data_gen() labels = data["input_ids"].clone() data["labels"] = labels @@ -45,7 +45,7 @@ if HAS_QWEN2: # function to get the loss loss_fn = lambda output: output["last_hidden_state"].mean() - loss_fn_for_casual_lm = lambda output: output["loss"] + loss_fn_for_causal_lm = lambda output: output["loss"] loss_fn_for_seq_classification = lambda output: output["logits"].mean() config = Qwen2Config( @@ -72,11 +72,11 @@ if HAS_QWEN2: model_attribute=ModelAttribute(has_control_flow=True), ) model_zoo.register( - name="transformers_qwen2_for_casual_lm", + name="transformers_qwen2_for_causal_lm", model_fn=lambda: transformers.Qwen2ForCausalLM(config), - data_gen_fn=data_gen_for_casual_lm, + data_gen_fn=data_gen_for_causal_lm, output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_casual_lm, + loss_fn=loss_fn_for_causal_lm, model_attribute=ModelAttribute(has_control_flow=True), ) model_zoo.register( diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py index e57cadfd8..3e8532955 100644 --- a/tests/test_booster/test_plugin/test_3d_plugin.py +++ b/tests/test_booster/test_plugin/test_3d_plugin.py @@ -97,7 +97,7 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True): # TODO(ver217): add more models for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.get_sub_registry( - "transformers_llama_for_casual_lm" + "transformers_llama_for_causal_lm" ).items(): err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index 8c59f430c..c2a08a541 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -105,7 +105,7 @@ def check_low_level_zero_lora(stage, model_name, early_stop: bool = True): sub_model_zoo = model_zoo.get_sub_registry(model_name) for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): task_type = None - if name == "transformers_llama_for_casual_lm": + if name == "transformers_llama_for_causal_lm": task_type = "CAUSAL_LM" if name == "transformers_llama_for_sequence_classification": task_type = "SEQ_CLS" diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index fd13ce0bf..b133be948 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -74,7 +74,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b @clear_cache_before_run() @parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS) @parameterize("shard", [True, False]) -@parameterize("model_name", ["transformers_llama_for_casual_lm"]) +@parameterize("model_name", ["transformers_llama_for_causal_lm"]) @parameterize("size_per_shard", [32]) @parameterize("tp_size", [1, 2]) @parameterize("zero_size", [2]) diff --git a/tests/test_checkpoint_io/test_gemini_torch_compability.py b/tests/test_checkpoint_io/test_gemini_torch_compability.py index 4897907ff..ce4d10322 100644 --- a/tests/test_checkpoint_io/test_gemini_torch_compability.py +++ b/tests/test_checkpoint_io/test_gemini_torch_compability.py @@ -20,7 +20,7 @@ from tests.kit.model_zoo import model_zoo @clear_cache_before_run() @parameterize("shard", [False, True]) -@parameterize("model_name", ["transformers_llama_for_casual_lm"]) +@parameterize("model_name", ["transformers_llama_for_causal_lm"]) def exam_torch_load_from_gemini(shard: bool, model_name: str): (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index 4f8f26041..86d7924fb 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -39,7 +39,7 @@ else: @parameterize("shard", [True, False]) -@parameterize("model_name", ["transformers_llama_for_casual_lm"]) +@parameterize("model_name", ["transformers_llama_for_causal_lm"]) @parameterize("size_per_shard", [32]) @parameterize("test_config", TEST_CONFIGS) @clear_cache_before_run() diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index ab48944d4..a8e05a25a 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -149,7 +149,7 @@ def check_low_level_zero_lora_checkpointIO( if name != "transformers_llama": continue task_type = None - if name == "transformers_llama_for_casual_lm": + if name == "transformers_llama_for_causal_lm": task_type = "CAUSAL_LM" if name == "transformers_llama_for_sequence_classification": task_type = "SEQ_CLS" diff --git a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py index df8636141..6f8eb2ad2 100644 --- a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py +++ b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py @@ -18,7 +18,7 @@ from tests.kit.model_zoo import model_zoo @clear_cache_before_run() -@parameterize("model_name", ["transformers_llama_for_casual_lm"]) +@parameterize("model_name", ["transformers_llama_for_causal_lm"]) @parameterize("plugin_type", ["ddp", "zero", "gemini"]) def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per_shard=32): (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next( diff --git a/tests/test_lora/test_lora.py b/tests/test_lora/test_lora.py index 1ae17025d..b0ec767cc 100644 --- a/tests/test_lora/test_lora.py +++ b/tests/test_lora/test_lora.py @@ -91,7 +91,7 @@ def run_lora_test(): sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): task_type = None - if name == "transformers_llama_for_casual_lm": + if name == "transformers_llama_for_causal_lm": task_type = "CAUSAL_LM" if name == "transformers_llama_for_sequence_classification": task_type = "SEQ_CLS" diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py index a626b834a..04a1296e6 100644 --- a/tests/test_pipeline/test_schedule/test_interleaved.py +++ b/tests/test_pipeline/test_schedule/test_interleaved.py @@ -6,6 +6,7 @@ import pytest import torch import torch.distributed as dist import torch.nn as nn +from torch.testing import assert_close import colossalai from colossalai.cluster import ProcessGroupMesh @@ -107,13 +108,13 @@ def run_pp( # check loss if stage_manager.is_last_stage(ignore_chunk=True): - assert torch.allclose(torch_loss, pp_ret["loss"]) + assert_close(torch_loss, pp_ret["loss"]) # check gradients for i in range(num_model_chunk): idx = world_size * i + rank - assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad) - assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad) + assert_close(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad) + assert_close(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad) # step torch_optimizer.step() @@ -123,8 +124,8 @@ def run_pp( # check updated param for i in range(num_model_chunk): idx = world_size * i + rank - assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight) - assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias) + assert_close(torch_model.layers[idx].weight, sharded_model[i].weight) + assert_close(torch_model.layers[idx].bias, sharded_model[i].bias) # forward only with torch.no_grad(): @@ -135,14 +136,14 @@ def run_pp( sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True ) if stage_manager.is_last_stage(ignore_chunk=True): - assert torch.allclose(torch_loss, pp_ret["loss"]) + assert_close(torch_loss, pp_ret["loss"]) for layer in sharded_model: if layer.weight.grad is None: assert layer.weight.grad is None and layer.bias.grad is None else: - assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad)) - assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad)) + assert_close(layer.weight.grad, torch.zeros_like(layer.weight.grad)) + assert_close(layer.bias.grad, torch.zeros_like(layer.bias.grad)) @pytest.mark.dist diff --git a/tests/test_pipeline/test_schedule/test_oneF_oneB.py b/tests/test_pipeline/test_schedule/test_oneF_oneB.py index c4bfa7b69..8ae4f6daa 100644 --- a/tests/test_pipeline/test_schedule/test_oneF_oneB.py +++ b/tests/test_pipeline/test_schedule/test_oneF_oneB.py @@ -6,6 +6,7 @@ import pytest import torch import torch.distributed as dist import torch.nn as nn +from torch.testing import assert_close import colossalai from colossalai.cluster import ProcessGroupMesh @@ -103,13 +104,13 @@ def examine_pp(num_microbatch: int, batch_size: int): # check loss if stage_manager.is_last_stage(): - assert torch.allclose(torch_loss, pp_ret["loss"]) + assert_close(torch_loss, pp_ret["loss"]) # check gradients for i in range(len(sharded_model)): idx = rank * num_local_layer + i - assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad) - assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad) + assert_close(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad) + assert_close(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad) # step torch_optimizer.step() @@ -119,8 +120,8 @@ def examine_pp(num_microbatch: int, batch_size: int): # check updated param for i in range(len(sharded_model)): idx = rank * num_local_layer + i - assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight) - assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias) + assert_close(torch_model.layers[idx].weight, sharded_model[i].weight) + assert_close(torch_model.layers[idx].bias, sharded_model[i].bias) # forward only with torch.no_grad(): @@ -131,14 +132,14 @@ def examine_pp(num_microbatch: int, batch_size: int): sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True ) if stage_manager.is_last_stage(): - assert torch.allclose(torch_loss, pp_ret["loss"]) + assert_close(torch_loss, pp_ret["loss"]) for layer in sharded_model: if layer.weight.grad is None: assert layer.weight.grad is None and layer.bias.grad is None else: - assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad)) - assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad)) + assert_close(layer.weight.grad, torch.zeros_like(layer.weight.grad)) + assert_close(layer.bias.grad, torch.zeros_like(layer.bias.grad)) def run_dist( diff --git a/tests/test_shardformer/test_flash_attention.py b/tests/test_shardformer/test_flash_attention.py index 9aa24a166..42ca6b198 100644 --- a/tests/test_shardformer/test_flash_attention.py +++ b/tests/test_shardformer/test_flash_attention.py @@ -88,6 +88,7 @@ def check_attn_func(dtype: torch.dtype, attn_func, attn_kwargs: dict, padding_ma padding_mask = padding_mask[:, None, :, None].logical_not() ref_output = ref_output.masked_fill(padding_mask, 0) output = output.masked_fill(padding_mask, 0) + assert_close(output, ref_output, **tols) output.mean().backward() ref_output.mean().backward() @@ -128,6 +129,8 @@ def test_flash_attn_func(dtype: torch.dtype): attn_kwargs, padding_mask = gen_kwargs_func(dtype) for attn_func, name, need_postprocess in attn_funcs: print(f"{dtype}, {name}, {mask_type}") + if mask_type == "padded": + pass if need_postprocess: check_attn_func(dtype, attn_func, post_process_kwargs_for_raw_attn(attn_kwargs), padding_mask) else: diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py new file mode 100644 index 000000000..1c7647a7d --- /dev/null +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -0,0 +1,186 @@ +import torch +import torch.distributed as dist +import torch.nn.functional as F +from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_func +from torch.testing import assert_close + +import colossalai +from colossalai.shardformer.layer import AttnMaskType +from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention +from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device + + +@parameterize("seq_len", [4096]) +@parameterize("bs", [2]) +@parameterize("nheads", [5]) +@parameterize("d", [128]) +@parameterize("dtype", [torch.bfloat16, torch.float16]) +def check_ring_attn(seq_len, bs, nheads, d, dtype): + torch.cuda.manual_seed(2) + device = get_current_device() + sp_group = dist.group.WORLD + sp_size = dist.get_world_size() + # Some outliers may seem large, but our errors are still lower than + # than Megatron-LM context parallel's + # (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215) + # and the original zigzag implementation's (https://github.com/zhuzilin/ring-flash-attention/tree/main) + atol = rtol = 7e-3 + + # Setup inputs + qkv = torch.randn(bs, seq_len, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) + local_qkv = split_batch_zigzag(qkv, sp_group) + q, k, v = local_qkv.unbind(dim=-3) + q, k, v = [x.squeeze(2).detach().clone().transpose(1, 2) for x in (q, k, v)] # (B, nHeads, Sq, D) + q.requires_grad = k.requires_grad = v.requires_grad = True + + # Ring attention vs single GPU + ring_out, ring_lse = RingAttention.attention( + q, + k, + v, + sp_group, + AttnMaskType.CAUSAL, + return_softmax=True, + inner_ring_size=max(2, sp_size // 2), + # inner_ring_size=4 + ) + ring_out = ring_out.transpose(1, 2) + out, lse, _ = flash_attn_qkvpacked_func( + qkv, dropout_p=0.0, causal=True, window_size=(-1, -1), alibi_slopes=None, return_attn_probs=True + ) + + # Checkout out and softmax denominator + local_out = split_batch_zigzag(out, sp_group) + local_lse = split_batch_zigzag(lse, sp_group, seq_dim=-1) + local_lse = local_lse.transpose(1, 2).contiguous().view(-1, ring_lse.shape[-1]) # (B, nHeads, Sq) -> (T, nHeads) + assert_close(ring_lse, local_lse, atol=atol, rtol=rtol) + assert_close(ring_out, local_out, atol=atol, rtol=rtol) + + # Check grads + ring_out.sum().backward() + out.sum().backward() + ring_dq, ring_dk, ring_dv = [x.transpose(1, 2) for x in (q.grad, k.grad, v.grad)] + dqkv = qkv.grad + local_dqkv = split_batch_zigzag(dqkv, sp_group) + + assert_close(ring_dq, local_dqkv[:, :, 0], atol=atol, rtol=rtol) + assert_close(ring_dk, local_dqkv[:, :, 1], atol=atol, rtol=rtol) + assert_close(ring_dv, local_dqkv[:, :, 2], atol=atol, rtol=rtol) + if dist.get_rank() == 0: + print( + f"sp_size {dist.get_world_size()}, inner ring size {dist.get_world_size(RingAttention.INNER_RING_GROUP)} passed." + ) + + +@parameterize("seqlen", [4096]) +@parameterize("bs", [2]) +@parameterize("nheads", [5]) +@parameterize("d", [128]) +@parameterize("dtype", [torch.bfloat16, torch.float16]) +def check_packed_seq(seqlen, bs, nheads, d, dtype): + device = get_current_device() + sp_group = dist.group.WORLD + sp_size = dist.get_world_size() + atol = rtol = 7e-3 + torch.cuda.manual_seed(2) + # Prepare varlen attention mask + padding_mask = torch.ones((bs, seqlen), dtype=torch.int, device=device) + padding_mask[: bs // 2, (seqlen // 4) * 3 :] = 0 + padding_mask[:, seqlen // 2 :] = 0 + + input_embeds = torch.randn(bs, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) + + # Forward + # out = ColoAttention.attention(q, k, v, **mask_info) + flat_input = input_embeds.view(-1, nheads, d)[padding_mask.flatten().nonzero().squeeze()] + qkv = torch.stack([flat_input] * 3, dim=1) + qkv.retain_grad() + + input_embeds, mask_info, _ = RingAttention.prepare_varlen_batch(padding_mask, sp_group, input_embeds) + out, lse, _ = flash_attn_varlen_qkvpacked_func( + qkv, + mask_info["cu_seqlens"] * sp_size, + mask_info["max_seqlen"] * sp_size, + return_attn_probs=True, + causal=True, + # deterministic=True + ) + # Test the splitting function + local_input = split_varlen_zigzag( + flat_input, mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size + ) + assert (local_input == input_embeds.view(-1, nheads, d)[mask_info["valid_indices"]]).all() + del local_input, flat_input + + q_ring, k_ring, v_ring = [input_embeds.clone().transpose(1, 2) for _ in range(3)] + q_ring.retain_grad() + k_ring.retain_grad() + v_ring.retain_grad() + + ring_out, ring_lse = RingAttention.attention( + q_ring, + k_ring, + v_ring, + sp_group, + **mask_info, + pad_output=False, + return_softmax=True, + # deterministic=True + ) + ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d) + # Check output + lse = lse.transpose(0, 1) + out, lse = split_varlen_zigzag( + [out, lse], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size + ) + assert_close(lse, ring_lse, atol=atol, rtol=rtol) + assert_close(out, ring_out, atol=atol, rtol=rtol) + + # Check grads + labels = torch.ones(out.shape[0], dtype=dtype, device=device) + F.mse_loss(out.sum((-2, -1)), labels).backward() + F.mse_loss(ring_out.sum((-2, -1)), labels[: ring_out.shape[0]]).backward() + dq, dk, dv = [ + split_varlen_zigzag( + qkv.grad[:, i], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size + ) + for i in range(3) + ] + dq_ring, dk_ring, dv_ring = [ + x.transpose(1, 2).reshape(-1, nheads, d)[mask_info["valid_indices"]] + for x in (q_ring.grad, k_ring.grad, v_ring.grad) + ] + + assert_close(dq, dq_ring, atol=atol, rtol=rtol) + assert_close(dk, dk_ring, atol=atol, rtol=rtol) + assert_close(dv, dv_ring, atol=atol, rtol=rtol) + + +def launch_single_ring(rank, world_size, port): + colossalai.launch(rank, world_size, "localhost", port) + check_packed_seq() + check_ring_attn() + + +def launch_double_ring(rank, world_size, port): + colossalai.launch(rank, world_size, "localhost", port) + check_ring_attn() + + +@rerun_if_address_is_in_use() +@parameterize("world_size", [2]) +def test_ring_attn(world_size): + spawn(launch_single_ring, nprocs=world_size) + + +@rerun_if_address_is_in_use() +@parameterize("world_size", [4]) +def test_double_ring(world_size): + spawn(launch_double_ring, nprocs=world_size) + + +if __name__ == "__main__": + test_ring_attn() + test_double_ring() diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 190fee129..9ad84341a 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup from torch.nn import Module from torch.optim import Adam, Optimizer from torch.testing import assert_close +from transformers.modeling_outputs import BaseModelOutputWithPast from colossalai.accelerator import get_accelerator from colossalai.booster import Booster @@ -259,7 +260,6 @@ def run_forward_backward_with_hybrid_plugin( org_output = org_model(**unshard_test_data) org_loss = criterion(org_output) org_loss.backward() - return org_loss, org_output, sharded_loss, sharded_output @@ -302,11 +302,12 @@ def run_forward_backward_with_low_level_zero_plugin( def check_output_hidden_state( - org_output: Tensor, - sharded_output: Tensor, + org_output: BaseModelOutputWithPast, + sharded_output: BaseModelOutputWithPast, stage_manager: Optional[PipelineStageManager] = None, atol: float = 1e-5, rtol: float = 1e-3, + shard_config: Optional[ShardConfig] = None, ): org_hidden_state = org_output.last_hidden_state @@ -315,6 +316,14 @@ def check_output_hidden_state( else: sharded_hidden_state = sharded_output.last_hidden_state + # Check if the output sequence is gathered before cross entropy + if shard_config is not None: + seq_dim = 1 + sp_group = shard_config.sequence_parallel_process_group + sp_size = shard_config.sequence_parallel_size + if org_hidden_state.shape[seq_dim] == sharded_hidden_state.shape[seq_dim] * sp_size: + org_hidden_state = org_hidden_state.chunk(sp_size, dim=seq_dim)[dist.get_rank(sp_group)] + assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol) @@ -374,8 +383,11 @@ def get_grad_tensors_for_check( shard_grad = torch.cat(shard_grad_list, dim=dim) # embedding may be resized when using tensor parallel - if shard_grad.shape[0] > org_grad.shape[0]: - shard_grad = shard_grad[: org_grad.shape[0], :] + try: + if shard_grad.shape[0] > org_grad.shape[0]: + shard_grad = shard_grad[: org_grad.shape[0], :] + except: + pass if verbose and dist.get_rank() == 0: print(f"'{suffix}' grad: {org_grad}, {shard_grad}") @@ -404,9 +416,6 @@ def check_grad( org_grad = getattr_(org_model, suffix).weight.grad shard_grad = getattr_(sharded_model, suffix).weight.grad shard_weight = getattr_(sharded_model, suffix).weight - # if verbose and dist.get_rank() == 0: - # print("shard_weight", shard_weight) - # print("org_grad", org_grad) if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))] dist.all_gather(shard_grad_list, shard_grad, tp_group) @@ -440,7 +449,7 @@ def check_all_grad_tensors(check_tensors): "org_grad": tensor to be compared from the original model "shard_grad": tensor to be compared from the sharded model """ - for suffix, check_info in check_tensors.items(): + for idx, (suffix, check_info) in enumerate(check_tensors.items()): org_grad = check_info["org_grad"] shard_grad = check_info["shard_grad"] rtol = check_info["rtol"] diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py index 3281b50e1..efe5cee2a 100644 --- a/tests/test_shardformer/test_model/test_shard_command.py +++ b/tests/test_shardformer/test_model/test_shard_command.py @@ -271,7 +271,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ], ) def run_command_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm") + sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_causal_lm") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -321,7 +321,7 @@ def run_command_test(test_config): ], ) def run_command_3d_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm") + sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_causal_lm") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 88e54176b..3c66f6097 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -63,7 +63,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" ): master2working = sharded_optimizer.get_master_to_working_map() - for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): + for (name, p1), p2 in zip( + llama_model.named_parameters(), sharded_optimizer._master_param_groups_of_current_rank[0] + ): working_p = master2working[id(p2)] grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p)) grad_index = ( @@ -73,7 +75,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) grad = grads[grad_index] sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()] - assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) + try: + assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) + except Exception as e: + raise RuntimeError(f"Failed to check grad for {name}") from e # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} @@ -114,89 +119,130 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 5e-3, 5e-3 if org_model.__class__.__name__ == "LlamaModel": - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - + check_output_hidden_state( + org_output, + sharded_output, + stage_manager, + atol=atol, + rtol=rtol, + shard_config=booster.plugin.shard_config, + ) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # check weights if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): if test_config["precision"] == "fp32": atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 - try: - check_weight( - llama_model, - shard_llama_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False, - ) - except Exception as e: - print(f"Failed config: {test_config}") - raise e + check_weight( + llama_model, + shard_llama_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) # check grads check_all_grad_tensors(grads_to_check) - torch.cuda.empty_cache() @parameterize( "test_config", [ - { # Ulysess + Flash attention + # Double Ring Attention + { + "tp_size": 1, + "pp_size": 1, + "sp_size": 4, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring_attn", + "use_lazy_init": True, + "zero_stage": 0, + "precision": "fp16", + "initial_scale": 1, + "inner_ring_size": 2, + }, + # Ring Attention + PP + { + "tp_size": 1, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring_attn", + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + # Ring Attention + TP + { + "tp_size": 2, + "pp_size": 1, + "sp_size": 2, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring_attn", + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { # Ulysess + TP + "tp_size": 2, + "pp_size": 1, + "sp_size": 2, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 0, + "precision": "fp16", + "initial_scale": 1, + }, + { # Ulysess + PP "tp_size": 1, "pp_size": 2, "sp_size": 2, "num_microbatches": 2, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "all_to_all", - "enable_flash_attention": True, + "enable_all_optimization": True, "use_lazy_init": True, "zero_stage": 0, "precision": "fp16", "initial_scale": 1, }, - { # Test ring + Flash attention - "tp_size": 2, - "pp_size": 1, - "sp_size": 2, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring", - "enable_flash_attention": True, - "use_lazy_init": True, - "zero_stage": 2, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 1, - "pp_size": 1, - "sp_size": 2, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", - "use_lazy_init": True, - "zero_stage": 1, - "precision": "fp16", - "initial_scale": 1, - }, { "tp_size": 4, "pp_size": 1, "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "split_gather", - "enable_flash_attention": False, + "enable_flash_attention": True, "use_lazy_init": True, "precision": "fp16", "initial_scale": 1, }, + { + "tp_size": 2, + "pp_size": 1, + "sp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 2, @@ -240,12 +286,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) def run_llama_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name: + continue try: check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) except Exception as e: - print(f"Failed config: {test_config}") + print(f"Failed config: {test_config}, model name: {name}") raise e clear_layout_converter() Randomizer.reset_index() From 26493b97d3b9c55a27d822c63b9176ac79b9ff5e Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Fri, 16 Aug 2024 18:49:14 +0800 Subject: [PATCH 11/16] [misc] update compatibility (#6008) * [misc] update compatibility * [misc] update requirements * [devops] disable requirements cache * [test] fix torch ddp test * [test] fix rerun on address in use * [test] fix lazy init --- .compatibility | 1 + .cuda_ext.json | 4 ++-- .github/workflows/build_on_pr.yml | 2 +- .github/workflows/build_on_schedule.yml | 2 +- colossalai/testing/utils.py | 2 +- requirements/requirements.txt | 2 +- .../test_plugin/test_torch_ddp_plugin.py | 2 +- tests/test_lazy/test_models.py | 14 +++++++++++--- 8 files changed, 19 insertions(+), 10 deletions(-) diff --git a/.compatibility b/.compatibility index 4f808740b..62d19faff 100644 --- a/.compatibility +++ b/.compatibility @@ -1,3 +1,4 @@ 2.1.0-12.1.0 2.2.2-12.1.0 2.3.0-12.1.0 +2.4.0-12.4.1 diff --git a/.cuda_ext.json b/.cuda_ext.json index 8c9d5916c..1e617755b 100644 --- a/.cuda_ext.json +++ b/.cuda_ext.json @@ -5,8 +5,8 @@ "cuda_image": "hpcaitech/cuda-conda:12.1" }, { - "torch_command": "pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118", - "cuda_image": "hpcaitech/cuda-conda:11.8" + "torch_command": "pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124", + "cuda_image": "hpcaitech/cuda-conda:12.4" } ] } diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index 151454239..58cd88268 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -141,7 +141,7 @@ jobs: - name: Install Colossal-AI run: | BUILD_EXT=1 pip install -v -e . - pip install -r requirements/requirements-test.txt + pip install --no-cache-dir -r requirements/requirements-test.txt - name: Store Colossal-AI Cache run: | diff --git a/.github/workflows/build_on_schedule.yml b/.github/workflows/build_on_schedule.yml index fc6424503..fc688a71b 100644 --- a/.github/workflows/build_on_schedule.yml +++ b/.github/workflows/build_on_schedule.yml @@ -57,7 +57,7 @@ jobs: [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/ BUILD_EXT=1 pip install -v -e . cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/ - pip install -r requirements/requirements-test.txt + pip install --no-cache-dir -r requirements/requirements-test.txt - name: Unit Testing if: steps.check-avai.outputs.avai == 'true' diff --git a/colossalai/testing/utils.py b/colossalai/testing/utils.py index 5f6864ff0..90d35dc85 100644 --- a/colossalai/testing/utils.py +++ b/colossalai/testing/utils.py @@ -176,7 +176,7 @@ def rerun_if_address_is_in_use(): else: exception = Exception - func_wrapper = rerun_on_exception(exception_type=exception, pattern=".*Address already in use.*") + func_wrapper = rerun_on_exception(exception_type=exception, pattern=".*(A|a)ddress already in use.*") return func_wrapper diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 651eb66e8..578122d47 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -8,7 +8,7 @@ click fabric contexttimer ninja -torch>=2.1.0,<=2.3.0 +torch>=2.1.0,<=2.4.0 safetensors einops pydantic diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py index f92b5c6e5..2a3b6e5a3 100644 --- a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py @@ -47,7 +47,7 @@ def check_torch_ddp_plugin(): registry = model_zoo for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in registry.items(): - if name == "dlrm_interactionarch" or name.startswith("simple_"): + if name in ("dlrm_interactionarch", "transformers_mixtral") or name.startswith("simple_"): continue run_fn(model_fn, data_gen_fn, output_transform_fn) torch.cuda.empty_cache() diff --git a/tests/test_lazy/test_models.py b/tests/test_lazy/test_models.py index c85860a8d..0a919955f 100644 --- a/tests/test_lazy/test_models.py +++ b/tests/test_lazy/test_models.py @@ -18,9 +18,17 @@ def test_models_lazy_init(subset, default_device): sub_model_zoo = model_zoo.get_sub_registry(subset, allow_empty=True) for name, entry in sub_model_zoo.items(): # TODO(ver217): lazy init does not support weight norm, skip these models - if name in ("torchaudio_wav2vec2_base", "torchaudio_hubert_base") or name.startswith( - ("transformers_vit", "transformers_blip2", "transformers_whisper") - ): + if name in ( + "torchaudio_wav2vec2_base", + "torchaudio_hubert_base", + "timm_beit", + "timm_vision_transformer", + "timm_deit", + "timm_beitv2", + "timm_deit3", + "timm_convit", + "timm_tnt_b_patch16_224", + ) or name.startswith(("transformers_vit", "transformers_blip2", "transformers_whisper")): continue check_lazy_init(entry, verbose=True, default_device=default_device) From f1c3266a946d6fed18921c534824901f7606f778 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Mon, 19 Aug 2024 14:08:17 +0800 Subject: [PATCH 12/16] overlap kv comm with output rescale (#6017) Co-authored-by: Edenzzzz --- colossalai/shardformer/layer/attn.py | 30 ++++++++++++++++++---------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 6dab17ec0..5d1a30d8a 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -690,6 +690,13 @@ class RingAttention(torch.autograd.Function): ) return out, softmax_lse, rng_state + def _kv_comm(i): + # Avoid overwriting attn input when it shares mem with buffer + if not RingAttention.ATTN_DONE.query(): + kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2]) + if i < local_sp_size - 1: + local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) + def _local_ring_forward(): # (Hopefully) overlap output correction with next flash attn for i in range(local_sp_size): @@ -698,12 +705,8 @@ class RingAttention(torch.autograd.Function): # NOTE: waiting outside the current stream will NOT correctly synchronize. if i > 0: local_kv_comms[(i + 1) % 2].wait() - - # Avoid overwriting attn input when it shares mem with buffer - if not RingAttention.ATTN_DONE.query(): - kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2]) - if i < local_sp_size - 1: - local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) + if i == 0: + _kv_comm(i) if i == 0: # Compute with local KV; no mask @@ -734,6 +737,9 @@ class RingAttention(torch.autograd.Function): rng_states[i], ) = _forward(q_block, kv_block[0], kv_block[1], causal=False) RingAttention.ATTN_DONE.record() + # Pipeline the next KV comm with output correction instead of the next flash attn + # to minimize idle time when comm takes longer than attn. + _kv_comm(i + 1) block_softmax_lse[i % 2] = ( block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() @@ -761,15 +767,13 @@ class RingAttention(torch.autograd.Function): # all new KVs from the previous inner ring for i in range(local_sp_size): with torch.cuda.stream(sp_streams[i % 2]): - if not RingAttention.ATTN_DONE.query(): - kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2]) - if i < local_sp_size - 1: - local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) - # Send & recv KV if i > 0: local_kv_comms[(i + 1) % 2].wait() + if i == 0: + _kv_comm(i) + if ring_num_idx > inter_ring_rank: kv_block = kv_buffers[i % 2] ( @@ -778,6 +782,8 @@ class RingAttention(torch.autograd.Function): rng_states[i + local_sp_size * ring_num_idx], ) = _forward(q1, kv_block[0], kv_block[1], causal=False) RingAttention.ATTN_DONE.record() + + _kv_comm(i + 1) block_softmax_lse[i % 2] = ( block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() ) @@ -792,6 +798,8 @@ class RingAttention(torch.autograd.Function): rng_states[i + local_sp_size * ring_num_idx], ) = _forward(q, kv_block[0], kv_block[1], causal=False) RingAttention.ATTN_DONE.record() + + _kv_comm(i + 1) block_softmax_lse[i % 2] = ( block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() ) From dcc44aab8d3ab640bfd9d740a147ad39cb21f18c Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Tue, 20 Aug 2024 10:32:41 +0800 Subject: [PATCH 13/16] [misc] Use dist logger in plugins (#6011) * use dist logger in plugins * remove trash * print on rank 0 --------- Co-authored-by: Edenzzzz --- colossalai/booster/booster.py | 18 ++++++--- colossalai/booster/plugin/gemini_plugin.py | 26 ++++++++----- .../booster/plugin/hybrid_parallel_plugin.py | 35 ++++++++++------- .../booster/plugin/low_level_zero_plugin.py | 39 +++++++++---------- .../plugin/moe_hybrid_parallel_plugin.py | 24 +++++++----- colossalai/booster/plugin/torch_ddp_plugin.py | 2 + .../booster/plugin/torch_fsdp_plugin.py | 15 +++---- colossalai/zero/gemini/gemini_optimizer.py | 12 +++--- 8 files changed, 101 insertions(+), 70 deletions(-) diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index 56d8a0935..8047d90f7 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -1,4 +1,3 @@ -import warnings from contextlib import contextmanager from typing import Any, Callable, Dict, Iterator, List, Optional, Union @@ -8,6 +7,8 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader +from colossalai.logging import get_dist_logger + SUPPORT_PEFT = False try: import peft @@ -81,12 +82,15 @@ class Booster: plugin, Plugin ), f"Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}." self.plugin = plugin + self.logger = get_dist_logger() # set accelerator if self.plugin and self.plugin.control_device(): self.accelerator = None if device is not None: - warnings.warn("The plugin will control the accelerator, so the device argument will be ignored.") + self.logger.warning( + "The plugin will control the accelerator," "so the device argument will be ignored.", ranks=[0] + ) else: device = device or "cuda" self.accelerator = Accelerator(device) @@ -94,7 +98,10 @@ class Booster: # set precision if self.plugin and self.plugin.control_precision(): if mixed_precision is not None: - warnings.warn("The plugin will control the precision, so the mixed_precision argument will be ignored.") + self.logger.warning( + "The plugin will control the precision," "so the mixed_precision argument will be ignored.", + ranks=[0], + ) self.mixed_precision = None elif mixed_precision is None: self.mixed_precision = None @@ -267,8 +274,9 @@ class Booster: ), "Please provide pretrained directory path if not passing in lora configuration." if quantize is True: if bnb_quantization_config is not None: - warnings.warn( - "User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk." + self.logger.warning( + "User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk.", + ranks=[0], ) else: bnb_quantization_config = BnbQuantizationConfig( diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index ad131fbe7..3754cfe60 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -1,5 +1,4 @@ import gc -import logging import os import random from pathlib import Path @@ -27,6 +26,7 @@ from colossalai.checkpoint_io.utils import ( ) from colossalai.cluster import DistCoordinator, ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.logging import get_dist_logger from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.memory_tracer import MemStats @@ -63,6 +63,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): def __init__(self) -> None: super().__init__() self.coordinator = DistCoordinator() + self.logger = get_dist_logger() def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): """ @@ -118,7 +119,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): """ assert isinstance(model, GeminiDDP), "Please boost the model before saving!" if os.path.isfile(checkpoint_path): - logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") + self.logger.error(f"Provided path ({checkpoint_path}) should be a directory, not a file", ranks=[0]) return Path(checkpoint_path).mkdir(parents=True, exist_ok=True) @@ -143,10 +144,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO): index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) save_config_file(model.unwrap(), checkpoint_path) - logging.info( + self.logger.info( f"The model is split into checkpoint shards. " f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." + f"index located at {save_index_file}.", + ranks=[0], ) def load_sharded_model( @@ -168,7 +170,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!" if os.path.isfile(checkpoint): - logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0]) return Path(checkpoint).mkdir(parents=True, exist_ok=True) @@ -201,10 +203,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO): if self.coordinator.is_master(): index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) - logging.info( + self.logger.info( f"The optimizer is going to be split to checkpoint shards. " f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." + f"index located at {save_index_file}.", + ranks=[0], ) def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str): @@ -214,7 +217,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): """ assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!" if not os.path.isfile(checkpoint_index_file): - logging.error(f"Provided path ({checkpoint_index_file}) should be a file") + self.logger.error(f"Provided path ({checkpoint_index_file}) should be a file", ranks=[0]) assert isinstance(optimizer, GeminiOptimizer) @@ -369,9 +372,12 @@ class GeminiPlugin(DPPluginBase): assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" if get_accelerator().name == "npu": assert placement_policy == "static", "NPU only supports static placement policy" + + self.logger = get_dist_logger() if enable_async_reduce and not pin_memory: - logging.warning( - f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set." + self.logger.warning( + f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set.", + ranks=[0], ) pin_memory = True self.gemini_config = dict( diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 63427192f..b4b40020f 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,6 +1,5 @@ import ctypes import random -import warnings from collections import defaultdict from contextlib import contextmanager, nullcontext from copy import deepcopy @@ -27,6 +26,7 @@ from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.cluster import ProcessGroupMesh from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface.optimizer import DistributedOptim +from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager @@ -1023,6 +1023,7 @@ class HybridParallelPlugin(PipelinePluginBase): inner_ring_size: int = None, ) -> None: super().__init__() + self.logger = get_dist_logger() assert ( dist.get_world_size() % (tp_size * pp_size) == 0 @@ -1040,8 +1041,9 @@ class HybridParallelPlugin(PipelinePluginBase): tp_size > 1 ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" if sp_size != 1: - warnings.warn( - f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size." + self.logger.warning( + f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size.", + ranks=[0], ) self.sp_size = 1 self.dp_size = dist.get_world_size() // (tp_size * pp_size) @@ -1126,7 +1128,12 @@ class HybridParallelPlugin(PipelinePluginBase): else: raise NotImplementedError() if sequence_parallelism_mode == "ring_attn": - assert parallel_output, "Ring Attention doesn't support gathering output yet." + if not parallel_output: + self.logger.warning( + "parallel_output must be True for Zigzag Ring Attention, as we've not supported Zigzag all-gather yet.", + ranks=[0], + ) + parallel_output = True self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis) @@ -1231,7 +1238,10 @@ class HybridParallelPlugin(PipelinePluginBase): optimizer = cast_to_distributed(optimizer) if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0: - warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.") + self.logger.warning( + "Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.", + ranks=[0], + ) zero_config["partition_grad"] = False zero_stage = 0 @@ -1287,9 +1297,10 @@ class HybridParallelPlugin(PipelinePluginBase): else: is_zero = self.dp_size > 1 if self.dp_size == 1: - warnings.warn( + self.logger.warning( "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " - "If you do not intend to use cpu_offload, please consider set zero_stage=0." + "If you do not intend to use cpu_offload, please consider set zero_stage=0.", + ranks=[0], ) assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." @@ -1332,7 +1343,7 @@ class HybridParallelPlugin(PipelinePluginBase): assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled" if return_outputs: - warnings.warn("return_outputs may lead to significant extra memory consumption.") + self.logger.warning("return_outputs may lead to significant extra memory consumption.", ranks=[0]) # Create a context for gradient synchronization based on the optimizer type. # If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync(). @@ -1346,10 +1357,8 @@ class HybridParallelPlugin(PipelinePluginBase): ) # run with gradients accumulation - if ( - model.require_grad_sync == False - or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False) - or not torch.is_grad_enabled() + if model.require_grad_sync == False or ( + isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False ): return outputs @@ -1449,7 +1458,7 @@ class HybridParallelPlugin(PipelinePluginBase): assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model." assert self.pp_size == 1 and self.tp_size == 1 self.lora_enabled = True - warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr") + self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0]) if bnb_quantization_config is not None: model = quantize_model(model, bnb_quantization_config) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index e4c386a22..cc3755e6f 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -1,7 +1,5 @@ import enum -import logging import os -import warnings from contextlib import nullcontext from functools import partial from pathlib import Path @@ -33,6 +31,7 @@ from colossalai.checkpoint_io.utils import ( ) from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface.optimizer import DistributedOptim +from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.tensor.colo_parameter import ColoParameter @@ -62,9 +61,7 @@ class OptimizerParamCheckState(enum.Enum): class LowLevelZeroModel(ModelWrapper, AMPModelMixin): - def __init__( - self, module: nn.Module, precision: str, overlap_allgather: bool = False, cast_inputs: bool = True - ) -> None: + def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None: super().__init__(module) self.dtype = None if precision == "fp16": @@ -76,7 +73,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin): module = module.to(get_accelerator().get_current_device()) self.module = module self.convert_fn = None - if self.dtype is not None and cast_inputs: + if self.dtype is not None: self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) self.overlap_allgather = overlap_allgather if overlap_allgather: @@ -140,7 +137,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): """ assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!" if os.path.isfile(checkpoint): - logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0]) return Path(checkpoint).mkdir(parents=True, exist_ok=True) @@ -177,10 +174,11 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): index_file.append_meta_data("total_size", total_size) if self.coordinator.is_master(): index_file.write_index_file(save_index_file) - logging.info( + self.logger.info( f"The optimizer is going to be split to checkpoint shards. " f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." + f"index located at {save_index_file}.", + ranks=[0], ) def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str): @@ -267,7 +265,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): def save_lora_as_pretrained(self, model, checkpoint, use_safetensors): if os.path.isfile(checkpoint): - logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0]) return from peft import PeftModel @@ -336,7 +334,6 @@ class LowLevelZeroPlugin(DPPluginBase): cpu_offload: bool = False, master_weights: bool = True, verbose: bool = False, - cast_inputs: bool = True, ) -> None: super().__init__() assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training" @@ -363,8 +360,7 @@ class LowLevelZeroPlugin(DPPluginBase): ) self.lora_enabled = False self.verbose = verbose - self.cast_inputs = cast_inputs - + self.logger = get_dist_logger() # set class name with stage, for better error message setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}") @@ -400,7 +396,7 @@ class LowLevelZeroPlugin(DPPluginBase): assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model." self.lora_enabled = True - warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr") + self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0]) if bnb_quantization_config is not None: model = quantize_model(model, bnb_quantization_config) @@ -449,8 +445,9 @@ class LowLevelZeroPlugin(DPPluginBase): origin_param = name2param[origin_key] group_id, check_state = self.get_param_group_id(optimizer, origin_param, param) if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND: - warnings.warn( - f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups." + self.logger.warning( + f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups.", + ranks=[0], ) elif ( check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED @@ -478,10 +475,7 @@ class LowLevelZeroPlugin(DPPluginBase): if not isinstance(model, ModelWrapper): model = LowLevelZeroModel( - model, - self.precision, - overlap_allgather=self.zero_optim_kwargs["overlap_allgather"], - cast_inputs=self.cast_inputs, + model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"] ) # TODO: Support Galore + ZeRO @@ -493,7 +487,10 @@ class LowLevelZeroPlugin(DPPluginBase): optimizer = cast_to_distributed(optimizer) if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0: - warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.") + self.logger.warning( + "Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.", + ranks=[0], + ) zero_optim_kwargs["partition_grad"] = False zero_stage = 0 diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index b3415af0e..36973b240 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -1,4 +1,3 @@ -import warnings from collections import defaultdict from types import MethodType from typing import Callable, List, Optional, OrderedDict, Tuple @@ -26,6 +25,7 @@ from colossalai.checkpoint_io import MoECheckpointIO from colossalai.cluster.process_group_mesh import ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface.optimizer import DistributedOptim +from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import cast_to_distributed from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule @@ -215,12 +215,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): overlap_p2p: bool = True, overlap_allgather: bool = False, ) -> None: + self.logger = get_dist_logger() if overlap_communication or zero_stage == 2: overlap_communication = False zero_stage = 1 - warnings.warn( + self.logger.warning( f"overlap_communication and zero_stage are set to False and 1 because " - f"ZeRO-2 or comm overlap cause program hang when some experts are not routed. " + f"ZeRO-2 or comm overlap cause program hang when some experts are not routed.", + ranks=[0], ) assert ( @@ -238,8 +240,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): tp_size > 1 ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" if sp_size != 1: - warnings.warn( - f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size." + self.logger.warning( + f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}," + "will ignore the given sequence parallelism size.", + ranks=[0], ) self.sp_size = 1 self.dp_size = dist.get_world_size() // (tp_size * pp_size) @@ -400,8 +404,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): and self.sequence_parallelism_mode == "all_to_all" ) if use_ddp: - warnings.warn( - f"Will have to check all params are used in pytorch DDP since not all experts are always activated" + self.logger.warning( + f"Will have to check all params are used in pytorch DDP since not all experts are always activated", + ranks=[0], ) self.ddp_config["find_unused_parameters"] = True @@ -457,9 +462,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): ) else: if self.dp_size <= 1: - warnings.warn( + self.logger.warning( "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " - "If you do not intend to use cpu_offload, please consider set zero_stage=0." + "If you do not intend to use cpu_offload, please consider set zero_stage=0.", + ranks=[0], ) assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." optimizer = MoeHybridParallelZeroOptimizer( diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 5116446a4..8a807970c 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -9,6 +9,7 @@ from torch.utils.data import DataLoader from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.logging import get_dist_logger from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.utils import get_current_device @@ -21,6 +22,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): def __init__(self) -> None: super().__init__() self.coordinator = DistCoordinator() + self.logger = get_dist_logger() def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True): """ diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index cd2f9e840..7b67da032 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -1,6 +1,4 @@ -import logging import os -import warnings from pathlib import Path from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple @@ -30,6 +28,7 @@ from torch.utils.data import DataLoader from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.logging import get_dist_logger from .dp_plugin_base import DPPluginBase @@ -40,6 +39,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): def __init__(self) -> None: super().__init__() self.coordinator = DistCoordinator() + self.logger = get_dist_logger() def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool): assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!" @@ -88,7 +88,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): """ assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!" if os.path.isfile(checkpoint_path): - logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") + self.logger.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") return Path(checkpoint_path).mkdir(parents=True, exist_ok=True) @@ -117,7 +117,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) utils.save_config_file(model.unwrap(), checkpoint_path) - logging.info( + self.logger.info( f"The model is split into checkpoint shards. " f"You can find where each parameters has been saved in the " f"index located at {save_index_file}." @@ -162,7 +162,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!" if os.path.isfile(checkpoint): - logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file") return Path(checkpoint).mkdir(parents=True, exist_ok=True) @@ -200,7 +200,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) - logging.info( + self.logger.info( f"The optimizer is going to be split to checkpoint shards. " f"You can find where each parameters has been saved in the " f"index located at {save_index_file}." @@ -311,6 +311,7 @@ class TorchFSDPPlugin(DPPluginBase): param_init_fn=param_init_fn, sync_module_states=sync_module_states, ) + self.logger = get_dist_logger() else: raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") @@ -349,7 +350,7 @@ class TorchFSDPPlugin(DPPluginBase): if optimizer is not None: if len(optimizer.param_groups) > 1: - warnings.warn( + self.logger.warning( "TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used." ) optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults) diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 1d755c417..fdf2a4976 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -1,7 +1,6 @@ # this code is inspired by the DeepSpeed library and implemented with our own design from scratch import copy import math -import warnings from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union import torch @@ -136,7 +135,7 @@ class GeminiOptimizer(OptimizerWrapper): self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0 self.verbose = verbose self.param_groups_backup = list() - + self.logger = get_dist_logger() # Mapping from integer id to real/fake param tensor, used for checkpointing. self.id_to_real_params: Dict[int, Parameter] = dict() self.id_to_fake_params: Dict[int, Parameter] = dict() @@ -148,9 +147,10 @@ class GeminiOptimizer(OptimizerWrapper): for name, param in module.named_parameters(): if is_ddp_ignored(param): if param.requires_grad: - warnings.warn( + self.logger.warning( f"Parameter `{name}` is ignored by DDP but requires gradient! " - "You should handle its optimizer update by yourself!" + "You should handle its optimizer update by yourself!", + ranks=[0], ) else: ddp_param_list.append(param) @@ -842,7 +842,9 @@ class GeminiOptimizer(OptimizerWrapper): *args, **kwargs, ) -> torch.Tensor: - warnings.warn(f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm") + self.logger.warning( + f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm", ranks=[0] + ) class GeminiAdamOptimizer(GeminiOptimizer): From 0d3b0bd86498a6de7a38ca5c5905d60a80c9f6a7 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Wed, 21 Aug 2024 10:21:26 +0800 Subject: [PATCH 14/16] [plugin] add cast inputs option for zero (#6003) (#6022) --- colossalai/booster/plugin/low_level_zero_plugin.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index cc3755e6f..185d34f12 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -61,7 +61,9 @@ class OptimizerParamCheckState(enum.Enum): class LowLevelZeroModel(ModelWrapper, AMPModelMixin): - def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None: + def __init__( + self, module: nn.Module, precision: str, overlap_allgather: bool = False, cast_inputs: bool = True + ) -> None: super().__init__(module) self.dtype = None if precision == "fp16": @@ -73,7 +75,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin): module = module.to(get_accelerator().get_current_device()) self.module = module self.convert_fn = None - if self.dtype is not None: + if self.dtype is not None and cast_inputs: self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) self.overlap_allgather = overlap_allgather if overlap_allgather: @@ -334,6 +336,7 @@ class LowLevelZeroPlugin(DPPluginBase): cpu_offload: bool = False, master_weights: bool = True, verbose: bool = False, + cast_inputs: bool = True, ) -> None: super().__init__() assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training" @@ -361,6 +364,8 @@ class LowLevelZeroPlugin(DPPluginBase): self.lora_enabled = False self.verbose = verbose self.logger = get_dist_logger() + self.cast_inputs = cast_inputs + # set class name with stage, for better error message setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}") @@ -475,7 +480,10 @@ class LowLevelZeroPlugin(DPPluginBase): if not isinstance(model, ModelWrapper): model = LowLevelZeroModel( - model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"] + model, + self.precision, + overlap_allgather=self.zero_optim_kwargs["overlap_allgather"], + cast_inputs=self.cast_inputs, ) # TODO: Support Galore + ZeRO From 39e2597426e3c2c3a3544b400b716c935c51292b Mon Sep 17 00:00:00 2001 From: Tong Li Date: Wed, 21 Aug 2024 10:47:39 +0800 Subject: [PATCH 15/16] [ColossalChat] Add PP support (#6001) * support pp training * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update rm * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update test case * fix * change to 4 * fix eval * test * add pp * hotfix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * support pp training * update rm * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update test case * fix * change to 4 * fix eval * test * add pp * hotfix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * skip pp eval * update all reduce * update sft * update ignore * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update no cache * add eval * remove fi * remove debug * remove parentheses to avoid warning * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revert "add eval" This reverts commit 3ab2f6fa329b6d12959fb3c668d278b4b225c5f0. * add all reduce --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/workflows/run_chatgpt_examples.yml | 8 +- applications/ColossalChat/.gitignore | 6 + .../ColossalChat/coati/trainer/base.py | 4 +- .../ColossalChat/coati/trainer/dpo.py | 7 +- .../ColossalChat/coati/trainer/kto.py | 7 +- .../ColossalChat/coati/trainer/orpo.py | 7 +- applications/ColossalChat/coati/trainer/rm.py | 7 +- .../ColossalChat/coati/trainer/sft.py | 209 ++++++++++++------ .../ColossalChat/coati/trainer/utils.py | 13 +- .../examples/training_scripts/train_dpo.py | 1 + .../examples/training_scripts/train_kto.py | 1 + .../examples/training_scripts/train_orpo.py | 1 + .../examples/training_scripts/train_rm.py | 1 + .../examples/training_scripts/train_sft.py | 4 +- applications/ColossalChat/tests/test_lora.py | 2 +- applications/ColossalChat/tests/test_train.sh | 78 ++++--- 16 files changed, 241 insertions(+), 115 deletions(-) diff --git a/.github/workflows/run_chatgpt_examples.yml b/.github/workflows/run_chatgpt_examples.yml index d0b5c2164..b7522ffbd 100644 --- a/.github/workflows/run_chatgpt_examples.yml +++ b/.github/workflows/run_chatgpt_examples.yml @@ -31,18 +31,18 @@ jobs: - name: Install Colossal-AI run: | - BUILD_EXT=1 pip install -v -e . + BUILD_EXT=1 pip install --no-cache-dir -v -e . - name: Install ChatGPT run: | cd applications/ColossalChat - pip install -v . + pip install --no-cache-dir -v . export BUILD_EXT=1 - pip install -r examples/requirements.txt + pip install --no-cache-dir -r examples/requirements.txt - name: Install Transformers run: | - pip install transformers==4.36.2 + pip install --no-cache-dir transformers==4.36.2 - name: Execute Examples run: | diff --git a/applications/ColossalChat/.gitignore b/applications/ColossalChat/.gitignore index 7b361d38e..5a4bb905f 100755 --- a/applications/ColossalChat/.gitignore +++ b/applications/ColossalChat/.gitignore @@ -161,3 +161,9 @@ applications/ColossalChat/sft_data applications/ColossalChat/prompt_data applications/ColossalChat/preference_data applications/ColossalChat/temp + +# Testing data +/kto_data/ +/preference_data/ +/prompt_data/ +/sft_data/ diff --git a/applications/ColossalChat/coati/trainer/base.py b/applications/ColossalChat/coati/trainer/base.py index 63c903a51..bef4ccc3e 100755 --- a/applications/ColossalChat/coati/trainer/base.py +++ b/applications/ColossalChat/coati/trainer/base.py @@ -16,7 +16,7 @@ from coati.experience_buffer import NaiveExperienceBuffer from coati.experience_maker import Experience from torch.optim import Optimizer -from colossalai.booster import Booster +from colossalai.booster import Booster, Plugin from .utils import is_rank_0 @@ -38,6 +38,7 @@ class SLTrainer(ABC): max_epochs: int, model: nn.Module, optimizer: Optimizer, + plugin: Plugin, start_epoch: int = 0, ) -> None: super().__init__() @@ -45,6 +46,7 @@ class SLTrainer(ABC): self.max_epochs = max_epochs self.model = model self.optimizer = optimizer + self.plugin = plugin self.start_epoch = start_epoch @abstractmethod diff --git a/applications/ColossalChat/coati/trainer/dpo.py b/applications/ColossalChat/coati/trainer/dpo.py index 24ddca654..faa7a90d9 100755 --- a/applications/ColossalChat/coati/trainer/dpo.py +++ b/applications/ColossalChat/coati/trainer/dpo.py @@ -16,7 +16,7 @@ from torch.utils.data import DataLoader from tqdm import trange from transformers import PreTrainedTokenizerBase -from colossalai.booster import Booster +from colossalai.booster import Booster, Plugin from colossalai.cluster import DistCoordinator from colossalai.utils import get_current_device @@ -50,6 +50,7 @@ class DPOTrainer(SLTrainer): ref_model: Any, booster: Booster, actor_optim: Optimizer, + plugin: Plugin, actor_lr_scheduler: _LRScheduler, tokenizer: PreTrainedTokenizerBase, max_epochs: int = 1, @@ -63,7 +64,9 @@ class DPOTrainer(SLTrainer): save_dir: str = None, coordinator: DistCoordinator = None, ) -> None: - super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch) + super().__init__( + booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch + ) self.ref_model = ref_model self.actor_scheduler = actor_lr_scheduler self.tokenizer = tokenizer diff --git a/applications/ColossalChat/coati/trainer/kto.py b/applications/ColossalChat/coati/trainer/kto.py index 6462ba816..f0b23afb6 100755 --- a/applications/ColossalChat/coati/trainer/kto.py +++ b/applications/ColossalChat/coati/trainer/kto.py @@ -17,7 +17,7 @@ from torch.utils.data import DataLoader from tqdm import trange from transformers import PreTrainedTokenizerBase -from colossalai.booster import Booster +from colossalai.booster import Booster, Plugin from colossalai.cluster import DistCoordinator from colossalai.utils import get_current_device @@ -53,6 +53,7 @@ class KTOTrainer(SLTrainer): ref_model: Any, booster: Booster, actor_optim: Optimizer, + plugin: Plugin, actor_lr_scheduler: _LRScheduler, tokenizer: PreTrainedTokenizerBase, max_epochs: int = 1, @@ -66,7 +67,9 @@ class KTOTrainer(SLTrainer): save_dir: str = None, coordinator: DistCoordinator = None, ) -> None: - super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch) + super().__init__( + booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch + ) self.ref_model = ref_model self.actor_scheduler = actor_lr_scheduler self.tokenizer = tokenizer diff --git a/applications/ColossalChat/coati/trainer/orpo.py b/applications/ColossalChat/coati/trainer/orpo.py index c2f75771c..761fd305a 100644 --- a/applications/ColossalChat/coati/trainer/orpo.py +++ b/applications/ColossalChat/coati/trainer/orpo.py @@ -16,7 +16,7 @@ from torch.utils.data import DataLoader from tqdm import trange from transformers import PreTrainedTokenizerBase -from colossalai.booster import Booster +from colossalai.booster import Booster, Plugin from colossalai.cluster import DistCoordinator from colossalai.utils import get_current_device @@ -48,6 +48,7 @@ class ORPOTrainer(SLTrainer): actor: Any, booster: Booster, actor_optim: Optimizer, + plugin: Plugin, actor_lr_scheduler: _LRScheduler, tokenizer: PreTrainedTokenizerBase, max_epochs: int = 1, @@ -59,7 +60,9 @@ class ORPOTrainer(SLTrainer): save_dir: str = None, coordinator: DistCoordinator = None, ) -> None: - super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch) + super().__init__( + booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch + ) self.actor_scheduler = actor_lr_scheduler self.tokenizer = tokenizer self.odds_ratio_loss_fn = OddsRatioLoss() diff --git a/applications/ColossalChat/coati/trainer/rm.py b/applications/ColossalChat/coati/trainer/rm.py index b9e84ef55..82e4625b9 100755 --- a/applications/ColossalChat/coati/trainer/rm.py +++ b/applications/ColossalChat/coati/trainer/rm.py @@ -15,7 +15,7 @@ from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader from transformers import PreTrainedTokenizerBase -from colossalai.booster import Booster +from colossalai.booster import Booster, Plugin from colossalai.cluster import DistCoordinator from colossalai.utils import get_current_device @@ -48,6 +48,7 @@ class RewardModelTrainer(SLTrainer): model: Any, booster: Booster, optimizer: Optimizer, + plugin: Plugin, lr_scheduler: _LRScheduler, tokenizer: PreTrainedTokenizerBase, loss_fn: Optional[Callable] = None, @@ -59,7 +60,9 @@ class RewardModelTrainer(SLTrainer): save_dir: str = None, coordinator: DistCoordinator = None, ) -> None: - super().__init__(booster, max_epochs=max_epochs, model=model, optimizer=optimizer, start_epoch=start_epoch) + super().__init__( + booster, max_epochs=max_epochs, model=model, optimizer=optimizer, plugin=plugin, start_epoch=start_epoch + ) self.actor_scheduler = lr_scheduler self.tokenizer = tokenizer self.loss_fn = loss_fn if loss_fn is not None else LogSigLoss(beta=beta) diff --git a/applications/ColossalChat/coati/trainer/sft.py b/applications/ColossalChat/coati/trainer/sft.py index d37676ada..3aedcf7a9 100755 --- a/applications/ColossalChat/coati/trainer/sft.py +++ b/applications/ColossalChat/coati/trainer/sft.py @@ -6,14 +6,16 @@ import os from typing import Optional import torch +import torch.distributed as dist from coati.trainer.utils import all_reduce_mean from coati.utils import AccumulativeMeanMeter, save_checkpoint from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader -from tqdm import trange +from tqdm import tqdm, trange from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin, Plugin from colossalai.cluster import DistCoordinator from .base import SLTrainer @@ -40,6 +42,7 @@ class SFTTrainer(SLTrainer): optim: Optimizer, lr_scheduler: _LRScheduler, max_epochs: int = 2, + plugin: Plugin = None, accumulation_steps: int = 8, apply_loss_mask: bool = True, start_epoch=0, @@ -47,7 +50,7 @@ class SFTTrainer(SLTrainer): save_dir: str = None, coordinator: Optional[DistCoordinator] = None, ) -> None: - super().__init__(booster, max_epochs, model, optim, start_epoch=start_epoch) + super().__init__(booster, max_epochs, model, optim, plugin, start_epoch=start_epoch) self.accumulation_steps = accumulation_steps self.scheduler = lr_scheduler @@ -94,60 +97,85 @@ class SFTTrainer(SLTrainer): def _train(self, epoch: int): self.model.train() - step_bar = trange( - len(self.train_dataloader) // self.accumulation_steps, - desc=f"Epoch {epoch + 1}/{self.max_epochs}", - disable=not is_rank_0(), - ) - for i, batch in enumerate(self.train_dataloader): - batch = to_device(batch, torch.cuda.current_device()) - batch_size = batch["input_ids"].size(0) - outputs = self.model( - batch["input_ids"], - attention_mask=batch["attention_mask"], - labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"], + if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1: + data_iter = iter(self.train_dataloader) + step_bar = tqdm( + range(len(self.train_dataloader)), + desc="Step", + disable=not (dist.get_rank() == dist.get_world_size() - 1), ) - loss = outputs.loss + for step in step_bar: + outputs = self.booster.execute_pipeline( + data_iter, + self.model, + criterion=lambda outputs, inputs: outputs[0], + optimizer=self.optimizer, + return_loss=True, + ) + loss = outputs["loss"] - self.booster.backward(loss=loss, optimizer=self.optimizer) + if self.booster.plugin.stage_manager.is_last_stage(): + global_loss = all_reduce_mean(loss, self.plugin) + if dist.get_rank() == dist.get_world_size() - 1: + step_bar.set_postfix({"train/loss": global_loss.item()}) - loss_mean = all_reduce_mean(tensor=loss) - self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item()) - - # Gradient accumulation - if (i + 1) % self.accumulation_steps == 0: self.optimizer.step() self.optimizer.zero_grad() - self.scheduler.step() + else: + step_bar = trange( + len(self.train_dataloader) // self.accumulation_steps, + desc=f"Epoch {epoch + 1}/{self.max_epochs}", + disable=not is_rank_0(), + ) + for i, batch in enumerate(self.train_dataloader): + batch = to_device(batch, torch.cuda.current_device()) + batch_size = batch["input_ids"].size(0) + outputs = self.model( + batch["input_ids"], + attention_mask=batch["attention_mask"], + labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"], + ) + loss = outputs.loss - step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")}) - if self.writer: - self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step) - self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step) - self.num_train_step += 1 - self.accumulative_meter.reset() - step_bar.update() + self.booster.backward(loss=loss, optimizer=self.optimizer) - # Save checkpoint - if ( - self.save_dir is not None - and self.save_interval is not None - and (self.num_train_step + 1) % self.save_interval == 0 - ): - save_checkpoint( - save_dir=self.save_dir, - booster=self.booster, - model=self.model, - optimizer=self.optimizer, - lr_scheduler=self.scheduler, - epoch=epoch, - step=self.num_train_step + 1, - batch_size=batch_size, - coordinator=self.coordinator, - ) - self.coordinator.print_on_master( - f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}" - ) + loss_mean = all_reduce_mean(tensor=loss) + self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item()) + + # Gradient accumulation + if (i + 1) % self.accumulation_steps == 0: + self.optimizer.step() + self.optimizer.zero_grad() + self.scheduler.step() + + step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")}) + if self.writer: + self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step) + self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step) + self.num_train_step += 1 + self.accumulative_meter.reset() + step_bar.update() + + # Save checkpoint + if ( + self.save_dir is not None + and self.save_interval is not None + and (self.num_train_step + 1) % self.save_interval == 0 + ): + save_checkpoint( + save_dir=self.save_dir, + booster=self.booster, + model=self.model, + optimizer=self.optimizer, + lr_scheduler=self.scheduler, + epoch=epoch, + step=self.num_train_step + 1, + batch_size=batch_size, + coordinator=self.coordinator, + ) + self.coordinator.print_on_master( + f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}" + ) step_bar.close() def _eval(self, epoch: int): @@ -157,27 +185,64 @@ class SFTTrainer(SLTrainer): self.accumulative_meter.reset() self.model.eval() with torch.no_grad(): - step_bar = trange( - len(self.eval_dataloader), - desc=f"Epoch {epoch + 1}/{self.max_epochs}", - disable=not is_rank_0(), - ) - for batch in self.eval_dataloader: - batch = to_device(batch, torch.cuda.current_device()) - outputs = self.model( - batch["input_ids"], - attention_mask=batch["attention_mask"], - labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"], + if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1: + data_iter = iter(self.eval_dataloader) + step_bar = tqdm( + range(len(self.eval_dataloader)), + desc="Step", + disable=not (dist.get_rank() == dist.get_world_size() - 1), ) - loss_mean = all_reduce_mean(tensor=outputs.loss) - self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0)) - step_bar.update() - loss_mean = self.accumulative_meter.get("loss") - msg = "Evaluation Result:\n" - for tag in ["loss"]: - msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" - self.coordinator.print_on_master(msg) - os.makedirs(self.save_dir, exist_ok=True) - with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: - f.write(msg) - step_bar.close() + for step in step_bar: + outputs = self.booster.execute_pipeline( + data_iter, + self.model, + criterion=lambda outputs, inputs: outputs[0], + optimizer=self.optimizer, + return_loss=True, + ) + loss = outputs["loss"] + if self.booster.plugin.stage_manager.is_last_stage(): + global_loss = all_reduce_mean(loss, self.plugin) + if dist.get_rank() == dist.get_world_size() - 1: + step_bar.set_postfix({"eval/loss": global_loss.item()}) + self.accumulative_meter.add("loss", global_loss.item()) + + if dist.get_rank() == dist.get_world_size() - 1: + loss_mean = self.accumulative_meter.get("loss") + msg = "Evaluation Result:\n" + for tag in ["loss"]: + msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" + print(msg) + if self.save_dir is not None: + os.makedirs(self.save_dir, exist_ok=True) + with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: + f.write(msg) + step_bar.close() + + else: + step_bar = trange( + len(self.eval_dataloader), + desc=f"Epoch {epoch + 1}/{self.max_epochs}", + disable=not is_rank_0(), + ) + for batch in self.eval_dataloader: + batch = to_device(batch, torch.cuda.current_device()) + outputs = self.model( + batch["input_ids"], + attention_mask=batch["attention_mask"], + labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"], + ) + loss_mean = all_reduce_mean(tensor=outputs.loss) + self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0)) + step_bar.update() + + loss_mean = self.accumulative_meter.get("loss") + msg = "Evaluation Result:\n" + for tag in ["loss"]: + msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" + self.coordinator.print_on_master(msg) + if self.save_dir is not None: + os.makedirs(self.save_dir, exist_ok=True) + with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: + f.write(msg) + step_bar.close() diff --git a/applications/ColossalChat/coati/trainer/utils.py b/applications/ColossalChat/coati/trainer/utils.py index 3c836b4b4..217a87cf0 100755 --- a/applications/ColossalChat/coati/trainer/utils.py +++ b/applications/ColossalChat/coati/trainer/utils.py @@ -9,6 +9,8 @@ import torch.distributed as dist from torch.utils._pytree import tree_map from torch.utils.data import DataLoader +from colossalai.booster import Plugin + class CycledDataLoader: """ @@ -85,7 +87,7 @@ def to_device(x: Any, device: torch.device) -> Any: return tree_map(_to, x) -def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: +def all_reduce_mean(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor: """ Perform all-reduce operation on the given tensor and compute the mean across all processes. @@ -95,8 +97,13 @@ def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: The reduced tensor with mean computed across all processes. """ - dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) - tensor.div_(dist.get_world_size()) + # All reduce mean across DP group + if plugin is not None: + dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM, group=plugin.dp_group) + tensor.div_(plugin.dp_size) + else: + dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) + tensor.div_(dist.get_world_size()) return tensor diff --git a/applications/ColossalChat/examples/training_scripts/train_dpo.py b/applications/ColossalChat/examples/training_scripts/train_dpo.py index d88750aeb..3b324ee78 100755 --- a/applications/ColossalChat/examples/training_scripts/train_dpo.py +++ b/applications/ColossalChat/examples/training_scripts/train_dpo.py @@ -267,6 +267,7 @@ def train(args): ref_model=ref_model, booster=booster, actor_optim=optim, + plugin=plugin, actor_lr_scheduler=lr_scheduler, tokenizer=tokenizer, max_epochs=args.max_epochs, diff --git a/applications/ColossalChat/examples/training_scripts/train_kto.py b/applications/ColossalChat/examples/training_scripts/train_kto.py index 598fd8062..931c16577 100755 --- a/applications/ColossalChat/examples/training_scripts/train_kto.py +++ b/applications/ColossalChat/examples/training_scripts/train_kto.py @@ -286,6 +286,7 @@ def train(args): ref_model=ref_model, booster=booster, actor_optim=optim, + plugin=plugin, actor_lr_scheduler=lr_scheduler, tokenizer=tokenizer, max_epochs=args.max_epochs, diff --git a/applications/ColossalChat/examples/training_scripts/train_orpo.py b/applications/ColossalChat/examples/training_scripts/train_orpo.py index 87860f7ea..0f2fbfa2b 100755 --- a/applications/ColossalChat/examples/training_scripts/train_orpo.py +++ b/applications/ColossalChat/examples/training_scripts/train_orpo.py @@ -250,6 +250,7 @@ def train(args): actor=model, booster=booster, actor_optim=optim, + plugin=plugin, actor_lr_scheduler=lr_scheduler, tokenizer=tokenizer, max_epochs=args.max_epochs, diff --git a/applications/ColossalChat/examples/training_scripts/train_rm.py b/applications/ColossalChat/examples/training_scripts/train_rm.py index 4c0a782b4..5ea1a06ac 100755 --- a/applications/ColossalChat/examples/training_scripts/train_rm.py +++ b/applications/ColossalChat/examples/training_scripts/train_rm.py @@ -262,6 +262,7 @@ def train(args): model, booster, optim, + plugin, lr_scheduler, tokenizer, loss_fn=loss_fn, diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.py b/applications/ColossalChat/examples/training_scripts/train_sft.py index c4ef3b783..62acad32f 100755 --- a/applications/ColossalChat/examples/training_scripts/train_sft.py +++ b/applications/ColossalChat/examples/training_scripts/train_sft.py @@ -114,7 +114,7 @@ def train(args): parallel_output=False, max_norm=args.grad_clip, precision=args.mixed_precision, - microbatch_size=args.batch_size, + microbatch_size=args.microbatch_size, ) else: raise ValueError(f"Unknown plugin {args.plugin}") @@ -269,6 +269,7 @@ def train(args): model=model, booster=booster, optim=optim, + plugin=plugin, lr_scheduler=lr_scheduler, max_epochs=args.max_epochs, accumulation_steps=args.accumulation_steps, @@ -344,6 +345,7 @@ if __name__ == "__main__": parser.add_argument("--use_wandb", default=False, action="store_true") parser.add_argument("--grad_checkpoint", default=False, action="store_true") parser.add_argument("--use_flash_attn", default=False, action="store_true") + parser.add_argument("--microbatch_size", type=int, default=1) args = parser.parse_args() if args.config_file is not None: os.makedirs(os.path.dirname(args.config_file), exist_ok=True) diff --git a/applications/ColossalChat/tests/test_lora.py b/applications/ColossalChat/tests/test_lora.py index 778759210..a63650517 100755 --- a/applications/ColossalChat/tests/test_lora.py +++ b/applications/ColossalChat/tests/test_lora.py @@ -61,7 +61,7 @@ def test_overfit(): _, predicted = torch.max(outputs.data, 1) total = labels.size(0) correct = (predicted == Y).sum().item() - assert (correct / total > 0.95, "The model has not overfitted to the synthesized dataset") + assert correct / total > 0.95 assert (weight_to_compare - model.fc1.weight).sum() < 0.01 diff --git a/applications/ColossalChat/tests/test_train.sh b/applications/ColossalChat/tests/test_train.sh index 69036de63..2935a6369 100755 --- a/applications/ColossalChat/tests/test_train.sh +++ b/applications/ColossalChat/tests/test_train.sh @@ -15,7 +15,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" } -set_n_least_used_CUDA_VISIBLE_DEVICES 2 +set_n_least_used_CUDA_VISIBLE_DEVICES 4 set -xu @@ -30,7 +30,7 @@ MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models MODELS_DIR=$TEMP_DIR/models_config # Skip those tests due to CI tests timeout MODELS=('llama') -ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu') # pp is still buggy +ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu' 'pp' 'tp_pp') PLUGINS=('zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu') LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally LORA_CONFIG_ENABLE="--lora_config $BASE_DIR/examples/training_scripts/lora_config.json" @@ -91,7 +91,7 @@ SKIPPED_TESTS=( llama-gemini_auto-20 # gemini_auto plugin doesn't support lora llama-gemini-20 # gemini doesn't support lora ) - +skip_eval=false GRAD_CKPTS=('--grad_checkpoint') for lora_rank in ${LORA_RANK[@]}; do for model in ${MODELS[@]}; do @@ -129,15 +129,18 @@ for lora_rank in ${LORA_RANK[@]}; do plugin='3d' fi if [[ $plugin == "tp_pp" ]]; then + echo "Here" tp='2' bs='8' pp='2' plugin='3d' + skip_eval=true fi if [[ $plugin == "pp" ]]; then bs='8' pp='2' plugin='3d' + skip_eval=true fi if [[ $plugin == "sp_split_gather" ]]; then enable_sequence_parallelism='--enable_sequence_parallelism' @@ -175,28 +178,53 @@ for lora_rank in ${LORA_RANK[@]}; do for split in $(seq -f "%05g" 0 0); do dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split") done - colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \ - --pretrain $pretrain \ - --tokenizer_dir $tokenizer_dir \ - --dataset ${dataset[@]} \ - --eval_dataset ${dataset[@]} \ - --save_path $MODEL_SAVE_PATH \ - --config_file $MODELS_DIR/config.jsonl \ - $lora_config \ - --plugin $plugin \ - --batch_size $bs \ - --max_epochs 1 \ - --accumulation_steps $grad_accu \ - --tp $tp \ - --pp $pp \ - --zero_stage $zero_stage \ - --sp $sp \ - --sp_mode $sp_mode \ - $enable_sequence_parallelism \ - --lr 2e-5 \ - $grad_ckpt \ - --max_len 400 \ - --use_flash_attn + + if [[ $skip_eval ]]; then + colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \ + --pretrain $pretrain \ + --tokenizer_dir $tokenizer_dir \ + --dataset ${dataset[@]} \ + --save_path $MODEL_SAVE_PATH \ + --config_file $MODELS_DIR/config.jsonl \ + $lora_config \ + --plugin $plugin \ + --batch_size $bs \ + --max_epochs 1 \ + --accumulation_steps $grad_accu \ + --tp $tp \ + --pp $pp \ + --zero_stage $zero_stage \ + --sp $sp \ + --sp_mode $sp_mode \ + $enable_sequence_parallelism \ + --lr 2e-5 \ + $grad_ckpt \ + --max_len 400 \ + --use_flash_attn + else + colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \ + --pretrain $pretrain \ + --tokenizer_dir $tokenizer_dir \ + --dataset ${dataset[@]} \ + --eval_dataset ${dataset[@]} \ + --save_path $MODEL_SAVE_PATH \ + --config_file $MODELS_DIR/config.jsonl \ + $lora_config \ + --plugin $plugin \ + --batch_size $bs \ + --max_epochs 1 \ + --accumulation_steps $grad_accu \ + --tp $tp \ + --pp $pp \ + --zero_stage $zero_stage \ + --sp $sp \ + --sp_mode $sp_mode \ + $enable_sequence_parallelism \ + --lr 2e-5 \ + $grad_ckpt \ + --max_len 400 \ + --use_flash_attn + fi passed=$? if [ $passed -eq 0 ]; then rm -rf ${MODEL_SAVE_PATH:?}/* From 7cf9df07bcb267c0839d8880b109c4c7d55e80fa Mon Sep 17 00:00:00 2001 From: Wenxuan Tan Date: Fri, 23 Aug 2024 15:44:27 +0800 Subject: [PATCH 16/16] [Hotfix] Fix llama fwd replacement bug (#6031) Co-authored-by: Edenzzzz --- colossalai/shardformer/policies/llama.py | 27 ++++++++++++------------ 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index f72a72df0..67e6e92d1 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -95,19 +95,20 @@ class LlamaPolicy(Policy): policy=policy, target_key=attn_cls, ) - if self.pipeline_stage_manager is None: - self.append_or_create_method_replacement( - description={ - "forward": get_llama_flash_attention_model_forward( - self.shard_config, - sp_mode=sp_mode, - sp_size=sp_size, - sp_group=sp_group, - ), - }, - policy=policy, - target_key=LlamaModel, - ) + + if self.pipeline_stage_manager is None: + self.append_or_create_method_replacement( + description={ + "forward": get_llama_flash_attention_model_forward( + self.shard_config, + sp_mode=sp_mode, + sp_size=sp_size, + sp_group=sp_group, + ), + }, + policy=policy, + target_key=LlamaModel, + ) if self.shard_config.enable_tensor_parallelism: assert (