diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index e360392e7..5afc60070 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -55,6 +55,7 @@ class BaseConsumer: self.enable_profiling = enable_profiling assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size" self.num_microbatches = batch_size // minibatch_size + self.checkpoint_path = model_config.pop("checkpoint_path", None) self.model_config = model_config self.plugin_config = plugin_config @@ -143,6 +144,26 @@ class BaseConsumer: return effective_group_to_raw_group_mapping def loop(self) -> None: + self.profiler.enter("sync_model") + torch.cuda.empty_cache() + state_dict = self.state_dict() + if self.pp_size > 1: + if self.tp_rank == 0 and self.dp_rank == 0: + ray_broadcast_tensor_dict( + state_dict, + src=self.num_producers, + device=self.device, + group_name=f"sync_model_{self.pp_rank}", + ) + else: + if self.rank == 0: + ray_broadcast_tensor_dict( + state_dict, src=self.num_producers, device=self.device, group_name="sync_model" + ) + del state_dict + torch.cuda.empty_cache() + self.profiler.exit("sync_model") + print( f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}" ) @@ -286,7 +307,7 @@ class BaseConsumer: if self.rank == 0: print(f"Start saving policy model at step {step + 1}.") save_path = os.path.join(self.save_dir, f"modeling-episode-{episode}-step-{step + 1}") - self.booster.save_model(self.policy_model, save_path, shard=True) + self.booster.save_model(self.policy_model, save_path, shard=True, use_safetensors=True) if self.rank == 0: print(f"Saved model checkpoint at step {step + 1} in folder {save_path}") @@ -365,7 +386,7 @@ class SimpleConsumer(BaseConsumer): self.model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.model.train() self.model.gradient_checkpointing_enable() - self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3) + self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3, weight_decay=0.01) self.accum_loss = torch.zeros(1, device=self.device) def setup(self): diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index a3f1a1cbb..c4cae3acd 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -72,7 +72,11 @@ class GRPOConsumer(BaseConsumer): self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.policy_model.train() self.policy_model.gradient_checkpointing_enable() - self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6)) + self.optimizer = HybridAdam( + self.policy_model.parameters(), + lr=grpo_config.get("lr", 1e-6), + weight_decay=grpo_config.get("weight_decay", 0.01), + ) self.accum_loss = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device) self.accum_entropy = torch.zeros(1, device=self.device) @@ -153,6 +157,8 @@ class GRPOConsumer(BaseConsumer): ) if self.policy_loss_fn.beta > 0: self.reference_model, *_ = self.booster.boost(self.reference_model) + if self.checkpoint_path is not None: + self.booster.load_model(self.policy_model, self.checkpoint_path) self.plugin.logger.set_level("ERROR") def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]: diff --git a/applications/ColossalChat/coati/distributed/launch_zero_bubble.py b/applications/ColossalChat/coati/distributed/launch_zero_bubble.py index de5b61353..ffea6f6bb 100644 --- a/applications/ColossalChat/coati/distributed/launch_zero_bubble.py +++ b/applications/ColossalChat/coati/distributed/launch_zero_bubble.py @@ -55,7 +55,6 @@ def launch_distributed( eval_dataset_config: Optional[Dict[str, Any]] = None, eval_interval: int = 100, eval_save_dir: Optional[str] = None, - eval_generation_config: Optional[Dict[str, Any]] = None, log_rollout_interval: int = 20, rollout_save_dir: str = "./rollout", enable_profiling: bool = False, @@ -139,7 +138,6 @@ def launch_distributed( eval_interval=eval_interval, grpo_config=grpo_config, eval_save_dir=eval_save_dir, - eval_generation_config=eval_generation_config, project_name=project_name, run_name=run_name, wandb_group_name=wandb_group_name, diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 11fb5d3aa..a5f3bda04 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -203,6 +203,28 @@ class BaseProducer: raise NotImplementedError def loop(self) -> None: + + torch.cuda.empty_cache() + self.profiler.enter("sync_model") + if self.consumer_pp_size > 1: + for pp_idx in range(self.consumer_pp_size): + state_dict = ray_broadcast_tensor_dict( + None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}" + ) + if "consumer_global_step" in state_dict: + self.consumer_global_step = state_dict.pop("consumer_global_step").item() + self.load_state_dict(state_dict) + else: + state_dict = ray_broadcast_tensor_dict( + None, self.num_producers, device=self.device, group_name="sync_model" + ) + if "consumer_global_step" in state_dict: + self.consumer_global_step = state_dict.pop("consumer_global_step").item() + self.load_state_dict(state_dict) + self.profiler.exit("sync_model") + del state_dict + torch.cuda.empty_cache() + num_update_per_episode = len(self.train_dataloader) // self.num_microbatches num_valid_microbatches = num_update_per_episode * self.num_microbatches diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index f7a2fb89c..1b04c976b 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -25,7 +25,12 @@ from latex2sympy2_extended import NormalizationConfig from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify from .code_reward.utils import check_correctness_code_api as check_correctness_code -from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure +from .reward_utils import ( + extract_boxed_solution, + extract_solution, + find_infinite_loop_start, + validate_response_structure, +) CANNOT_PARSE_GT_ANSWER = -1 CANNOT_PARSE_PREDICTION = -2 @@ -122,6 +127,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) + repetition_reward = 1.0 if detect_repetition(decoded_final_answer) == [] else 0.0 + final_answer, processed_str = extract_solution(decoded_final_answer) format_valid = validate_response_structure(processed_str, kwargs["tags"]) @@ -137,6 +144,10 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): if format_valid: format_acc += 1 + # Add repetition reward + if not eval_mode: + reward += repetition_reward + # Check if the sequence is over length if not eval_mode and res_length >= max_new_tokens: reward *= 0.0 @@ -182,6 +193,8 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): raise ValueError("no gt_answer is provided, please check your training dataset.") decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) + print(f"Decoded final answer: {decoded_final_answer[-500:]}") + repetition_score = find_infinite_loop_start(input_ids[s : e + 1], min_repeats=2, distance=False) final_answer = extract_boxed_solution(decoded_final_answer) format_valid = final_answer is not None @@ -202,6 +215,10 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): if format_valid: format_acc += 1 + if not repetition_score > 0 and not eval_mode: + # award for non-repetition + reward += 2 + # Check if the sequence is over length if not eval_mode and res_length >= max_new_tokens: reward *= 0.0 diff --git a/applications/ColossalChat/coati/distributed/reward/reward_utils.py b/applications/ColossalChat/coati/distributed/reward/reward_utils.py index ffc220846..45a9e4828 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_utils.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_utils.py @@ -14,7 +14,9 @@ # limitations under the License. import re -from typing import Dict, Optional, Tuple +from typing import Dict, List, Optional, Tuple + +import torch def validate_response_structure(processed_str: str, tags: Dict = None) -> bool: @@ -122,3 +124,51 @@ def extract_boxed_solution(text: str) -> Optional[str]: except Exception: # Any other unexpected error return None + + +import Levenshtein + + +def is_similar(seq1: List[int], seq2: List[int], threshold: float = 0.9) -> bool: + ratio = Levenshtein.ratio(seq1, seq2) + return ratio >= threshold + + +def find_infinite_loop_start(token_ids: List[int], min_repeats: int = 2, distance: bool = False) -> float: + n = len(token_ids) + + # Step 1: Detect the repeating segment at the end using two pointers + longest_valid_length = 0 + start_of_loop = n + + for length in range(1, n // min_repeats + 1): # Try different phrase lengths + count = 1 # Reset repetition counter + right = n - length # Start comparing from the second last occurrence + + while right - length >= 0: + # Check if the current phrase matches the previous phrase + if distance: + if is_similar(token_ids[right - length : right], token_ids[right : right + length]): + count += 1 + else: + break # Stop if repetition is broken + else: + # Use torch.equal() for tensor comparison + if torch.equal(token_ids[right - length : right], token_ids[right : right + length]): + count += 1 + else: + break # Stop if repetition is broken + + right -= length # Move left to check further + + if count >= min_repeats: # Found a valid repeating phrase + longest_valid_length = length + start_of_loop = right # This is where the first cycle of the repetition begins + + if longest_valid_length == 0: + return 0.0 # No infinite loop found, return repetition ratio as 0 + + # Step 2: Compute the repetition ratio + repetition_ratio = (n - start_of_loop) / n + + return repetition_ratio diff --git a/applications/ColossalChat/coati/distributed/zero_bubble/consumer.py b/applications/ColossalChat/coati/distributed/zero_bubble/consumer.py index 2b4790884..7c23f59b6 100644 --- a/applications/ColossalChat/coati/distributed/zero_bubble/consumer.py +++ b/applications/ColossalChat/coati/distributed/zero_bubble/consumer.py @@ -55,7 +55,7 @@ class BaseConsumer: self.num_microbatches = batch_size // minibatch_size self.data_uid = 0 self.sync_model_thread_started = False - + self.checkpoint_path = model_config.pop("checkpoint_path", None) self.model_config = model_config self.plugin_config = plugin_config @@ -79,6 +79,7 @@ class BaseConsumer: plugin_config.update(self.plugin_config) self.plugin = HybridParallelPlugin(**plugin_config) self.booster = Booster(plugin=self.plugin) + self.dp_rank = dist.get_rank(self.plugin.dp_group) self.tp_rank = dist.get_rank(self.plugin.tp_group) self.pp_rank = dist.get_rank(self.plugin.pp_group) @@ -165,8 +166,93 @@ class BaseConsumer: desc=f"Episode {episode} with rollout step(s)", disable=self.rank != 0, ) as pbar: + need_sync_model = True while self.received_prompts < self.train_dataset_size: torch.cuda.reset_peak_memory_stats() + if need_sync_model and ( + (self.global_step + 1) % self.save_interval == 0 + or self.received_prompts >= self.train_dataset_size + ): + if self.rank == 0: + print(f"Start saving policy model at step {self.global_step + 1}.") + save_path = os.path.join( + self.save_dir, f"modeling-episode-{episode}-step-{self.global_step + 1}" + ) + self.booster.save_model(self.policy_model, save_path, shard=True, use_safetensors=True) + if self.rank == 0: + print(f"Saved model checkpoint at step {self.global_step + 1} in folder {save_path}") + + if need_sync_model and ( + episode != self.num_episodes - 1 or self.received_prompts != self.train_dataset_size + ): + + def sync_model_thread(): + # sync model weights to all producers, if no model update or it is the last training step, skip syncing + if self.pp_size > 1: + print( + f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {self.global_step}" + ) + else: + print(f"[T{dist.get_rank()}] Sync model episode {episode} step {self.global_step}") + torch.cuda.empty_cache() + if self.pp_size > 1: + if self.tp_rank == 0 and self.dp_rank == 0: + self.profiler.enter("sync_model") + ray.get( + self.shared_signal_actor.set_signal.remote( + f"consumer_pp_{self.pp_rank}", "ready_sync_model" + ) + ) + print( + f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {self.global_step}" + ) + ray_broadcast_tensor_dict( + self.state_dict_cpu, + src=0, + device=torch.device("cpu"), + group_name=f"sync_model_consumer_pp_{self.pp_rank}", + backend="gloo", + ) + self.profiler.exit("sync_model") + else: + if self.rank == 0: + self.profiler.enter("sync_model") + ray.get(self.shared_signal_actor.set_signal.remote("consumer", "ready_sync_model")) + print(f"[T{dist.get_rank()}] Sync model episode {episode} step {self.global_step}") + ray_broadcast_tensor_dict( + self.state_dict_cpu, + src=0, + device=torch.device("cpu"), + group_name="sync_model_consumer", + backend="gloo", + ) + self.profiler.exit("sync_model") + + if not self.sync_model_thread_started: + # only sync model when the thread is not started and no other thread is broadcasting + self.sync_model_thread_started = True + state_dict_ = self.state_dict() + if (self.pp_size > 1 and self.tp_rank == 0 and self.dp_rank == 0) or ( + self.pp_size == 1 and self.rank == 0 + ): + if len(self.state_dict_cpu) == 0: + # use pinned memory to speed up the transfer + self.state_dict_cpu = {k: v.cpu().pin_memory() for k, v in state_dict_.items()} + torch.cuda.synchronize() + for k, v in state_dict_.items(): + self.state_dict_cpu[k].copy_(v, non_blocking=True) + torch.cuda.synchronize() + cc.barrier( + group_name="consumer_pg" + ) # to make sure all ranks have state dict offloaded to CPU before starting the thread + time_before_starting_thread = time.time() + threading.Thread(target=sync_model_thread).start() + # sync_model_thread() + self.profiler.log( + f"Sync model, took {time.time() - time_before_starting_thread:.2f} seconds" + ) + self.sync_model_thread_started = False + # ray.get(self.shared_signal_actor.release_process_lock.remote("broadcasting_lock")) effective_group_to_raw_group_mapping = {} self.profiler.enter(f"recv_data") while len(effective_group_to_raw_group_mapping) < self.dp_size * self.minibatch_size: @@ -254,90 +340,6 @@ class BaseConsumer: pbar.set_postfix({"loss": loss}) need_sync_model = True ray.get(self.shared_signal_actor.set_signal.remote("global_step", self.global_step + 1)) - if need_sync_model and ( - (self.global_step + 1) % self.save_interval == 0 - or self.received_prompts >= self.train_dataset_size - ): - if self.rank == 0: - print(f"Start saving policy model at step {self.global_step + 1}.") - save_path = os.path.join( - self.save_dir, f"modeling-episode-{episode}-step-{self.global_step + 1}" - ) - self.booster.save_model(self.policy_model, save_path, shard=True) - if self.rank == 0: - print(f"Saved model checkpoint at step {self.global_step + 1} in folder {save_path}") - - if need_sync_model and ( - episode != self.num_episodes - 1 or self.received_prompts != self.train_dataset_size - ): - - def sync_model_thread(): - # sync model weights to all producers, if no model update or it is the last training step, skip syncing - if self.pp_size > 1: - print( - f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {self.global_step}" - ) - else: - print(f"[T{dist.get_rank()}] Sync model episode {episode} step {self.global_step}") - torch.cuda.empty_cache() - if self.pp_size > 1: - if self.tp_rank == 0 and self.dp_rank == 0: - self.profiler.enter("sync_model") - ray.get( - self.shared_signal_actor.set_signal.remote( - f"consumer_pp_{self.pp_rank}", "ready_sync_model" - ) - ) - print( - f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {self.global_step}" - ) - ray_broadcast_tensor_dict( - self.state_dict_cpu, - src=0, - device=torch.device("cpu"), - group_name=f"sync_model_consumer_pp_{self.pp_rank}", - backend="gloo", - ) - self.profiler.exit("sync_model") - else: - if self.rank == 0: - self.profiler.enter("sync_model") - ray.get(self.shared_signal_actor.set_signal.remote("consumer", "ready_sync_model")) - print(f"[T{dist.get_rank()}] Sync model episode {episode} step {self.global_step}") - ray_broadcast_tensor_dict( - self.state_dict_cpu, - src=0, - device=torch.device("cpu"), - group_name="sync_model_consumer", - backend="gloo", - ) - self.profiler.exit("sync_model") - - if not self.sync_model_thread_started: - # only sync model when the thread is not started and no other thread is broadcasting - self.sync_model_thread_started = True - state_dict_ = self.state_dict() - if (self.pp_size > 1 and self.tp_rank == 0 and self.dp_rank == 0) or ( - self.pp_size == 1 and self.rank == 0 - ): - if len(self.state_dict_cpu) == 0: - # use pinned memory to speed up the transfer - self.state_dict_cpu = {k: v.cpu().pin_memory() for k, v in state_dict_.items()} - torch.cuda.synchronize() - for k, v in state_dict_.items(): - self.state_dict_cpu[k].copy_(v, non_blocking=True) - torch.cuda.synchronize() - cc.barrier( - group_name="consumer_pg" - ) # to make sure all ranks have state dict offloaded to CPU before starting the thread - time_before_starting_thread = time.time() - threading.Thread(target=sync_model_thread).start() - # sync_model_thread() - self.profiler.log( - f"Sync model, took {time.time() - time_before_starting_thread:.2f} seconds" - ) - self.sync_model_thread_started = False - # ray.get(self.shared_signal_actor.release_process_lock.remote("broadcasting_lock")) self.profiler.log(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") self.received_prompts = 0 ray.get(self.shared_signal_actor.set_signal.remote("consumer", "terminate")) diff --git a/applications/ColossalChat/coati/distributed/zero_bubble/distributor.py b/applications/ColossalChat/coati/distributed/zero_bubble/distributor.py index ea04ae13c..250b388d1 100644 --- a/applications/ColossalChat/coati/distributed/zero_bubble/distributor.py +++ b/applications/ColossalChat/coati/distributed/zero_bubble/distributor.py @@ -20,7 +20,7 @@ class Distributor: enable_profiling: bool = True, ): self.distributor_id = distributor_id - self.weight_version = [0] * consumer_pp_size + self.weight_version = [-1] * consumer_pp_size self.consumer_pp_size = consumer_pp_size self.state_dict_cpu = {} self.num_producers = num_producers diff --git a/applications/ColossalChat/coati/distributed/zero_bubble/grpo_consumer.py b/applications/ColossalChat/coati/distributed/zero_bubble/grpo_consumer.py index f04785271..0993687ac 100644 --- a/applications/ColossalChat/coati/distributed/zero_bubble/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/zero_bubble/grpo_consumer.py @@ -75,7 +75,11 @@ class GRPOConsumer(BaseConsumer): self.policy_model.train() self.policy_model.gradient_checkpointing_enable() self.vocab_size = self.policy_model.config.vocab_size - self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6)) + self.optimizer = HybridAdam( + self.policy_model.parameters(), + lr=grpo_config.get("lr", 1e-6), + weight_decay=grpo_config.get("weight_decay", 0.01), + ) self.accum_loss = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device) self.accum_entropy = torch.zeros(1, device=self.device) @@ -157,6 +161,9 @@ class GRPOConsumer(BaseConsumer): self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost( self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler ) + if self.checkpoint_path is not None: + print(f"Resume training from checkpoint: {self.checkpoint_path}") + self.booster.load_model(self.policy_model, self.checkpoint_path) if self.policy_loss_fn.beta > 0: self.reference_model, *_ = self.booster.boost(self.reference_model) self.plugin.logger.set_level("ERROR") diff --git a/applications/ColossalChat/coati/distributed/zero_bubble/producer.py b/applications/ColossalChat/coati/distributed/zero_bubble/producer.py index 31c314dd5..c33471992 100644 --- a/applications/ColossalChat/coati/distributed/zero_bubble/producer.py +++ b/applications/ColossalChat/coati/distributed/zero_bubble/producer.py @@ -82,7 +82,7 @@ class BaseProducer: self.eval_interval = eval_interval self.eval_save_dir = eval_save_dir self.consumer_global_step = 0 - self.producer_weight_version = 0 + self.producer_weight_version = -1 self.eval_mode = False self.log_rollout_interval = log_rollout_interval self.latest_rollout_log_step = -1 @@ -148,19 +148,38 @@ class BaseProducer: self.evaluation_function = code_reward_fn else: raise ValueError(f"Unknown evaluation function type {grpo_config['reward_fn_type']}") - self.eval_dataset_config = eval_dataset_config if self.eval_dataset_config is not None: self.eval_dataloaders = {} + self.eval_sample_params = {} for eval_task_name in self.eval_dataset_config: + eval_sampling_params = copy.deepcopy(generate_config) + eval_sampling_params["n"] = eval_dataset_config[eval_task_name].pop( + "n", 1 + ) # use 1 generation for evaluation + if "max_tokens" in eval_sampling_params: + eval_sampling_params["max_tokens"] = eval_dataset_config[eval_task_name].pop( + "max_tokens", eval_sampling_params["max_tokens"] + ) + elif "max_new_tokens" in eval_sampling_params: + eval_sampling_params["max_new_tokens"] = eval_dataset_config[eval_task_name].pop( + "max_new_tokens", eval_sampling_params["max_new_tokens"] + ) + eval_sampling_params["temperature"] = eval_dataset_config[eval_task_name].pop("temperature", 0.6) + eval_sampling_params["logprobs"] = 0 # force parameter + self.eval_sample_params[eval_task_name] = SamplingParams(**eval_sampling_params) + eval_dataset_path = eval_dataset_config[eval_task_name].pop("path") + eval_inference_batch_size = eval_dataset_config[eval_task_name].pop("batch_size", microbatch_size) eval_dataset = RawConversationDataset( self.tokenizer, eval_dataset_path, **eval_dataset_config[eval_task_name] ) - print(f"[P{self.producer_idx}] eval dataset {eval_task_name} size: {len(eval_dataset)}") + print( + f"[P{self.producer_idx}] eval dataset {eval_task_name} size: {len(eval_dataset)}, eval_batch_size: {eval_inference_batch_size} num_producers:{num_producers}" + ) self.eval_dataloaders[eval_task_name] = DataLoader( eval_dataset, - batch_size=microbatch_size, + batch_size=eval_inference_batch_size, sampler=DistributedSampler( eval_dataset, num_replicas=num_producers, @@ -280,8 +299,18 @@ class BaseProducer: self.sync_model_thread_started = False distributor_weight_version = ray.get(self.shared_signal_actor.get_signal.remote()).get( - f"distributor_weight_version", 0 + f"distributor_weight_version", -1 ) + print(f"[P{self.producer_idx}] Distributor weight version: {distributor_weight_version}") + is_initial_sync = False + while distributor_weight_version < 0: + # wait for distributor to set weight version + is_initial_sync = True + print(f"[P{self.producer_idx}] Waiting for intial model sync before training") + time.sleep(1) + distributor_weight_version = ray.get(self.shared_signal_actor.get_signal.remote()).get( + f"distributor_weight_version", -1 + ) if ( not self.sync_model_thread_started and distributor_weight_version != self.producer_weight_version @@ -291,6 +320,16 @@ class BaseProducer: self.sync_model_thread = threading.Thread(target=sync_model_thread) self.producer_weight_version = distributor_weight_version self.sync_model_thread.start() + if is_initial_sync: + # wait till initial model sync is done and load the model state dict + while self.sync_model_thread_started: + print(f"[P{self.producer_idx}] Waiting for model sync to finish before evaluation") + time.sleep(1) + for pp_idx in range(self.consumer_pp_size): + if self.state_dict_cpu[pp_idx] is not None and self.state_dict_cpu[pp_idx] != {}: + self.load_state_dict(self.state_dict_cpu[pp_idx]) + print(f"[P{self.producer_idx}] loaded initial model state dict") + torch.cuda.empty_cache() if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get( "enable_sleep_mode", False @@ -314,7 +353,9 @@ class BaseProducer: for eval_batch in tqdm.tqdm( self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0 ): - eval_outputs = self.rollout(**eval_batch, sample_params=self.eval_sample_params) + eval_outputs = self.rollout( + **eval_batch, sample_params=copy.deepcopy(self.eval_sample_params[eval_task_name]) + ) eval_results = eval_results + [ self.evaluation_function( eval_outputs["input_ids"][m][n], @@ -343,6 +384,7 @@ class BaseProducer: print( f"[P{self.producer_idx}]: Accuracy on {eval_task_name}: {to_log_msg[f'eval/{eval_task_name}']}" ) + # save eval results safe_append_to_jsonl_file( os.path.join( @@ -351,7 +393,7 @@ class BaseProducer: ), eval_results, ) - + raise ValueError() if self.producer_idx == 0: self.wandb_run.log(to_log_msg, step=self.consumer_global_step) self.eval_mode = False @@ -427,15 +469,15 @@ class BaseProducer: self.load_state_dict(self.state_dict_cpu[pp_idx]) # linear annealing for 1 episode, temperature from initial to 0.9 - if episode <= 0: - ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader) - self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[ - "temperature" - ] + ratio * 0.9 - if isinstance(self.model, BACKEND_MAP["vllm"]): - self.model.sample_params.temperature = (1 - ratio) * self.generate_config[ - "temperature" - ] + ratio * 0.9 + # if episode <= 0: + # ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader) + # self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[ + # "temperature" + # ] + ratio * 0.9 + # if isinstance(self.model, BACKEND_MAP["vllm"]): + # self.model.sample_params.temperature = (1 - ratio) * self.generate_config[ + # "temperature" + # ] + ratio * 0.9 def __del__(self): self.profiler.close() @@ -464,7 +506,6 @@ class SimpleProducer(BaseProducer): eval_interval=-1, # disable evaluation grpo_config: Dict[str, Any] = None, eval_save_dir: str = "./eval", - eval_generation_config={}, project_name: str = None, run_name: str = None, wandb_group_name: str = None, @@ -500,10 +541,6 @@ class SimpleProducer(BaseProducer): ) print("tokenizer_config", tokenizer_config) self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations, tokenizer_config) - self.eval_generation_config = copy.deepcopy(self.model.generate_config) - self.eval_generation_config["n"] = 1 # use 1 generation for evaluation - self.eval_generation_config.update(eval_generation_config) - self.eval_sample_params = SamplingParams(**self.eval_generation_config) @torch.no_grad() def rollout(self, input_ids, attention_mask, **kwargs): diff --git a/applications/ColossalChat/profiling.sh b/applications/ColossalChat/profiling.sh index d9f3d9a93..91f1594cd 100755 --- a/applications/ColossalChat/profiling.sh +++ b/applications/ColossalChat/profiling.sh @@ -11,3 +11,9 @@ rm -rf *.prof MAX_NEW_TOKENS=$((4096-512)) python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 8 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 4 -tmbs 4 -p GRPO-Math-Profile -ei -5 -zero 2 -mnt $MAX_NEW_TOKENS -nb 1 --enable_profiling 2>&1| tee ibs_32_tbs_32_tmbs_4_zero_2_4096_GRPO_profile.txt python visualization.py --visualization actor_timelines_ibs_32_tbs_32_tmbs_4_zero_2_4096_GRPO_profile.png + + +python rl_example_zero_bubble.py --dataset /home/litong/workspace/data/rl-math-train-final.jsonl --model /home/litong/workspace/model/DeepSeek-R1-Distill-Qwen-7B -t 4 -i 4 -b vllm -a DAPO -imbs 4 -ibs 8 -tbs 32 -e 2 -rt boxed -si 25 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 2 -tmbs 2 -p DAPO-Experiment-Baseline -ei 15 -ed '{"Math_500_level_1": "/mnt/nfs/yeanbang/data/RLHF_data/Math-500/Math_500_level_1.jsonl", "Math_500_level_2": "/mnt/nfs/yeanbang/data/RLHF_data/Math-500/Math_500_level_2.jsonl", "Math_500_level_3": "/mnt/nfs/yeanbang/data/RLHF_data/Math-500/Math_500_level_3.jsonl", "Math_500_level_4": "/mnt/nfs/yeanbang/data/RLHF_data/Math-500/Math_500_level_4.jsonl", "Math_500_level_5": "/mnt/nfs/yeanbang/data/RLHF_data/Math-500/Math_500_level_5.jsonl"}' -zero 1 -tp 2 -pp 2 -mpt 512 -mnt 15872 --data_actor_buffer_size_limit 48 + + +python rl_example_zero_bubble.py --dataset /mnt/nfs/yeanbang/data/RLHF_data/miromina/miromina.jsonl --model /home/litong/workspace/model/DeepSeek-R1-Distill-Qwen-7B -t 4 -i 4 -b vllm -a DAPO -imbs 4 -ibs 8 -tbs 32 -e 2 -g 16 -rt boxed -si 25 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 2 -tmbs 2 -p DAPO-Experiment-Baseline-V3 -ei 15 -ed '{"AIME2024": "/mnt/nfs/yeanbang/data/RLHF_data/AIME/2024/AIME2024.jsonl"}' -zero 1 -tp 2 -pp 2 -mpt 1024 -mnt 16384 --data_actor_buffer_size_limit 64 --enable_profiling 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_pp_2_16K_DAPO_profile_v3.txt diff --git a/applications/ColossalChat/reproduce.sh b/applications/ColossalChat/reproduce.sh new file mode 100644 index 000000000..de2cc5c2d --- /dev/null +++ b/applications/ColossalChat/reproduce.sh @@ -0,0 +1,9 @@ +export https_proxy=http://vpn.luchentech.com:32171 +export http_proxy=http://vpn.luchentech.com:32171 +export all_proxy=socks5://vpn.luchentech.com:32170 + +# 输出正常 +# python rl_example.py --dataset /mnt/nfs/yeanbang/data/RLHF_data/miromina/miromina.jsonl --model /mnt/nfs/share/data/models/Miromind-M1-SFT-7B/ --checkpoint-path /mnt/nfs/share/data/models/Miromind-M1-SFT-7B/ -t 4 -i 4 -b vllm -a DAPO -imbs 8 -ibs 8 -tbs 32 -e 2 -g 16 -rt boxed -si 25 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 8 -tmbs 8 -p DAPO-Experiment-Baseline-V3_1 -zero 1 -pp 2 -tp 2 -mpt 1024 -mnt 4096 -nb 1 --enable_profiling + +# 输出乱码 +python rl_example.py --dataset /mnt/nfs/yeanbang/data/RLHF_data/miromina/miromina.jsonl --model /mnt/nfs/share/data/models/Miromind-M1-SFT-7B/ --checkpoint-path /mnt/nfs/yeanbang/experiments/rlhf/dapo/v3/model/DAPO-Experiment-Baseline-V3/modeling-episode-0-step-200 -t 4 -i 4 -b vllm -a DAPO -imbs 8 -ibs 8 -tbs 32 -e 2 -g 16 -rt boxed -si 25 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 8 -tmbs 8 -p DAPO-Experiment-Baseline-V3_1 -zero 1 -pp 2 -tp 2 -mpt 1024 -mnt 4096 -nb 1 --enable_profiling diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 08814f9f1..69bbb64dd 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -18,6 +18,13 @@ os.environ["no_proxy"] = "127.0.0.1,localhost" if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B") + parser.add_argument( + "-cp", + "--checkpoint-path", + type=str, + default=None, + help="Path to the checkpoint to load the model from. If not provided, the model will be loaded from the model path.", + ) parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") parser.add_argument( "-ed", @@ -227,7 +234,9 @@ if __name__ == "__main__": os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock inference_model_config = dict(path=args.model) - train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False) + train_model_config = dict( + path=args.model, use_flash_attention_2=True, use_cache=False, checkpoint_path=args.checkpoint_path + ) generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature) if args.backend == "transformers": diff --git a/applications/ColossalChat/rl_example_zero_bubble.py b/applications/ColossalChat/rl_example_zero_bubble.py index 89270d753..fb6e5b973 100644 --- a/applications/ColossalChat/rl_example_zero_bubble.py +++ b/applications/ColossalChat/rl_example_zero_bubble.py @@ -261,7 +261,6 @@ if __name__ == "__main__": stop=[""] if args.reward_type == "think_answer_tags" else None, ) ) - eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation else: raise ValueError(f"Unsupported backend: {args.backend}") @@ -290,7 +289,7 @@ if __name__ == "__main__": elif args.algo == "DAPO": # DAPO variant settings grpo_config = { - "filter_range": [0.01, 0.7], # only filter out all zero batch and all one batch + "filter_range": [0.01, 0.8], # only filter out all zero batch and all one batch "lr": args.learning_rate, "train_microbatch_size": args.train_microbatch_size, "dynamic_batching": True, @@ -302,8 +301,8 @@ if __name__ == "__main__": "soft_over_length_punishment": True, "max_length": args.max_new_tokens + args.max_prompt_tokens, "max_new_tokens": args.max_new_tokens, - "cache_length": min(1024, int(args.max_new_tokens / 4)), - "filter_truncated_response": True, + "cache_length": max(1024, int(args.max_new_tokens / 4)), + "filter_truncated_response": False, "reward_fn_type": args.reward_type, "response_format_tags": ( { @@ -362,15 +361,22 @@ if __name__ == "__main__": save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")), eval_dataset_config=( { - k: {"path": v, "max_length": args.max_prompt_tokens, "system_prompt": args.system_prompt} + k: { + "path": v, + "max_length": args.max_prompt_tokens, + "system_prompt": args.system_prompt, + "batch_size": 1, + "n": 16, + "max_tokens": 32 * 1024 - 512, + "temperature": 0.6, + } for k, v in json.loads(args.eval_dataset).items() } if args.eval_dataset else None - ), + ), # support using different evaluation prompts/ parameters for each task eval_interval=args.eval_interval, eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")), - eval_generation_config=eval_generation_config, log_rollout_interval=20, rollout_save_dir=args.rollout_save_dir, enable_profiling=args.enable_profiling, diff --git a/applications/ColossalChat/run_eval.sh b/applications/ColossalChat/run_eval.sh new file mode 100755 index 000000000..9a8f1900d --- /dev/null +++ b/applications/ColossalChat/run_eval.sh @@ -0,0 +1,54 @@ +# run eval for every saved checkpoint + +#!/bin/bash + +# Set the base directory where models are saved +BASE_DIR="/mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline" +BASE_DIR="/mnt/nfs/yeanbang/experiments/rlhf/dapo/v3/model/DAPO-Experiment-Baseline-V3" +DATASET="/home/litong/workspace/data/rl-math-train-final.jsonl" +EVAL_SCRIPT="rl_example_zero_bubble.py" +# /mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline/modeling-episode-0-step-105 +# /mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline/modeling-episode-0-step-120 +# /mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline/modeling-episode-0-step-15 +# /mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline/modeling-episode-0-step-30 +# /mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline/modeling-episode-0-step-45 +# /mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline/modeling-episode-0-step-60 +# /mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline/modeling-episode-0-step-75 +# /mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline/modeling-episode-0-step-90 +# /mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline/modeling-episode-1-step-150 +# Loop through each model directory matching the pattern +# 64K 65024 +# 32K 32256 +for model_path in "$BASE_DIR"/modeling-episode-0-step-150; do + if [ -d "$model_path" ]; then + echo "Evaluating model: $model_path" + + python "$EVAL_SCRIPT" \ + --dataset "$DATASET" \ + --tokenizer-path "/home/litong/workspace/model/DeepSeek-R1-Distill-Qwen-7B" \ + --model "$model_path" \ + -t 1 \ + -i 6 \ + -b vllm \ + -a GRPO \ + -imbs 1 \ + -ibs 8 \ + -tbs 8 \ + -e 2 \ + -g 16 \ + -rt boxed \ + -si 25 \ + -s "Please reason step by step, and put your final answer within \\boxed{}." \ + -tMbs 2 \ + -tmbs 1 \ + -p DAPO-Experiment-Baseline-Eval \ + -ei 15 \ + -ed '{"AIME2024": "/mnt/nfs/yeanbang/data/RLHF_data/AIME/2024/AIME2024.jsonl"}' \ + -zero 2 \ + -mpt 512 \ + -mnt 65024 \ + --data_actor_buffer_size_limit 4 | grep "Accuracy on" + + echo "Finished evaluating: $model_path" + fi +done diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 27571309e..30b3ee032 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -824,7 +824,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - force_sp_output_gather=False, + # force_sp_output_gather=False, ) hidden_states = outputs[0]