[feat] Update consumer init to run 32B , update qwen benchmark.

This commit is contained in:
xysheng-colossal 2025-07-21 11:26:18 +08:00
parent ad1ceb0424
commit 40310bdd5e
3 changed files with 25 additions and 12 deletions

View File

@ -69,8 +69,8 @@ class GRPOConsumer(BaseConsumer):
enable_profiling=enable_profiling,
n_behind=n_behind,
)
path = model_config.pop("path")
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.path = model_config.pop("path")
self.policy_model = AutoModelForCausalLM.from_pretrained(self.path, **model_config)
self.policy_model.train()
self.policy_model.gradient_checkpointing_enable()
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
@ -98,12 +98,7 @@ class GRPOConsumer(BaseConsumer):
loss_variation=grpo_config.get("loss_variation", "sample_level"),
)
# Reference model is initialized from policy model.
if self.policy_loss_fn.beta > 0:
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.reference_model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.tokenizer = AutoTokenizer.from_pretrained(self.path)
self.pad_token_id = self.tokenizer.pad_token_id
self.num_generations = num_generations
self.filter_range = grpo_config.get("filter_range", None)
@ -148,7 +143,10 @@ class GRPOConsumer(BaseConsumer):
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
)
# Reference model is initialized from policy model.
if self.policy_loss_fn.beta > 0:
self.reference_model = AutoModelForCausalLM.from_pretrained(self.path, **self.model_config)
self.reference_model.eval()
self.reference_model, *_ = self.booster.boost(self.reference_model)
self.plugin.logger.set_level("ERROR")

View File

@ -53,7 +53,7 @@ def main():
# ==============================
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", type=str, default="7b", help="Model configuration")
parser.add_argument("-model", "--model_path", type=str, help="Model path")
parser.add_argument("--model_path", type=str, help="Model path")
parser.add_argument(
"-p",
"--plugin",
@ -85,6 +85,7 @@ def main():
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"])
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
parser.add_argument("--profile", action="store_true", help="Profile the code")
parser.add_argument("--cpu_offload", action="store_true", help="Cpu offload")
parser.add_argument(
"--nsys",
action="store_true",
@ -142,6 +143,7 @@ def main():
pp_style=args.pp_style,
num_model_chunks=args.n_chunks,
zero_stage=args.zero,
cpu_offload=args.cpu_offload,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
enable_sequence_parallelism=args.sp > 1,
@ -204,7 +206,11 @@ def main():
)
model = Qwen2ForCausalLM.from_pretrained(
MODEL_PATH, trust_remote_code=True, use_flash_attention_2=False, use_cache=False, attn_implementation="eager"
args.model_path,
trust_remote_code=True,
use_flash_attention_2=False,
use_cache=False,
attn_implementation="eager",
)
if args.grad_checkpoint:
model.gradient_checkpointing_enable()

View File

@ -6,5 +6,14 @@
export OMP_NUM_THREADS=8
#hybird: zero2+flash_atten+grad_ckpt+bs4
colossalai run --nproc_per_node 8 benchmark.py -m "/home/grpo/models/Qwen2.5-7B/" -p "3d" -x -g --zero 1 -b 32 --mbs 1 --tp 2 --pp 2 -l 4096
colossalai run --nproc_per_node 8 benchmark.py \
--model_path "/home/grpo/models/DeepSeek-R1-Distill-Qwen-7B/" \
-p "3d" \
-x -g \
--zero 1 \
--cpu_offload \
-b 16 --mbs 1 \
--tp 4 --pp 2 \
-l 4096 \
-s 3 \
&>qwen2_7b.log &