diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 03971e255..7c7eeb0bb 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -69,8 +69,8 @@ class GRPOConsumer(BaseConsumer): enable_profiling=enable_profiling, n_behind=n_behind, ) - path = model_config.pop("path") - self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) + self.path = model_config.pop("path") + self.policy_model = AutoModelForCausalLM.from_pretrained(self.path, **model_config) self.policy_model.train() self.policy_model.gradient_checkpointing_enable() self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6)) @@ -98,12 +98,7 @@ class GRPOConsumer(BaseConsumer): loss_variation=grpo_config.get("loss_variation", "sample_level"), ) - # Reference model is initialized from policy model. - if self.policy_loss_fn.beta > 0: - self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) - self.reference_model.eval() - - self.tokenizer = AutoTokenizer.from_pretrained(path) + self.tokenizer = AutoTokenizer.from_pretrained(self.path) self.pad_token_id = self.tokenizer.pad_token_id self.num_generations = num_generations self.filter_range = grpo_config.get("filter_range", None) @@ -148,7 +143,10 @@ class GRPOConsumer(BaseConsumer): self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost( self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler ) + # Reference model is initialized from policy model. if self.policy_loss_fn.beta > 0: + self.reference_model = AutoModelForCausalLM.from_pretrained(self.path, **self.model_config) + self.reference_model.eval() self.reference_model, *_ = self.booster.boost(self.reference_model) self.plugin.logger.set_level("ERROR") diff --git a/examples/language/qwen2/benchmark.py b/examples/language/qwen2/benchmark.py index d37132fd2..d446188b4 100644 --- a/examples/language/qwen2/benchmark.py +++ b/examples/language/qwen2/benchmark.py @@ -53,7 +53,7 @@ def main(): # ============================== parser = argparse.ArgumentParser() parser.add_argument("-c", "--config", type=str, default="7b", help="Model configuration") - parser.add_argument("-model", "--model_path", type=str, help="Model path") + parser.add_argument("--model_path", type=str, help="Model path") parser.add_argument( "-p", "--plugin", @@ -85,6 +85,7 @@ def main(): parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"]) parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval) parser.add_argument("--profile", action="store_true", help="Profile the code") + parser.add_argument("--cpu_offload", action="store_true", help="Cpu offload") parser.add_argument( "--nsys", action="store_true", @@ -142,6 +143,7 @@ def main(): pp_style=args.pp_style, num_model_chunks=args.n_chunks, zero_stage=args.zero, + cpu_offload=args.cpu_offload, sp_size=args.sp, sequence_parallelism_mode=args.sp_mode, enable_sequence_parallelism=args.sp > 1, @@ -204,7 +206,11 @@ def main(): ) model = Qwen2ForCausalLM.from_pretrained( - MODEL_PATH, trust_remote_code=True, use_flash_attention_2=False, use_cache=False, attn_implementation="eager" + args.model_path, + trust_remote_code=True, + use_flash_attention_2=False, + use_cache=False, + attn_implementation="eager", ) if args.grad_checkpoint: model.gradient_checkpointing_enable() diff --git a/examples/language/qwen2/hybrid_test_N1C8.sh b/examples/language/qwen2/hybrid_test_N1C8.sh index 36919901d..7e579ed6d 100644 --- a/examples/language/qwen2/hybrid_test_N1C8.sh +++ b/examples/language/qwen2/hybrid_test_N1C8.sh @@ -6,5 +6,14 @@ export OMP_NUM_THREADS=8 -#hybird: zero2+flash_atten+grad_ckpt+bs4 -colossalai run --nproc_per_node 8 benchmark.py -m "/home/grpo/models/Qwen2.5-7B/" -p "3d" -x -g --zero 1 -b 32 --mbs 1 --tp 2 --pp 2 -l 4096 +colossalai run --nproc_per_node 8 benchmark.py \ + --model_path "/home/grpo/models/DeepSeek-R1-Distill-Qwen-7B/" \ + -p "3d" \ + -x -g \ + --zero 1 \ + --cpu_offload \ + -b 16 --mbs 1 \ + --tp 4 --pp 2 \ + -l 4096 \ + -s 3 \ + &>qwen2_7b.log &