diff --git a/.gitignore b/.gitignore index 8bc74b4c8..16f764c1b 100644 --- a/.gitignore +++ b/.gitignore @@ -163,3 +163,4 @@ coverage.xml # log, test files - ColossalChat applications/ColossalChat/logs applications/ColossalChat/tests/logs +applications/ColossalChat/wandb diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index 8711d0b8c..58414b29f 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -61,12 +61,22 @@ class TransformersInferenceBackend(BaseInferenceBackend): self.generate_config = generate_config.copy() self.generate_config.update(self.FORCE_GENERATE_CONFIG) self.tokenizer = tokenizer + self.num_generations = 8 @torch.no_grad() def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: + micro_batch_size = input_ids.size(0) input_ids = input_ids.to(get_current_device()) attention_mask = attention_mask.to(get_current_device()) - out = self.model.generate(input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config) + gt_answer = None + if "gt_answer" in kwargs: + gt_answer = kwargs.pop("gt_answer") + if self.num_generations > 1: + input_ids = input_ids.repeat_interleave(self.num_generations, dim=0) + attention_mask = attention_mask.repeat_interleave(self.num_generations, dim=0) + out = self.model.generate( + input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config, tokenizer=self.tokenizer + ) input_len = input_ids.shape[-1] new_token_ids = out.sequences[:, input_len:] # get log probs @@ -76,10 +86,13 @@ class TransformersInferenceBackend(BaseInferenceBackend): action_log_probs.append(log_probs_from_logits(logits[:, None, :], new_token_ids[:, i : i + 1])) action_log_probs = torch.cat(action_log_probs, dim=1) # get action mask + response_idx = torch.zeros((new_token_ids.size(0), 2), dtype=torch.int).to(get_current_device()) action_mask = torch.ones_like(new_token_ids, dtype=attention_mask.dtype) if self.tokenizer.eos_token_id is not None: for indices in torch.nonzero(new_token_ids == self.tokenizer.eos_token_id): action_mask[indices[0], indices[1] + 1 :] = 0 + response_idx[:, 0] = input_len + response_idx[:, 1] = input_len + action_mask.sum(dim=1) - 1 if attention_mask.size(0) != action_mask.size(0): assert action_mask.size(0) % attention_mask.size(0) == 0 @@ -91,7 +104,15 @@ class TransformersInferenceBackend(BaseInferenceBackend): "attention_mask": attention_mask, "action_log_probs": action_log_probs, "action_mask": action_mask, + "response_idx": response_idx, } + + data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()} + + if gt_answer is not None: + # repeat gt_answer for each prompt. + data["gt_answer"] = gt_answer.repeat_interleave(self.num_generations, dim=1) + data = {k: v.to(get_current_device()) for k, v in data.items()} return data def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 30d56c90b..1de8b649d 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -10,9 +10,9 @@ if __name__ == "__main__": parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") parser.add_argument("-t", "--num-trainers", type=int, default=2) parser.add_argument("-i", "--num-inferencer", type=int, default=2) - parser.add_argument("-ibs", "--inference-batch-size", type=int, default=32) - parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=16) - parser.add_argument("-tbs", "--train-batch-size", type=int, default=16) + parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64) + parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8) + parser.add_argument("-tbs", "--train-batch-size", type=int, default=32) parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=1) parser.add_argument("-b", "--backend", type=str, default="transformers") parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple, GRPO"]) @@ -24,29 +24,31 @@ if __name__ == "__main__": train_model_config = dict(path=args.model) generate_config = dict( top_k=50, - top_p=0.8, + top_p=0.9, + temperature=1.0, ) if args.backend == "transformers": inference_model_config.update( dict( - attn_implementation="flash_attention_2", + use_flash_attention_2=True, torch_dtype=torch.bfloat16, ) ) train_model_config.update( dict( - attn_implementation="flash_attention_2", + use_flash_attention_2=True, torch_dtype=torch.bfloat16, use_cache=False, ) ) generate_config.update( dict( - max_length=512, + max_length=1024 + 512, do_sample=True, max_new_tokens=None, early_stopping=False, + stop_strings=[""], ) ) elif args.backend == "vllm": @@ -82,12 +84,12 @@ if __name__ == "__main__": num_producers=args.num_inferencer, num_proc_per_producer=1, num_consumer_procs=args.num_trainers, - num_episodes=1, + num_episodes=10, inference_batch_size=args.inference_batch_size, inference_microbatch_size=args.inference_microbatch_size, train_batch_size=args.train_batch_size, train_microbatch_size=args.train_microbatch_size, - dataset_config={"path": args.dataset, "max_length": 256}, + dataset_config={"path": args.dataset, "max_length": 300}, dataloaders_config={}, inference_model_config=inference_model_config, generate_config=generate_config,