mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-23 11:44:15 +00:00
[feat] Update consumer init to run 32B , update qwen benchmark.
This commit is contained in:
parent
ad1ceb0424
commit
40310bdd5e
@ -69,8 +69,8 @@ class GRPOConsumer(BaseConsumer):
|
||||
enable_profiling=enable_profiling,
|
||||
n_behind=n_behind,
|
||||
)
|
||||
path = model_config.pop("path")
|
||||
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
self.path = model_config.pop("path")
|
||||
self.policy_model = AutoModelForCausalLM.from_pretrained(self.path, **model_config)
|
||||
self.policy_model.train()
|
||||
self.policy_model.gradient_checkpointing_enable()
|
||||
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
|
||||
@ -98,12 +98,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
loss_variation=grpo_config.get("loss_variation", "sample_level"),
|
||||
)
|
||||
|
||||
# Reference model is initialized from policy model.
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
self.reference_model.eval()
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.path)
|
||||
self.pad_token_id = self.tokenizer.pad_token_id
|
||||
self.num_generations = num_generations
|
||||
self.filter_range = grpo_config.get("filter_range", None)
|
||||
@ -148,7 +143,10 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
|
||||
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
|
||||
)
|
||||
# Reference model is initialized from policy model.
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
self.reference_model = AutoModelForCausalLM.from_pretrained(self.path, **self.model_config)
|
||||
self.reference_model.eval()
|
||||
self.reference_model, *_ = self.booster.boost(self.reference_model)
|
||||
self.plugin.logger.set_level("ERROR")
|
||||
|
||||
|
@ -53,7 +53,7 @@ def main():
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-c", "--config", type=str, default="7b", help="Model configuration")
|
||||
parser.add_argument("-model", "--model_path", type=str, help="Model path")
|
||||
parser.add_argument("--model_path", type=str, help="Model path")
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--plugin",
|
||||
@ -85,6 +85,7 @@ def main():
|
||||
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"])
|
||||
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
|
||||
parser.add_argument("--profile", action="store_true", help="Profile the code")
|
||||
parser.add_argument("--cpu_offload", action="store_true", help="Cpu offload")
|
||||
parser.add_argument(
|
||||
"--nsys",
|
||||
action="store_true",
|
||||
@ -142,6 +143,7 @@ def main():
|
||||
pp_style=args.pp_style,
|
||||
num_model_chunks=args.n_chunks,
|
||||
zero_stage=args.zero,
|
||||
cpu_offload=args.cpu_offload,
|
||||
sp_size=args.sp,
|
||||
sequence_parallelism_mode=args.sp_mode,
|
||||
enable_sequence_parallelism=args.sp > 1,
|
||||
@ -204,7 +206,11 @@ def main():
|
||||
)
|
||||
|
||||
model = Qwen2ForCausalLM.from_pretrained(
|
||||
MODEL_PATH, trust_remote_code=True, use_flash_attention_2=False, use_cache=False, attn_implementation="eager"
|
||||
args.model_path,
|
||||
trust_remote_code=True,
|
||||
use_flash_attention_2=False,
|
||||
use_cache=False,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
if args.grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
@ -6,5 +6,14 @@
|
||||
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
#hybird: zero2+flash_atten+grad_ckpt+bs4
|
||||
colossalai run --nproc_per_node 8 benchmark.py -m "/home/grpo/models/Qwen2.5-7B/" -p "3d" -x -g --zero 1 -b 32 --mbs 1 --tp 2 --pp 2 -l 4096
|
||||
colossalai run --nproc_per_node 8 benchmark.py \
|
||||
--model_path "/home/grpo/models/DeepSeek-R1-Distill-Qwen-7B/" \
|
||||
-p "3d" \
|
||||
-x -g \
|
||||
--zero 1 \
|
||||
--cpu_offload \
|
||||
-b 16 --mbs 1 \
|
||||
--tp 4 --pp 2 \
|
||||
-l 4096 \
|
||||
-s 3 \
|
||||
&>qwen2_7b.log &
|
||||
|
Loading…
Reference in New Issue
Block a user