diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index f4174261a..a282439cb 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -39,6 +39,7 @@ class GRPOConsumer(BaseConsumer): use_wandb=True, generate_config=None, training_config={}, + project_name=None, ): super().__init__( num_producers, @@ -69,6 +70,7 @@ class GRPOConsumer(BaseConsumer): self.accum_count = 0 self.generate_config = generate_config self.training_config = training_config + self.project_name = project_name # Reference model is initialized from policy model. self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) @@ -111,7 +113,7 @@ class GRPOConsumer(BaseConsumer): ): # Initialize wandb. name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}" - self.wandb_run = wandb.init(project="GRPO-V1-PP", sync_tensorboard=True, dir="./wandb", name=name) + self.wandb_run = wandb.init(project=self.project_name, sync_tensorboard=True, dir="./wandb", name=name) self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost( self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler @@ -239,6 +241,8 @@ class GRPOConsumer(BaseConsumer): "source": self.rank, } + kl = [] + def _criterion(outputs, inputs): action_logits = outputs.logits action_log_probs = calc_action_log_probs( @@ -252,6 +256,10 @@ class GRPOConsumer(BaseConsumer): - (inputs["reference_action_log_probs"] - action_log_probs) - 1 ) + appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum( + inputs["action_mask"], dim=-1 + ) + kl.append(appox_kl.mean()) loss, skip_update, _ = self.policy_loss_fn( action_log_probs, action_log_probs, @@ -273,26 +281,11 @@ class GRPOConsumer(BaseConsumer): loss = policy_model_outputs["loss"] if self.booster.plugin.stage_manager.is_last_stage(): - # calculate kl, as we cannot do this inside callback, kl needs be calculate again - action_logits = policy_model_outputs["outputs"]["logits"] - action_log_probs = calc_action_log_probs( - action_logits / self.generate_config["temperature"], - input_ids_forward_micro_batch, - num_action, - self.plugin.shard_config, - ) - per_token_kl = ( - torch.exp(reference_action_log_probs - action_log_probs) - - (reference_action_log_probs - action_log_probs) - - 1 - ) - kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum( - action_mask_forward_micro_batch, dim=-1 - ) - kl = all_reduce_mean(kl.mean(), self.plugin) + if len(kl) > 0: + kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin) + mean_kl.append(kl) loss = all_reduce_mean(loss, self.plugin) mean_loss.append(loss.data) - mean_kl.append(kl) else: policy_model_logits = self.policy_model( diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index ba5d3a9d4..699d90a8c 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -47,6 +47,7 @@ def launch_distributed( master_addr: str = "localhost", master_port: int = 29500, core_algo: str = "GRPO", + project_name: Optional[str] = None, ): if core_algo not in ALGO_MAP: @@ -108,6 +109,7 @@ def launch_distributed( "train_microbatch_size": train_microbatch_size, }, num_generations=num_generations, + project_name=project_name, ) procs.append(consumer) ray.get([p.setup.remote() for p in procs]) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index bf7a657e5..f87f12ed2 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -11,32 +11,41 @@ if __name__ == "__main__": parser.add_argument("-t", "--num-trainers", type=int, default=2) parser.add_argument("-i", "--num-inferencer", type=int, default=2) parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.") + parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.") parser.add_argument( - "-ibs", "--inference-batch-size", type=int, default=64, help="Number of prompts to generate per step." + "-ibs", + "--inference-batch-size", + type=int, + default=64, + help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.", ) parser.add_argument( "-imbs", "--inference-microbatch-size", type=int, default=8, - help="Number of prompts to send from the producer to the consumer.", + help="Effective batch size for the inference backend to run generation. Please select based on memory constraint.", ) parser.add_argument( - "-tbs", "--train-batch-size", type=int, default=32, help="Number of prompts to update policy model." + "-tbs", + "--train-batch-size", + type=int, + default=32, + help="Number of unique prompts to update policy model per step per dp group. Gradient is accumulated across tbs * dp_size unique prompts, equivalently tbs * g * dp_size samples", ) parser.add_argument( "-tMbs", "--train-minibatch-size", type=int, default=1, - help="Number of prompts per device. Number of samples = tMbs * num of generation per prompt.", + help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs", ) parser.add_argument( "-tmbs", "--train-microbatch-size", type=int, default=2, - help="Number of samples per device. PP micro batchsize when PP is activated.", + help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.", ) parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"]) parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"]) @@ -119,6 +128,7 @@ if __name__ == "__main__": }, # for pp inference_backend=args.backend, master_addr="localhost", - master_port=29505, + master_port=29506, core_algo=args.algo, + project_name=args.project, )