diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index de2897383..027acc2e7 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -66,12 +66,7 @@ class BaseConsumer: cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model") launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0) - plugin_config = dict( - tp_size=1, - pp_size=1, - precision="bf16", - zero_stage=1, - ) + plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2) if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config: plugin_config["microbatch_size"] = self.microbatch_size plugin_config.update(self.plugin_config) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 1c0773f4e..4174f9651 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -68,6 +68,7 @@ class GRPOConsumer(BaseConsumer): self.accum_response_length = torch.zeros(1, device=self.device) self.accum_count = 0 self.generate_config = generate_config + self.training_config = training_config # Reference model is initialized from policy model. self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) @@ -131,40 +132,16 @@ class GRPOConsumer(BaseConsumer): num_action = action_mask.shape[1] old_action_log_probs = data["action_log_probs"] response_length = torch.sum(action_mask, dim=1).to(torch.float32) + forward_batch_size = self.training_config.get("train_microbatch_size", data["input_ids"].size(0)) need_update = (step_idx + 1) % self.num_microbatches == 0 - ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer) + # Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500 + ctx = ( + nullcontext() + if need_update or self.booster.plugin.zero_stage == 2 + else self.booster.no_sync(self.policy_model, self.optimizer) + ) with ctx: - policy_model_logits = self.policy_model( - input_ids=data["input_ids"], - attention_mask=data["attention_mask"], - )["logits"] - action_log_probs = calc_action_log_probs( - policy_model_logits / self.generate_config["temperature"], - data["input_ids"], - num_action, - self.plugin.shard_config, - ) - - with torch.no_grad(): - reference_model_logits = self.reference_model( - input_ids=data["input_ids"], - attention_mask=data["attention_mask"], - )["logits"] - reference_action_log_probs = calc_action_log_probs( - reference_model_logits / self.generate_config["temperature"], - data["input_ids"], - num_action, - self.plugin.shard_config, - ) - - per_token_kl = ( - torch.exp(reference_action_log_probs - action_log_probs) - - (reference_action_log_probs - action_log_probs) - - 1 - ) - kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1) - reward_group = self.reward_model( data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"] ) @@ -177,6 +154,11 @@ class GRPOConsumer(BaseConsumer): group_reward = reward.view(-1, self.num_generations) reward_mean = group_reward.mean(dim=1) + # [batch_size x num_generations] + reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0) + reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) + # [batch_size x num_generations] + advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1) # filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score), loss_mask = ( None @@ -185,35 +167,82 @@ class GRPOConsumer(BaseConsumer): reward_mean > self.filter_range[0], reward_mean < self.filter_range[1] ).repeat_interleave(self.num_generations, dim=0) ) + mean_kl, mean_loss = [], [] + for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size): + input_ids_forward_micro_batch = data["input_ids"][ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + attention_mask_forward_micro_batch = data["attention_mask"][ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + action_mask_forward_micro_batch = action_mask[ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + loss_mask_forward_micro_batch = ( + loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size] + if loss_mask is not None + else None + ) + advantages_forward_micro_batch = advantages[ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + policy_model_logits = self.policy_model( + input_ids=input_ids_forward_micro_batch, + attention_mask=attention_mask_forward_micro_batch, + ).logits + action_log_probs = calc_action_log_probs( + policy_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) - # [batch_size x num_generations] - reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0) - reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) - # [batch_size x num_generations] - advantages = (reward - reward_mean) / (reward_std + 1e-4) + with torch.no_grad(): + reference_model_logits = self.reference_model( + input_ids=input_ids_forward_micro_batch, + attention_mask=attention_mask_forward_micro_batch, + ).logits + reference_action_log_probs = calc_action_log_probs( + reference_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) - loss, skip_update, _ = self.policy_loss_fn( - action_log_probs, - old_action_log_probs, - advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1), - per_token_kl, - action_mask, - loss_mask=loss_mask, - ) + per_token_kl = ( + torch.exp(reference_action_log_probs - action_log_probs) + - (reference_action_log_probs - action_log_probs) + - 1 + ) + kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum( + action_mask_forward_micro_batch, dim=-1 + ) + + loss, skip_update, _ = self.policy_loss_fn( + action_log_probs, + old_action_log_probs, + advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1), + per_token_kl, + action_mask_forward_micro_batch, + loss_mask=loss_mask_forward_micro_batch, + ) + + if not skip_update: + self.booster.backward(loss, self.optimizer) + loss = all_reduce_mean(loss, self.plugin) + kl = all_reduce_mean(kl.mean(), self.plugin) + # Calculate accumulate value. + mean_kl.append(kl.data) + mean_loss.append(loss.data) - if not skip_update: - self.booster.backward(loss, self.optimizer) - loss = all_reduce_mean(loss, self.plugin) reward = all_reduce_mean(reward.mean(), self.plugin) - kl = all_reduce_mean(kl.mean(), self.plugin) format_reward = all_reduce_mean(format_reward.mean(), self.plugin) acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin) advantages = all_reduce_mean(advantages.mean(), self.plugin) response_length = all_reduce_mean(response_length.mean(), self.plugin) - # Calculate accumulate value. - self.accum_loss.add_(loss.data) + self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) + self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) self.accum_reward.add_(reward.data) - self.accum_kl.add_(kl.data) self.accum_format_reward.add_(format_reward.data) self.accum_acc_reward.add_(acc_reward.data) self.accum_advantages.add_(advantages.data) diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index c50db1378..ba5d3a9d4 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -34,6 +34,7 @@ def launch_distributed( inference_microbatch_size: int, train_batch_size: int, train_microbatch_size: int, + train_minibatch_size: int, dataset_config: Dict[str, Any], dataloaders_config: Dict[str, Any], inference_model_config: Dict[str, Any], @@ -99,9 +100,13 @@ def launch_distributed( batch_size=train_batch_size, model_config=train_model_config, plugin_config=plugin_config, - microbatch_size=train_microbatch_size, + microbatch_size=train_minibatch_size, generate_config=generate_config_consumer, - training_config={"filter_range": [0.05, 9.0], "lr": 1e-6}, + training_config={ + "filter_range": [0.05, 9.0], + "lr": 1e-6, + "train_microbatch_size": train_microbatch_size, + }, num_generations=num_generations, ) procs.append(consumer) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 51a1af332..2c6a24a36 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -100,6 +100,7 @@ class BaseProducer: if i >= num_valid_microbatches: break outputs = self.rollout(**batch) + print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") outputs["temperature"] = torch.tensor( [self.model.generate_config.temperature] * outputs["input_ids"].size(0) @@ -116,16 +117,19 @@ class BaseProducer: print( f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}" ) + state_dict = ray_broadcast_tensor_dict( None, self.num_producers, device=self.device, group_name="sync_model" ) self.load_state_dict(state_dict) + del state_dict + torch.cuda.empty_cache() # linear annealing for 1 episode, temperature from initial to 0.7 if episode <= 0: ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader) - self.model.generate_config.temperature = ( - ratio * self.generate_config["temperature"] + (1 - ratio) * 0.7 - ) + self.model.generate_config.temperature = (1 - ratio) * self.generate_config[ + "temperature" + ] + ratio * 0.7 @ray.remote diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index a67a10bc5..4a4a4c340 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -10,18 +10,30 @@ if __name__ == "__main__": parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") parser.add_argument("-t", "--num-trainers", type=int, default=2) parser.add_argument("-i", "--num-inferencer", type=int, default=2) + parser.add_argument("-g", "--num-generations", type=int, default=8) parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64) parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8) parser.add_argument("-tbs", "--train-batch-size", type=int, default=32) - parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=1) + parser.add_argument("-tMbs", "--train-minibatch-size", type=int, default=1) + parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2) parser.add_argument("-b", "--backend", type=str, default="transformers") parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"]) args = parser.parse_args() + assert args.train_minibatch_size > 0, "Train mini batch size must be greater than 0" + assert ( + args.train_minibatch_size * args.num_generations >= args.train_microbatch_size + and args.train_microbatch_size > 0 + ), "Train micro batch size must be greater than 0 less than train mini batch size * num generations" + ray.init(address="local", namespace="ray-example") inference_model_config = dict(path=args.model) - train_model_config = dict(path=args.model) + train_model_config = dict( + path=args.model, + # use_flash_attention_2=True, + # use_cache=False + ) generate_config = dict(top_k=50, top_p=0.75, temperature=0.9) if args.backend == "transformers": @@ -31,13 +43,6 @@ if __name__ == "__main__": torch_dtype=torch.bfloat16, ) ) - train_model_config.update( - dict( - use_flash_attention_2=True, - torch_dtype=torch.bfloat16, - use_cache=False, - ) - ) generate_config.update( dict( max_length=1024 + 512, @@ -78,15 +83,17 @@ if __name__ == "__main__": inference_batch_size=args.inference_batch_size, inference_microbatch_size=args.inference_microbatch_size, train_batch_size=args.train_batch_size, + train_minibatch_size=args.train_minibatch_size, train_microbatch_size=args.train_microbatch_size, dataset_config={"path": args.dataset, "max_length": 300}, dataloaders_config={}, inference_model_config=inference_model_config, generate_config=generate_config, + num_generations=args.num_generations, train_model_config=train_model_config, plugin_config={}, inference_backend=args.backend, master_addr="localhost", - master_port=29503, + master_port=29505, core_algo=args.algo, )