support resume training, still buggy

This commit is contained in:
YeAnbang 2025-07-30 18:13:53 +08:00
parent 5c5cb1863b
commit 65f5289e35
16 changed files with 368 additions and 124 deletions

View File

@ -55,6 +55,7 @@ class BaseConsumer:
self.enable_profiling = enable_profiling self.enable_profiling = enable_profiling
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size" assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
self.num_microbatches = batch_size // minibatch_size self.num_microbatches = batch_size // minibatch_size
self.checkpoint_path = model_config.pop("checkpoint_path", None)
self.model_config = model_config self.model_config = model_config
self.plugin_config = plugin_config self.plugin_config = plugin_config
@ -143,6 +144,26 @@ class BaseConsumer:
return effective_group_to_raw_group_mapping return effective_group_to_raw_group_mapping
def loop(self) -> None: def loop(self) -> None:
self.profiler.enter("sync_model")
torch.cuda.empty_cache()
state_dict = self.state_dict()
if self.pp_size > 1:
if self.tp_rank == 0 and self.dp_rank == 0:
ray_broadcast_tensor_dict(
state_dict,
src=self.num_producers,
device=self.device,
group_name=f"sync_model_{self.pp_rank}",
)
else:
if self.rank == 0:
ray_broadcast_tensor_dict(
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
)
del state_dict
torch.cuda.empty_cache()
self.profiler.exit("sync_model")
print( print(
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}" f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
) )
@ -286,7 +307,7 @@ class BaseConsumer:
if self.rank == 0: if self.rank == 0:
print(f"Start saving policy model at step {step + 1}.") print(f"Start saving policy model at step {step + 1}.")
save_path = os.path.join(self.save_dir, f"modeling-episode-{episode}-step-{step + 1}") save_path = os.path.join(self.save_dir, f"modeling-episode-{episode}-step-{step + 1}")
self.booster.save_model(self.policy_model, save_path, shard=True) self.booster.save_model(self.policy_model, save_path, shard=True, use_safetensors=True)
if self.rank == 0: if self.rank == 0:
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}") print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
@ -365,7 +386,7 @@ class SimpleConsumer(BaseConsumer):
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.model.train() self.model.train()
self.model.gradient_checkpointing_enable() self.model.gradient_checkpointing_enable()
self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3) self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3, weight_decay=0.01)
self.accum_loss = torch.zeros(1, device=self.device) self.accum_loss = torch.zeros(1, device=self.device)
def setup(self): def setup(self):

View File

@ -72,7 +72,11 @@ class GRPOConsumer(BaseConsumer):
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.policy_model.train() self.policy_model.train()
self.policy_model.gradient_checkpointing_enable() self.policy_model.gradient_checkpointing_enable()
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6)) self.optimizer = HybridAdam(
self.policy_model.parameters(),
lr=grpo_config.get("lr", 1e-6),
weight_decay=grpo_config.get("weight_decay", 0.01),
)
self.accum_loss = torch.zeros(1, device=self.device) self.accum_loss = torch.zeros(1, device=self.device)
self.accum_kl = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device)
self.accum_entropy = torch.zeros(1, device=self.device) self.accum_entropy = torch.zeros(1, device=self.device)
@ -153,6 +157,8 @@ class GRPOConsumer(BaseConsumer):
) )
if self.policy_loss_fn.beta > 0: if self.policy_loss_fn.beta > 0:
self.reference_model, *_ = self.booster.boost(self.reference_model) self.reference_model, *_ = self.booster.boost(self.reference_model)
if self.checkpoint_path is not None:
self.booster.load_model(self.policy_model, self.checkpoint_path)
self.plugin.logger.set_level("ERROR") self.plugin.logger.set_level("ERROR")
def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]: def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:

View File

@ -55,7 +55,6 @@ def launch_distributed(
eval_dataset_config: Optional[Dict[str, Any]] = None, eval_dataset_config: Optional[Dict[str, Any]] = None,
eval_interval: int = 100, eval_interval: int = 100,
eval_save_dir: Optional[str] = None, eval_save_dir: Optional[str] = None,
eval_generation_config: Optional[Dict[str, Any]] = None,
log_rollout_interval: int = 20, log_rollout_interval: int = 20,
rollout_save_dir: str = "./rollout", rollout_save_dir: str = "./rollout",
enable_profiling: bool = False, enable_profiling: bool = False,
@ -139,7 +138,6 @@ def launch_distributed(
eval_interval=eval_interval, eval_interval=eval_interval,
grpo_config=grpo_config, grpo_config=grpo_config,
eval_save_dir=eval_save_dir, eval_save_dir=eval_save_dir,
eval_generation_config=eval_generation_config,
project_name=project_name, project_name=project_name,
run_name=run_name, run_name=run_name,
wandb_group_name=wandb_group_name, wandb_group_name=wandb_group_name,

View File

@ -203,6 +203,28 @@ class BaseProducer:
raise NotImplementedError raise NotImplementedError
def loop(self) -> None: def loop(self) -> None:
torch.cuda.empty_cache()
self.profiler.enter("sync_model")
if self.consumer_pp_size > 1:
for pp_idx in range(self.consumer_pp_size):
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
)
if "consumer_global_step" in state_dict:
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
self.load_state_dict(state_dict)
else:
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name="sync_model"
)
if "consumer_global_step" in state_dict:
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
self.load_state_dict(state_dict)
self.profiler.exit("sync_model")
del state_dict
torch.cuda.empty_cache()
num_update_per_episode = len(self.train_dataloader) // self.num_microbatches num_update_per_episode = len(self.train_dataloader) // self.num_microbatches
num_valid_microbatches = num_update_per_episode * self.num_microbatches num_valid_microbatches = num_update_per_episode * self.num_microbatches

View File

@ -25,7 +25,12 @@ from latex2sympy2_extended import NormalizationConfig
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
from .code_reward.utils import check_correctness_code_api as check_correctness_code from .code_reward.utils import check_correctness_code_api as check_correctness_code
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure from .reward_utils import (
extract_boxed_solution,
extract_solution,
find_infinite_loop_start,
validate_response_structure,
)
CANNOT_PARSE_GT_ANSWER = -1 CANNOT_PARSE_GT_ANSWER = -1
CANNOT_PARSE_PREDICTION = -2 CANNOT_PARSE_PREDICTION = -2
@ -122,6 +127,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
repetition_reward = 1.0 if detect_repetition(decoded_final_answer) == [] else 0.0
final_answer, processed_str = extract_solution(decoded_final_answer) final_answer, processed_str = extract_solution(decoded_final_answer)
format_valid = validate_response_structure(processed_str, kwargs["tags"]) format_valid = validate_response_structure(processed_str, kwargs["tags"])
@ -137,6 +144,10 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
if format_valid: if format_valid:
format_acc += 1 format_acc += 1
# Add repetition reward
if not eval_mode:
reward += repetition_reward
# Check if the sequence is over length # Check if the sequence is over length
if not eval_mode and res_length >= max_new_tokens: if not eval_mode and res_length >= max_new_tokens:
reward *= 0.0 reward *= 0.0
@ -182,6 +193,8 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
raise ValueError("no gt_answer is provided, please check your training dataset.") raise ValueError("no gt_answer is provided, please check your training dataset.")
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
print(f"Decoded final answer: {decoded_final_answer[-500:]}")
repetition_score = find_infinite_loop_start(input_ids[s : e + 1], min_repeats=2, distance=False)
final_answer = extract_boxed_solution(decoded_final_answer) final_answer = extract_boxed_solution(decoded_final_answer)
format_valid = final_answer is not None format_valid = final_answer is not None
@ -202,6 +215,10 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
if format_valid: if format_valid:
format_acc += 1 format_acc += 1
if not repetition_score > 0 and not eval_mode:
# award for non-repetition
reward += 2
# Check if the sequence is over length # Check if the sequence is over length
if not eval_mode and res_length >= max_new_tokens: if not eval_mode and res_length >= max_new_tokens:
reward *= 0.0 reward *= 0.0

View File

@ -14,7 +14,9 @@
# limitations under the License. # limitations under the License.
import re import re
from typing import Dict, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch
def validate_response_structure(processed_str: str, tags: Dict = None) -> bool: def validate_response_structure(processed_str: str, tags: Dict = None) -> bool:
@ -122,3 +124,51 @@ def extract_boxed_solution(text: str) -> Optional[str]:
except Exception: except Exception:
# Any other unexpected error # Any other unexpected error
return None return None
import Levenshtein
def is_similar(seq1: List[int], seq2: List[int], threshold: float = 0.9) -> bool:
ratio = Levenshtein.ratio(seq1, seq2)
return ratio >= threshold
def find_infinite_loop_start(token_ids: List[int], min_repeats: int = 2, distance: bool = False) -> float:
n = len(token_ids)
# Step 1: Detect the repeating segment at the end using two pointers
longest_valid_length = 0
start_of_loop = n
for length in range(1, n // min_repeats + 1): # Try different phrase lengths
count = 1 # Reset repetition counter
right = n - length # Start comparing from the second last occurrence
while right - length >= 0:
# Check if the current phrase matches the previous phrase
if distance:
if is_similar(token_ids[right - length : right], token_ids[right : right + length]):
count += 1
else:
break # Stop if repetition is broken
else:
# Use torch.equal() for tensor comparison
if torch.equal(token_ids[right - length : right], token_ids[right : right + length]):
count += 1
else:
break # Stop if repetition is broken
right -= length # Move left to check further
if count >= min_repeats: # Found a valid repeating phrase
longest_valid_length = length
start_of_loop = right # This is where the first cycle of the repetition begins
if longest_valid_length == 0:
return 0.0 # No infinite loop found, return repetition ratio as 0
# Step 2: Compute the repetition ratio
repetition_ratio = (n - start_of_loop) / n
return repetition_ratio

View File

@ -55,7 +55,7 @@ class BaseConsumer:
self.num_microbatches = batch_size // minibatch_size self.num_microbatches = batch_size // minibatch_size
self.data_uid = 0 self.data_uid = 0
self.sync_model_thread_started = False self.sync_model_thread_started = False
self.checkpoint_path = model_config.pop("checkpoint_path", None)
self.model_config = model_config self.model_config = model_config
self.plugin_config = plugin_config self.plugin_config = plugin_config
@ -79,6 +79,7 @@ class BaseConsumer:
plugin_config.update(self.plugin_config) plugin_config.update(self.plugin_config)
self.plugin = HybridParallelPlugin(**plugin_config) self.plugin = HybridParallelPlugin(**plugin_config)
self.booster = Booster(plugin=self.plugin) self.booster = Booster(plugin=self.plugin)
self.dp_rank = dist.get_rank(self.plugin.dp_group) self.dp_rank = dist.get_rank(self.plugin.dp_group)
self.tp_rank = dist.get_rank(self.plugin.tp_group) self.tp_rank = dist.get_rank(self.plugin.tp_group)
self.pp_rank = dist.get_rank(self.plugin.pp_group) self.pp_rank = dist.get_rank(self.plugin.pp_group)
@ -165,8 +166,93 @@ class BaseConsumer:
desc=f"Episode {episode} with rollout step(s)", desc=f"Episode {episode} with rollout step(s)",
disable=self.rank != 0, disable=self.rank != 0,
) as pbar: ) as pbar:
need_sync_model = True
while self.received_prompts < self.train_dataset_size: while self.received_prompts < self.train_dataset_size:
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
if need_sync_model and (
(self.global_step + 1) % self.save_interval == 0
or self.received_prompts >= self.train_dataset_size
):
if self.rank == 0:
print(f"Start saving policy model at step {self.global_step + 1}.")
save_path = os.path.join(
self.save_dir, f"modeling-episode-{episode}-step-{self.global_step + 1}"
)
self.booster.save_model(self.policy_model, save_path, shard=True, use_safetensors=True)
if self.rank == 0:
print(f"Saved model checkpoint at step {self.global_step + 1} in folder {save_path}")
if need_sync_model and (
episode != self.num_episodes - 1 or self.received_prompts != self.train_dataset_size
):
def sync_model_thread():
# sync model weights to all producers, if no model update or it is the last training step, skip syncing
if self.pp_size > 1:
print(
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {self.global_step}"
)
else:
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {self.global_step}")
torch.cuda.empty_cache()
if self.pp_size > 1:
if self.tp_rank == 0 and self.dp_rank == 0:
self.profiler.enter("sync_model")
ray.get(
self.shared_signal_actor.set_signal.remote(
f"consumer_pp_{self.pp_rank}", "ready_sync_model"
)
)
print(
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {self.global_step}"
)
ray_broadcast_tensor_dict(
self.state_dict_cpu,
src=0,
device=torch.device("cpu"),
group_name=f"sync_model_consumer_pp_{self.pp_rank}",
backend="gloo",
)
self.profiler.exit("sync_model")
else:
if self.rank == 0:
self.profiler.enter("sync_model")
ray.get(self.shared_signal_actor.set_signal.remote("consumer", "ready_sync_model"))
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {self.global_step}")
ray_broadcast_tensor_dict(
self.state_dict_cpu,
src=0,
device=torch.device("cpu"),
group_name="sync_model_consumer",
backend="gloo",
)
self.profiler.exit("sync_model")
if not self.sync_model_thread_started:
# only sync model when the thread is not started and no other thread is broadcasting
self.sync_model_thread_started = True
state_dict_ = self.state_dict()
if (self.pp_size > 1 and self.tp_rank == 0 and self.dp_rank == 0) or (
self.pp_size == 1 and self.rank == 0
):
if len(self.state_dict_cpu) == 0:
# use pinned memory to speed up the transfer
self.state_dict_cpu = {k: v.cpu().pin_memory() for k, v in state_dict_.items()}
torch.cuda.synchronize()
for k, v in state_dict_.items():
self.state_dict_cpu[k].copy_(v, non_blocking=True)
torch.cuda.synchronize()
cc.barrier(
group_name="consumer_pg"
) # to make sure all ranks have state dict offloaded to CPU before starting the thread
time_before_starting_thread = time.time()
threading.Thread(target=sync_model_thread).start()
# sync_model_thread()
self.profiler.log(
f"Sync model, took {time.time() - time_before_starting_thread:.2f} seconds"
)
self.sync_model_thread_started = False
# ray.get(self.shared_signal_actor.release_process_lock.remote("broadcasting_lock"))
effective_group_to_raw_group_mapping = {} effective_group_to_raw_group_mapping = {}
self.profiler.enter(f"recv_data") self.profiler.enter(f"recv_data")
while len(effective_group_to_raw_group_mapping) < self.dp_size * self.minibatch_size: while len(effective_group_to_raw_group_mapping) < self.dp_size * self.minibatch_size:
@ -254,90 +340,6 @@ class BaseConsumer:
pbar.set_postfix({"loss": loss}) pbar.set_postfix({"loss": loss})
need_sync_model = True need_sync_model = True
ray.get(self.shared_signal_actor.set_signal.remote("global_step", self.global_step + 1)) ray.get(self.shared_signal_actor.set_signal.remote("global_step", self.global_step + 1))
if need_sync_model and (
(self.global_step + 1) % self.save_interval == 0
or self.received_prompts >= self.train_dataset_size
):
if self.rank == 0:
print(f"Start saving policy model at step {self.global_step + 1}.")
save_path = os.path.join(
self.save_dir, f"modeling-episode-{episode}-step-{self.global_step + 1}"
)
self.booster.save_model(self.policy_model, save_path, shard=True)
if self.rank == 0:
print(f"Saved model checkpoint at step {self.global_step + 1} in folder {save_path}")
if need_sync_model and (
episode != self.num_episodes - 1 or self.received_prompts != self.train_dataset_size
):
def sync_model_thread():
# sync model weights to all producers, if no model update or it is the last training step, skip syncing
if self.pp_size > 1:
print(
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {self.global_step}"
)
else:
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {self.global_step}")
torch.cuda.empty_cache()
if self.pp_size > 1:
if self.tp_rank == 0 and self.dp_rank == 0:
self.profiler.enter("sync_model")
ray.get(
self.shared_signal_actor.set_signal.remote(
f"consumer_pp_{self.pp_rank}", "ready_sync_model"
)
)
print(
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {self.global_step}"
)
ray_broadcast_tensor_dict(
self.state_dict_cpu,
src=0,
device=torch.device("cpu"),
group_name=f"sync_model_consumer_pp_{self.pp_rank}",
backend="gloo",
)
self.profiler.exit("sync_model")
else:
if self.rank == 0:
self.profiler.enter("sync_model")
ray.get(self.shared_signal_actor.set_signal.remote("consumer", "ready_sync_model"))
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {self.global_step}")
ray_broadcast_tensor_dict(
self.state_dict_cpu,
src=0,
device=torch.device("cpu"),
group_name="sync_model_consumer",
backend="gloo",
)
self.profiler.exit("sync_model")
if not self.sync_model_thread_started:
# only sync model when the thread is not started and no other thread is broadcasting
self.sync_model_thread_started = True
state_dict_ = self.state_dict()
if (self.pp_size > 1 and self.tp_rank == 0 and self.dp_rank == 0) or (
self.pp_size == 1 and self.rank == 0
):
if len(self.state_dict_cpu) == 0:
# use pinned memory to speed up the transfer
self.state_dict_cpu = {k: v.cpu().pin_memory() for k, v in state_dict_.items()}
torch.cuda.synchronize()
for k, v in state_dict_.items():
self.state_dict_cpu[k].copy_(v, non_blocking=True)
torch.cuda.synchronize()
cc.barrier(
group_name="consumer_pg"
) # to make sure all ranks have state dict offloaded to CPU before starting the thread
time_before_starting_thread = time.time()
threading.Thread(target=sync_model_thread).start()
# sync_model_thread()
self.profiler.log(
f"Sync model, took {time.time() - time_before_starting_thread:.2f} seconds"
)
self.sync_model_thread_started = False
# ray.get(self.shared_signal_actor.release_process_lock.remote("broadcasting_lock"))
self.profiler.log(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") self.profiler.log(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
self.received_prompts = 0 self.received_prompts = 0
ray.get(self.shared_signal_actor.set_signal.remote("consumer", "terminate")) ray.get(self.shared_signal_actor.set_signal.remote("consumer", "terminate"))

View File

@ -20,7 +20,7 @@ class Distributor:
enable_profiling: bool = True, enable_profiling: bool = True,
): ):
self.distributor_id = distributor_id self.distributor_id = distributor_id
self.weight_version = [0] * consumer_pp_size self.weight_version = [-1] * consumer_pp_size
self.consumer_pp_size = consumer_pp_size self.consumer_pp_size = consumer_pp_size
self.state_dict_cpu = {} self.state_dict_cpu = {}
self.num_producers = num_producers self.num_producers = num_producers

View File

@ -75,7 +75,11 @@ class GRPOConsumer(BaseConsumer):
self.policy_model.train() self.policy_model.train()
self.policy_model.gradient_checkpointing_enable() self.policy_model.gradient_checkpointing_enable()
self.vocab_size = self.policy_model.config.vocab_size self.vocab_size = self.policy_model.config.vocab_size
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6)) self.optimizer = HybridAdam(
self.policy_model.parameters(),
lr=grpo_config.get("lr", 1e-6),
weight_decay=grpo_config.get("weight_decay", 0.01),
)
self.accum_loss = torch.zeros(1, device=self.device) self.accum_loss = torch.zeros(1, device=self.device)
self.accum_kl = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device)
self.accum_entropy = torch.zeros(1, device=self.device) self.accum_entropy = torch.zeros(1, device=self.device)
@ -157,6 +161,9 @@ class GRPOConsumer(BaseConsumer):
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost( self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
) )
if self.checkpoint_path is not None:
print(f"Resume training from checkpoint: {self.checkpoint_path}")
self.booster.load_model(self.policy_model, self.checkpoint_path)
if self.policy_loss_fn.beta > 0: if self.policy_loss_fn.beta > 0:
self.reference_model, *_ = self.booster.boost(self.reference_model) self.reference_model, *_ = self.booster.boost(self.reference_model)
self.plugin.logger.set_level("ERROR") self.plugin.logger.set_level("ERROR")

View File

@ -82,7 +82,7 @@ class BaseProducer:
self.eval_interval = eval_interval self.eval_interval = eval_interval
self.eval_save_dir = eval_save_dir self.eval_save_dir = eval_save_dir
self.consumer_global_step = 0 self.consumer_global_step = 0
self.producer_weight_version = 0 self.producer_weight_version = -1
self.eval_mode = False self.eval_mode = False
self.log_rollout_interval = log_rollout_interval self.log_rollout_interval = log_rollout_interval
self.latest_rollout_log_step = -1 self.latest_rollout_log_step = -1
@ -148,19 +148,38 @@ class BaseProducer:
self.evaluation_function = code_reward_fn self.evaluation_function = code_reward_fn
else: else:
raise ValueError(f"Unknown evaluation function type {grpo_config['reward_fn_type']}") raise ValueError(f"Unknown evaluation function type {grpo_config['reward_fn_type']}")
self.eval_dataset_config = eval_dataset_config self.eval_dataset_config = eval_dataset_config
if self.eval_dataset_config is not None: if self.eval_dataset_config is not None:
self.eval_dataloaders = {} self.eval_dataloaders = {}
self.eval_sample_params = {}
for eval_task_name in self.eval_dataset_config: for eval_task_name in self.eval_dataset_config:
eval_sampling_params = copy.deepcopy(generate_config)
eval_sampling_params["n"] = eval_dataset_config[eval_task_name].pop(
"n", 1
) # use 1 generation for evaluation
if "max_tokens" in eval_sampling_params:
eval_sampling_params["max_tokens"] = eval_dataset_config[eval_task_name].pop(
"max_tokens", eval_sampling_params["max_tokens"]
)
elif "max_new_tokens" in eval_sampling_params:
eval_sampling_params["max_new_tokens"] = eval_dataset_config[eval_task_name].pop(
"max_new_tokens", eval_sampling_params["max_new_tokens"]
)
eval_sampling_params["temperature"] = eval_dataset_config[eval_task_name].pop("temperature", 0.6)
eval_sampling_params["logprobs"] = 0 # force parameter
self.eval_sample_params[eval_task_name] = SamplingParams(**eval_sampling_params)
eval_dataset_path = eval_dataset_config[eval_task_name].pop("path") eval_dataset_path = eval_dataset_config[eval_task_name].pop("path")
eval_inference_batch_size = eval_dataset_config[eval_task_name].pop("batch_size", microbatch_size)
eval_dataset = RawConversationDataset( eval_dataset = RawConversationDataset(
self.tokenizer, eval_dataset_path, **eval_dataset_config[eval_task_name] self.tokenizer, eval_dataset_path, **eval_dataset_config[eval_task_name]
) )
print(f"[P{self.producer_idx}] eval dataset {eval_task_name} size: {len(eval_dataset)}") print(
f"[P{self.producer_idx}] eval dataset {eval_task_name} size: {len(eval_dataset)}, eval_batch_size: {eval_inference_batch_size} num_producers:{num_producers}"
)
self.eval_dataloaders[eval_task_name] = DataLoader( self.eval_dataloaders[eval_task_name] = DataLoader(
eval_dataset, eval_dataset,
batch_size=microbatch_size, batch_size=eval_inference_batch_size,
sampler=DistributedSampler( sampler=DistributedSampler(
eval_dataset, eval_dataset,
num_replicas=num_producers, num_replicas=num_producers,
@ -280,7 +299,17 @@ class BaseProducer:
self.sync_model_thread_started = False self.sync_model_thread_started = False
distributor_weight_version = ray.get(self.shared_signal_actor.get_signal.remote()).get( distributor_weight_version = ray.get(self.shared_signal_actor.get_signal.remote()).get(
f"distributor_weight_version", 0 f"distributor_weight_version", -1
)
print(f"[P{self.producer_idx}] Distributor weight version: {distributor_weight_version}")
is_initial_sync = False
while distributor_weight_version < 0:
# wait for distributor to set weight version
is_initial_sync = True
print(f"[P{self.producer_idx}] Waiting for intial model sync before training")
time.sleep(1)
distributor_weight_version = ray.get(self.shared_signal_actor.get_signal.remote()).get(
f"distributor_weight_version", -1
) )
if ( if (
not self.sync_model_thread_started not self.sync_model_thread_started
@ -291,6 +320,16 @@ class BaseProducer:
self.sync_model_thread = threading.Thread(target=sync_model_thread) self.sync_model_thread = threading.Thread(target=sync_model_thread)
self.producer_weight_version = distributor_weight_version self.producer_weight_version = distributor_weight_version
self.sync_model_thread.start() self.sync_model_thread.start()
if is_initial_sync:
# wait till initial model sync is done and load the model state dict
while self.sync_model_thread_started:
print(f"[P{self.producer_idx}] Waiting for model sync to finish before evaluation")
time.sleep(1)
for pp_idx in range(self.consumer_pp_size):
if self.state_dict_cpu[pp_idx] is not None and self.state_dict_cpu[pp_idx] != {}:
self.load_state_dict(self.state_dict_cpu[pp_idx])
print(f"[P{self.producer_idx}] loaded initial model state dict")
torch.cuda.empty_cache() torch.cuda.empty_cache()
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get( if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
"enable_sleep_mode", False "enable_sleep_mode", False
@ -314,7 +353,9 @@ class BaseProducer:
for eval_batch in tqdm.tqdm( for eval_batch in tqdm.tqdm(
self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0 self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0
): ):
eval_outputs = self.rollout(**eval_batch, sample_params=self.eval_sample_params) eval_outputs = self.rollout(
**eval_batch, sample_params=copy.deepcopy(self.eval_sample_params[eval_task_name])
)
eval_results = eval_results + [ eval_results = eval_results + [
self.evaluation_function( self.evaluation_function(
eval_outputs["input_ids"][m][n], eval_outputs["input_ids"][m][n],
@ -343,6 +384,7 @@ class BaseProducer:
print( print(
f"[P{self.producer_idx}]: Accuracy on {eval_task_name}: {to_log_msg[f'eval/{eval_task_name}']}" f"[P{self.producer_idx}]: Accuracy on {eval_task_name}: {to_log_msg[f'eval/{eval_task_name}']}"
) )
# save eval results # save eval results
safe_append_to_jsonl_file( safe_append_to_jsonl_file(
os.path.join( os.path.join(
@ -351,7 +393,7 @@ class BaseProducer:
), ),
eval_results, eval_results,
) )
raise ValueError()
if self.producer_idx == 0: if self.producer_idx == 0:
self.wandb_run.log(to_log_msg, step=self.consumer_global_step) self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
self.eval_mode = False self.eval_mode = False
@ -427,15 +469,15 @@ class BaseProducer:
self.load_state_dict(self.state_dict_cpu[pp_idx]) self.load_state_dict(self.state_dict_cpu[pp_idx])
# linear annealing for 1 episode, temperature from initial to 0.9 # linear annealing for 1 episode, temperature from initial to 0.9
if episode <= 0: # if episode <= 0:
ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader) # ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader)
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[ # self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
"temperature" # "temperature"
] + ratio * 0.9 # ] + ratio * 0.9
if isinstance(self.model, BACKEND_MAP["vllm"]): # if isinstance(self.model, BACKEND_MAP["vllm"]):
self.model.sample_params.temperature = (1 - ratio) * self.generate_config[ # self.model.sample_params.temperature = (1 - ratio) * self.generate_config[
"temperature" # "temperature"
] + ratio * 0.9 # ] + ratio * 0.9
def __del__(self): def __del__(self):
self.profiler.close() self.profiler.close()
@ -464,7 +506,6 @@ class SimpleProducer(BaseProducer):
eval_interval=-1, # disable evaluation eval_interval=-1, # disable evaluation
grpo_config: Dict[str, Any] = None, grpo_config: Dict[str, Any] = None,
eval_save_dir: str = "./eval", eval_save_dir: str = "./eval",
eval_generation_config={},
project_name: str = None, project_name: str = None,
run_name: str = None, run_name: str = None,
wandb_group_name: str = None, wandb_group_name: str = None,
@ -500,10 +541,6 @@ class SimpleProducer(BaseProducer):
) )
print("tokenizer_config", tokenizer_config) print("tokenizer_config", tokenizer_config)
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations, tokenizer_config) self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations, tokenizer_config)
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
self.eval_generation_config.update(eval_generation_config)
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
@torch.no_grad() @torch.no_grad()
def rollout(self, input_ids, attention_mask, **kwargs): def rollout(self, input_ids, attention_mask, **kwargs):

View File

@ -11,3 +11,9 @@ rm -rf *.prof
MAX_NEW_TOKENS=$((4096-512)) MAX_NEW_TOKENS=$((4096-512))
python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 8 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 4 -tmbs 4 -p GRPO-Math-Profile -ei -5 -zero 2 -mnt $MAX_NEW_TOKENS -nb 1 --enable_profiling 2>&1| tee ibs_32_tbs_32_tmbs_4_zero_2_4096_GRPO_profile.txt python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 8 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 4 -tmbs 4 -p GRPO-Math-Profile -ei -5 -zero 2 -mnt $MAX_NEW_TOKENS -nb 1 --enable_profiling 2>&1| tee ibs_32_tbs_32_tmbs_4_zero_2_4096_GRPO_profile.txt
python visualization.py --visualization actor_timelines_ibs_32_tbs_32_tmbs_4_zero_2_4096_GRPO_profile.png python visualization.py --visualization actor_timelines_ibs_32_tbs_32_tmbs_4_zero_2_4096_GRPO_profile.png
python rl_example_zero_bubble.py --dataset /home/litong/workspace/data/rl-math-train-final.jsonl --model /home/litong/workspace/model/DeepSeek-R1-Distill-Qwen-7B -t 4 -i 4 -b vllm -a DAPO -imbs 4 -ibs 8 -tbs 32 -e 2 -rt boxed -si 25 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 2 -tmbs 2 -p DAPO-Experiment-Baseline -ei 15 -ed '{"Math_500_level_1": "/mnt/nfs/yeanbang/data/RLHF_data/Math-500/Math_500_level_1.jsonl", "Math_500_level_2": "/mnt/nfs/yeanbang/data/RLHF_data/Math-500/Math_500_level_2.jsonl", "Math_500_level_3": "/mnt/nfs/yeanbang/data/RLHF_data/Math-500/Math_500_level_3.jsonl", "Math_500_level_4": "/mnt/nfs/yeanbang/data/RLHF_data/Math-500/Math_500_level_4.jsonl", "Math_500_level_5": "/mnt/nfs/yeanbang/data/RLHF_data/Math-500/Math_500_level_5.jsonl"}' -zero 1 -tp 2 -pp 2 -mpt 512 -mnt 15872 --data_actor_buffer_size_limit 48
python rl_example_zero_bubble.py --dataset /mnt/nfs/yeanbang/data/RLHF_data/miromina/miromina.jsonl --model /home/litong/workspace/model/DeepSeek-R1-Distill-Qwen-7B -t 4 -i 4 -b vllm -a DAPO -imbs 4 -ibs 8 -tbs 32 -e 2 -g 16 -rt boxed -si 25 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 2 -tmbs 2 -p DAPO-Experiment-Baseline-V3 -ei 15 -ed '{"AIME2024": "/mnt/nfs/yeanbang/data/RLHF_data/AIME/2024/AIME2024.jsonl"}' -zero 1 -tp 2 -pp 2 -mpt 1024 -mnt 16384 --data_actor_buffer_size_limit 64 --enable_profiling 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_pp_2_16K_DAPO_profile_v3.txt

View File

@ -0,0 +1,9 @@
export https_proxy=http://vpn.luchentech.com:32171
export http_proxy=http://vpn.luchentech.com:32171
export all_proxy=socks5://vpn.luchentech.com:32170
# 输出正常
# python rl_example.py --dataset /mnt/nfs/yeanbang/data/RLHF_data/miromina/miromina.jsonl --model /mnt/nfs/share/data/models/Miromind-M1-SFT-7B/ --checkpoint-path /mnt/nfs/share/data/models/Miromind-M1-SFT-7B/ -t 4 -i 4 -b vllm -a DAPO -imbs 8 -ibs 8 -tbs 32 -e 2 -g 16 -rt boxed -si 25 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 8 -tmbs 8 -p DAPO-Experiment-Baseline-V3_1 -zero 1 -pp 2 -tp 2 -mpt 1024 -mnt 4096 -nb 1 --enable_profiling
# 输出乱码
python rl_example.py --dataset /mnt/nfs/yeanbang/data/RLHF_data/miromina/miromina.jsonl --model /mnt/nfs/share/data/models/Miromind-M1-SFT-7B/ --checkpoint-path /mnt/nfs/yeanbang/experiments/rlhf/dapo/v3/model/DAPO-Experiment-Baseline-V3/modeling-episode-0-step-200 -t 4 -i 4 -b vllm -a DAPO -imbs 8 -ibs 8 -tbs 32 -e 2 -g 16 -rt boxed -si 25 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 8 -tmbs 8 -p DAPO-Experiment-Baseline-V3_1 -zero 1 -pp 2 -tp 2 -mpt 1024 -mnt 4096 -nb 1 --enable_profiling

View File

@ -18,6 +18,13 @@ os.environ["no_proxy"] = "127.0.0.1,localhost"
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B") parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
parser.add_argument(
"-cp",
"--checkpoint-path",
type=str,
default=None,
help="Path to the checkpoint to load the model from. If not provided, the model will be loaded from the model path.",
)
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
parser.add_argument( parser.add_argument(
"-ed", "-ed",
@ -227,7 +234,9 @@ if __name__ == "__main__":
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock
inference_model_config = dict(path=args.model) inference_model_config = dict(path=args.model)
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False) train_model_config = dict(
path=args.model, use_flash_attention_2=True, use_cache=False, checkpoint_path=args.checkpoint_path
)
generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature) generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)
if args.backend == "transformers": if args.backend == "transformers":

View File

@ -261,7 +261,6 @@ if __name__ == "__main__":
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None, stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
) )
) )
eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation
else: else:
raise ValueError(f"Unsupported backend: {args.backend}") raise ValueError(f"Unsupported backend: {args.backend}")
@ -290,7 +289,7 @@ if __name__ == "__main__":
elif args.algo == "DAPO": elif args.algo == "DAPO":
# DAPO variant settings # DAPO variant settings
grpo_config = { grpo_config = {
"filter_range": [0.01, 0.7], # only filter out all zero batch and all one batch "filter_range": [0.01, 0.8], # only filter out all zero batch and all one batch
"lr": args.learning_rate, "lr": args.learning_rate,
"train_microbatch_size": args.train_microbatch_size, "train_microbatch_size": args.train_microbatch_size,
"dynamic_batching": True, "dynamic_batching": True,
@ -302,8 +301,8 @@ if __name__ == "__main__":
"soft_over_length_punishment": True, "soft_over_length_punishment": True,
"max_length": args.max_new_tokens + args.max_prompt_tokens, "max_length": args.max_new_tokens + args.max_prompt_tokens,
"max_new_tokens": args.max_new_tokens, "max_new_tokens": args.max_new_tokens,
"cache_length": min(1024, int(args.max_new_tokens / 4)), "cache_length": max(1024, int(args.max_new_tokens / 4)),
"filter_truncated_response": True, "filter_truncated_response": False,
"reward_fn_type": args.reward_type, "reward_fn_type": args.reward_type,
"response_format_tags": ( "response_format_tags": (
{ {
@ -362,15 +361,22 @@ if __name__ == "__main__":
save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")), save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")),
eval_dataset_config=( eval_dataset_config=(
{ {
k: {"path": v, "max_length": args.max_prompt_tokens, "system_prompt": args.system_prompt} k: {
"path": v,
"max_length": args.max_prompt_tokens,
"system_prompt": args.system_prompt,
"batch_size": 1,
"n": 16,
"max_tokens": 32 * 1024 - 512,
"temperature": 0.6,
}
for k, v in json.loads(args.eval_dataset).items() for k, v in json.loads(args.eval_dataset).items()
} }
if args.eval_dataset if args.eval_dataset
else None else None
), ), # support using different evaluation prompts/ parameters for each task
eval_interval=args.eval_interval, eval_interval=args.eval_interval,
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")), eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
eval_generation_config=eval_generation_config,
log_rollout_interval=20, log_rollout_interval=20,
rollout_save_dir=args.rollout_save_dir, rollout_save_dir=args.rollout_save_dir,
enable_profiling=args.enable_profiling, enable_profiling=args.enable_profiling,

View File

@ -0,0 +1,54 @@
# run eval for every saved checkpoint
#!/bin/bash
# Set the base directory where models are saved
BASE_DIR="/mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline"
BASE_DIR="/mnt/nfs/yeanbang/experiments/rlhf/dapo/v3/model/DAPO-Experiment-Baseline-V3"
DATASET="/home/litong/workspace/data/rl-math-train-final.jsonl"
EVAL_SCRIPT="rl_example_zero_bubble.py"
# /mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline/modeling-episode-0-step-105
# /mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline/modeling-episode-0-step-120
# /mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline/modeling-episode-0-step-15
# /mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline/modeling-episode-0-step-30
# /mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline/modeling-episode-0-step-45
# /mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline/modeling-episode-0-step-60
# /mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline/modeling-episode-0-step-75
# /mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline/modeling-episode-0-step-90
# /mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline/modeling-episode-1-step-150
# Loop through each model directory matching the pattern
# 64K 65024
# 32K 32256
for model_path in "$BASE_DIR"/modeling-episode-0-step-150; do
if [ -d "$model_path" ]; then
echo "Evaluating model: $model_path"
python "$EVAL_SCRIPT" \
--dataset "$DATASET" \
--tokenizer-path "/home/litong/workspace/model/DeepSeek-R1-Distill-Qwen-7B" \
--model "$model_path" \
-t 1 \
-i 6 \
-b vllm \
-a GRPO \
-imbs 1 \
-ibs 8 \
-tbs 8 \
-e 2 \
-g 16 \
-rt boxed \
-si 25 \
-s "Please reason step by step, and put your final answer within \\boxed{}." \
-tMbs 2 \
-tmbs 1 \
-p DAPO-Experiment-Baseline-Eval \
-ei 15 \
-ed '{"AIME2024": "/mnt/nfs/yeanbang/data/RLHF_data/AIME/2024/AIME2024.jsonl"}' \
-zero 2 \
-mpt 512 \
-mnt 65024 \
--data_actor_buffer_size_limit 4 | grep "Accuracy on"
echo "Finished evaluating: $model_path"
fi
done

View File

@ -824,7 +824,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
force_sp_output_gather=False, # force_sp_output_gather=False,
) )
hidden_states = outputs[0] hidden_states = outputs[0]