From e774edeb8073bd6be0070ecfff0a317e572dea3e Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Mon, 21 Jul 2025 17:21:07 +0800 Subject: [PATCH] fix racing condition --- .../ColossalChat/coati/distributed/comm.py | 6 ++- .../coati/distributed/grpo_consumer.py | 18 +++++-- .../coati/distributed/inference_backend.py | 6 ++- .../coati/distributed/launch_zero_bubble.py | 5 +- .../ColossalChat/coati/distributed/loss.py | 4 +- .../coati/distributed/zero_bubble/consumer.py | 8 +-- .../distributed/zero_bubble/distributor.py | 21 ++++++-- .../distributed/zero_bubble/grpo_consumer.py | 49 ++++++++++++++++--- .../coati/distributed/zero_bubble/producer.py | 22 ++++++--- .../ColossalChat/rl_example_zero_bubble.py | 11 ++++- 10 files changed, 113 insertions(+), 37 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/comm.py b/applications/ColossalChat/coati/distributed/comm.py index 0a724d53b..21e6c7d90 100644 --- a/applications/ColossalChat/coati/distributed/comm.py +++ b/applications/ColossalChat/coati/distributed/comm.py @@ -1,5 +1,6 @@ -from typing import Any, Dict import copy +from typing import Any, Dict + import ray import ray.util.collective as cc import torch @@ -31,6 +32,7 @@ def ray_broadcast_object(obj: Any, src: int = 0, device=None, group_name: str = obj = c10d._tensor_to_object(obj, size_tensor.item()) return obj + def ray_broadcast_tensor_dict( tensor_dict: Dict[str, torch.Tensor], src: int = 0, @@ -98,7 +100,7 @@ class SharedVariableActor: queue length as data may still be generating """ ret = False - if self.queue_size < self.buffer_size_limit: + if self.queue_size < (self.buffer_size_limit / max(0.1, self.signals.get("sample_utilization", 1.0))): ret = True self.queue_size += num_tasks return ret diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 754f78097..7b4843b08 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -250,6 +250,9 @@ class GRPOConsumer(BaseConsumer): input_ids_forward_micro_batch = data["input_ids"][ forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size ] + old_action_log_probs_micro_batch = old_action_log_probs[ + forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size + ] attention_mask_forward_micro_batch = data["attention_mask"][ forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size ] @@ -306,17 +309,22 @@ class GRPOConsumer(BaseConsumer): "action_mask": action_mask_forward_micro_batch, "advantages": advantages_forward_micro_batch, "loss_mask": loss_mask_forward_micro_batch, + "old_action_log_probs": old_action_log_probs_micro_batch, "source": self.rank, } if reference_action_log_probs is not None: data_policy_forward["reference_action_log_probs"] = reference_action_log_probs kl = [] - policy_model_logits = torch.empty_like(input_ids_forward_micro_batch, device=self.device) def _criterion(outputs, inputs): action_logits = outputs.logits - policy_model_logits.copy_(action_logits) + mini_batch_entropies.append( + ( + ((entropy_from_logits(action_logits[:, -num_action:]) * inputs["action_mask"]).sum(-1)) + / inputs["action_mask"].sum(-1) + ).detach() + ) action_log_probs = memory_efficient_logprob( action_logits / self.generate_config["temperature"], inputs["input_ids"], @@ -339,7 +347,7 @@ class GRPOConsumer(BaseConsumer): loss, _ = self.policy_loss_fn( action_log_probs, - action_log_probs, + inputs["old_action_log_probs"], inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1), per_token_kl, inputs["action_mask"], @@ -415,7 +423,7 @@ class GRPOConsumer(BaseConsumer): loss, _ = self.policy_loss_fn( action_log_probs, - old_action_log_probs, + old_action_log_probs_micro_batch, advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1), per_token_kl, action_mask_forward_micro_batch, @@ -455,7 +463,7 @@ class GRPOConsumer(BaseConsumer): ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin) advantages = all_reduce_mean(advantages.mean(), self.plugin) response_length = all_reduce_mean(response_length.mean(), self.plugin) - entropy = torch.cat(mini_batch_entropies, dim=0).mean() + entropy = all_reduce_mean(torch.cat(mini_batch_entropies, dim=0).mean(), self.plugin) self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) self.accum_entropy.add_(entropy.data) if self.policy_loss_fn.beta > 0: diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index 34827e4e2..331f8d7b6 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -59,6 +59,7 @@ class TransformersInferenceBackend(BaseInferenceBackend): generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer, num_generations: int = 8, + tokenizer_config: Dict[str, Any] = None, ): model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG) model_config.update(self.FORCE_MODEL_CONFIG) @@ -132,6 +133,7 @@ class SGLangInferenceBackend(BaseInferenceBackend): generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer, num_generations: int = 8, + tokenizer_config: Dict[str, Any] = None, ): if sgl is None: raise ImportError("sglang is not installed") @@ -196,12 +198,14 @@ class VLLMInferenceBackend(BaseInferenceBackend): generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer, num_generations: int = 8, + tokenizer_config: Dict[str, Any] = None, ): if LLM is None: raise ImportError("vllm is not installed") model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG) path = model_config.pop("path") - self.llm = LLM(model=path, **model_config) + tokenizer_path = tokenizer_config.get("path", None) if tokenizer_config is not None else None + self.llm = LLM(model=path, tokenizer=tokenizer_path, **model_config) generate_config = generate_config.copy() generate_config.update(self.FORCE_GENERATE_CONFIG) generate_config.update({"n": num_generations}) diff --git a/applications/ColossalChat/coati/distributed/launch_zero_bubble.py b/applications/ColossalChat/coati/distributed/launch_zero_bubble.py index 635493b7d..de5b61353 100644 --- a/applications/ColossalChat/coati/distributed/launch_zero_bubble.py +++ b/applications/ColossalChat/coati/distributed/launch_zero_bubble.py @@ -130,7 +130,7 @@ def launch_distributed( train_dataset_config=train_dataset_config, model_config=inference_model_config, generate_config=generate_config, - tokenizer_config=tokenizer_config, + tokenizer_config=copy.deepcopy(tokenizer_config), microbatch_size=inference_microbatch_size, backend=inference_backend, num_generations=num_generations, @@ -158,8 +158,6 @@ def launch_distributed( consumer_master_ip_address = gpu_to_ip_address[0] print(f"Use {consumer_master_ip_address} as master address for torch DDP.") consumer_procs = [] - if num_consumer_procs <= 1: - raise ValueError("Number of consumer processes should be greater than 1 for async rl training.") for i in range(num_consumer_procs): node_id = gpu_to_node_id[0] consumer_ip_address = gpu_to_ip_address[0] @@ -180,6 +178,7 @@ def launch_distributed( model_config=train_model_config, plugin_config=plugin_config, minibatch_size=train_minibatch_size, + tokenizer_config=copy.deepcopy(tokenizer_config), generate_config=generate_config_consumer, grpo_config=grpo_config, num_generations=num_generations, diff --git a/applications/ColossalChat/coati/distributed/loss.py b/applications/ColossalChat/coati/distributed/loss.py index 36057b24f..ea4d0dd11 100644 --- a/applications/ColossalChat/coati/distributed/loss.py +++ b/applications/ColossalChat/coati/distributed/loss.py @@ -35,9 +35,9 @@ class PolicyLoss(nn.Module): total_effective_tokens_in_batch: torch.Tensor = None, ) -> torch.Tensor: if action_mask is None: - ratio = (log_probs - log_probs.detach()).exp() + ratio = (log_probs - old_log_probs.detach()).exp() else: - ratio = ((log_probs - log_probs.detach()) * action_mask).exp() + ratio = ((log_probs - old_log_probs.detach()) * action_mask).exp() surr1 = ratio * advantages surr2 = ratio.clamp(1 - self.clip_eps_low, 1 + self.clip_eps_high) * advantages diff --git a/applications/ColossalChat/coati/distributed/zero_bubble/consumer.py b/applications/ColossalChat/coati/distributed/zero_bubble/consumer.py index 82242e874..2b4790884 100644 --- a/applications/ColossalChat/coati/distributed/zero_bubble/consumer.py +++ b/applications/ColossalChat/coati/distributed/zero_bubble/consumer.py @@ -7,7 +7,9 @@ import ray import ray.util.collective as cc import torch import torch.distributed as dist +from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict from coati.distributed.profiling_utils import CustomProfiler +from coati.distributed.utils import bind_batch, post_recv, unbind_batch from tqdm import tqdm from colossalai.booster import Booster @@ -15,9 +17,6 @@ from colossalai.booster.plugin import HybridParallelPlugin from colossalai.initialize import launch from colossalai.utils import get_current_device -from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict -from coati.distributed.utils import bind_batch, post_recv, unbind_batch - class BaseConsumer: def __init__( @@ -175,14 +174,15 @@ class BaseConsumer: raw_batch = ray.get( self.shared_sync_data_actor.get_data.remote(self.data_uid) ) # get the first queued data + self.profiler.log(f"enter sleep") while raw_batch is None: - self.profiler.log(f"No data received by consumer {self.rank}, skipping") print( f"[T{dist.get_rank()}] No data received by consumer {self.rank}, skipping. Consider increasing the data actor buffer limit" ) time.sleep(1) raw_batch = ray.get(self.shared_sync_data_actor.get_data.remote(self.data_uid)) continue + self.profiler.log(f"exit sleep") self.data_uid += 1 raw_batch = {k: v.to(self.device) for k, v in raw_batch.items()} # calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete), diff --git a/applications/ColossalChat/coati/distributed/zero_bubble/distributor.py b/applications/ColossalChat/coati/distributed/zero_bubble/distributor.py index b16f4b67e..ea04ae13c 100644 --- a/applications/ColossalChat/coati/distributed/zero_bubble/distributor.py +++ b/applications/ColossalChat/coati/distributed/zero_bubble/distributor.py @@ -3,12 +3,11 @@ import time import ray import ray.util.collective as cc import torch +from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict from coati.distributed.profiling_utils import CustomProfiler from colossalai.utils import get_current_device -from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict - @ray.remote class Distributor: @@ -21,6 +20,7 @@ class Distributor: enable_profiling: bool = True, ): self.distributor_id = distributor_id + self.weight_version = [0] * consumer_pp_size self.consumer_pp_size = consumer_pp_size self.state_dict_cpu = {} self.num_producers = num_producers @@ -42,14 +42,17 @@ class Distributor: print(f"[D] Initialized {group_name} collective group", flush=True) def loop(self): + last_weight_version = self.get_weight_version() while True: time.sleep(1) signal = ray.get(self.shared_signal_actor.get_signal.remote()) if self.consumer_pp_size > 1: - for i in range(self.consumer_pp_size): - if signal.get(f"consumer_pp_{i}", None) == "ready_sync_model": + if all( + [signal.get(f"consumer_pp_{i}", None) == "ready_sync_model" for i in range(self.consumer_pp_size)] + ): + cc.barrier(group_name="distributor_pg") + for i in range(self.consumer_pp_size): self.profiler.enter(f"sync_model_consumer_pp_{i}") - cc.barrier(group_name="distributor_pg") ray.get(self.shared_signal_actor.set_signal.remote(f"consumer_pp_{i}", "not_ready_sync_model")) # Broadcast the model state dict from consumer to shared variable actor self.state_dict_cpu[i] = ray_broadcast_tensor_dict( @@ -60,6 +63,7 @@ class Distributor: backend="gloo", ) self.profiler.exit(f"sync_model_consumer_pp_{i}") + self.weight_version[i] += 1 for i in range(self.consumer_pp_size): if signal.get(f"producer_{self.distributor_id}_pp_{i}", None) == "ready_sync_model": self.profiler.enter(f"sync_model_producer_{self.distributor_id}_pp_{i}") @@ -87,6 +91,7 @@ class Distributor: None, 0, device=torch.device("cpu"), group_name="sync_model_consumer", backend="gloo" ) self.profiler.exit("sync_model_consumer") + self.weight_version[0] += 1 if signal.get(f"producer_{self.distributor_id}", None) == "ready_sync_model": self.profiler.enter(f"sync_model_producer_{self.distributor_id}") # Broadcast the model state dict to all producers @@ -106,3 +111,9 @@ class Distributor: if signal.get("consumer", None) == "terminate": self.profiler.log("terminate sync model worker") break + if last_weight_version != self.get_weight_version(): + last_weight_version = self.get_weight_version() + ray.get(self.shared_signal_actor.set_signal.remote("distributor_weight_version", last_weight_version)) + + def get_weight_version(self): + return min(self.weight_version) diff --git a/applications/ColossalChat/coati/distributed/zero_bubble/grpo_consumer.py b/applications/ColossalChat/coati/distributed/zero_bubble/grpo_consumer.py index c07385b97..f04785271 100644 --- a/applications/ColossalChat/coati/distributed/zero_bubble/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/zero_bubble/grpo_consumer.py @@ -5,9 +5,9 @@ import ray import torch import wandb from coati.distributed.comm import SharedVariableActor -from coati.distributed.zero_bubble.consumer import BaseConsumer from coati.distributed.loss import PolicyLoss -from coati.distributed.utils import memory_efficient_logprob +from coati.distributed.utils import entropy_from_logits, memory_efficient_logprob +from coati.distributed.zero_bubble.consumer import BaseConsumer from coati.trainer.utils import all_reduce_mean, all_reduce_sum from transformers import AutoModelForCausalLM, AutoTokenizer @@ -33,6 +33,7 @@ class GRPOConsumer(BaseConsumer): plugin_config, minibatch_size=1, num_generations=8, + tokenizer_config=None, generate_config=None, grpo_config={}, save_interval: int = 100, @@ -73,9 +74,11 @@ class GRPOConsumer(BaseConsumer): self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.policy_model.train() self.policy_model.gradient_checkpointing_enable() + self.vocab_size = self.policy_model.config.vocab_size self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6)) self.accum_loss = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device) + self.accum_entropy = torch.zeros(1, device=self.device) self.accum_advantages = torch.zeros(1, device=self.device) self.raw_train_batch_reward = [] self.raw_train_batch_format_acc = [] @@ -102,8 +105,11 @@ class GRPOConsumer(BaseConsumer): if self.policy_loss_fn.beta > 0: self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.reference_model.eval() - - self.tokenizer = AutoTokenizer.from_pretrained(path) + if tokenizer_config is not None: + path = tokenizer_config.pop("path", None) + self.tokenizer = AutoTokenizer.from_pretrained(path, **tokenizer_config) + else: + self.tokenizer = AutoTokenizer.from_pretrained(path) self.pad_token_id = self.tokenizer.pad_token_id self.num_generations = num_generations self.filter_range = grpo_config.get("filter_range", None) @@ -243,10 +249,14 @@ class GRPOConsumer(BaseConsumer): else self.booster.no_sync(self.policy_model, self.optimizer) ) with ctx: + mini_batch_entropies = [] for forward_micro_batch_start in range(0, data["input_ids"].size(0), train_microbatch_size): input_ids_forward_micro_batch = data["input_ids"][ forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size ] + old_action_log_probs_micro_batch = old_action_log_probs[ + forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size + ] attention_mask_forward_micro_batch = data["attention_mask"][ forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size ] @@ -303,6 +313,7 @@ class GRPOConsumer(BaseConsumer): "action_mask": action_mask_forward_micro_batch, "advantages": advantages_forward_micro_batch, "loss_mask": loss_mask_forward_micro_batch, + "old_action_log_probs": old_action_log_probs_micro_batch, "source": self.rank, } if reference_action_log_probs is not None: @@ -312,6 +323,12 @@ class GRPOConsumer(BaseConsumer): def _criterion(outputs, inputs): action_logits = outputs.logits + mini_batch_entropies.append( + ( + ((entropy_from_logits(action_logits[:, -num_action:]) * inputs["action_mask"]).sum(-1)) + / inputs["action_mask"].sum(-1) + ).detach() + ) action_log_probs = memory_efficient_logprob( action_logits / self.generate_config["temperature"], inputs["input_ids"], @@ -334,7 +351,7 @@ class GRPOConsumer(BaseConsumer): loss, _ = self.policy_loss_fn( action_log_probs, - action_log_probs, + inputs["old_action_log_probs"], inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1), per_token_kl, inputs["action_mask"], @@ -396,7 +413,7 @@ class GRPOConsumer(BaseConsumer): loss, _ = self.policy_loss_fn( action_log_probs, - old_action_log_probs, + old_action_log_probs_micro_batch, advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1), per_token_kl, action_mask_forward_micro_batch, @@ -411,6 +428,20 @@ class GRPOConsumer(BaseConsumer): kl = all_reduce_mean(kl.mean(), self.plugin) mean_kl.append(kl.data) mean_loss.append(loss.data) + mini_batch_entropies.append( + all_reduce_mean( + ( + ( + ( + entropy_from_logits(policy_model_logits[:, -num_action:]) + * action_mask_forward_micro_batch + ).sum(-1) + ) + / action_mask_forward_micro_batch.sum(-1) + ).detach(), + self.plugin, + ) + ) if not self.plugin.pp_size > 1 or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() @@ -422,7 +453,9 @@ class GRPOConsumer(BaseConsumer): ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin) advantages = all_reduce_mean(advantages.mean(), self.plugin) response_length = all_reduce_mean(response_length.mean(), self.plugin) + entropy = all_reduce_mean(torch.cat(mini_batch_entropies, dim=0).mean(), self.plugin) self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) + self.accum_entropy.add_(entropy.data) if self.policy_loss_fn.beta > 0: self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) self.accum_advantages.add_(advantages.data) @@ -465,6 +498,7 @@ class GRPOConsumer(BaseConsumer): f"Response Length: {raw_batch_response_len_mean:.4f}", f"Sample_utilization: {sample_utilization:.4f}", f"Overlength samples ratio: {overlength_samples_ratio:.4f}", + f"Entropy: {self.accum_entropy.item() / self.accum_count:.4f}", ] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else []) print("\n".join(to_log_msg)) metrics = { @@ -476,6 +510,7 @@ class GRPOConsumer(BaseConsumer): "train/advantages": self.accum_advantages.item() / self.accum_count, "train/learning_rate": self.lr_scheduler.get_last_lr()[0], "train/sample_utilization": sample_utilization, + "train/entropy": self.accum_entropy.item() / self.accum_count, "train/overlength_samples_ratio": overlength_samples_ratio, "rollout/temperature": data["temperature"].cpu().numpy()[0][0], } @@ -483,8 +518,10 @@ class GRPOConsumer(BaseConsumer): metrics["train/kl"] = self.accum_kl.item() / self.accum_count if self.wandb_run is not None: self.wandb_run.log(metrics) + ray.get(self.shared_signal_actor.set_signal.remote("sample_utilization", sample_utilization)) self.accum_loss.zero_() self.accum_kl.zero_() + self.accum_entropy.zero_() self.accum_advantages.zero_() self.accum_count = 0 return loss_scalar diff --git a/applications/ColossalChat/coati/distributed/zero_bubble/producer.py b/applications/ColossalChat/coati/distributed/zero_bubble/producer.py index 9e57914c4..31c314dd5 100644 --- a/applications/ColossalChat/coati/distributed/zero_bubble/producer.py +++ b/applications/ColossalChat/coati/distributed/zero_bubble/producer.py @@ -11,9 +11,12 @@ import torch import tqdm import wandb from coati.dataset.loader import RawConversationDataset, collate_fn_grpo +from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict +from coati.distributed.inference_backend import BACKEND_MAP from coati.distributed.profiling_utils import CustomProfiler from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn from coati.distributed.reward.verifiable_reward import VerifiableReward +from coati.distributed.utils import pre_send, safe_append_to_jsonl_file from ray.util.collective import allreduce from ray.util.collective.types import ReduceOp from torch.utils.data import DataLoader, DistributedSampler @@ -21,10 +24,6 @@ from transformers import AutoTokenizer from colossalai.utils import get_current_device -from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict -from coati.distributed.inference_backend import BACKEND_MAP -from coati.distributed.utils import pre_send, safe_append_to_jsonl_file - try: from vllm import SamplingParams except ImportError: @@ -280,11 +279,17 @@ class BaseProducer: self.profiler.exit("sync_model") self.sync_model_thread_started = False - if not self.sync_model_thread_started and self.consumer_global_step != self.producer_weight_version: + distributor_weight_version = ray.get(self.shared_signal_actor.get_signal.remote()).get( + f"distributor_weight_version", 0 + ) + if ( + not self.sync_model_thread_started + and distributor_weight_version != self.producer_weight_version + ): # only sync model when the thread is not started and global step is changed self.sync_model_thread_started = True self.sync_model_thread = threading.Thread(target=sync_model_thread) - self.producer_weight_version = self.consumer_global_step + self.producer_weight_version = distributor_weight_version self.sync_model_thread.start() torch.cuda.empty_cache() if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get( @@ -478,7 +483,7 @@ class SimpleProducer(BaseProducer): train_dataset_config, model_config, generate_config, - tokenizer_config, + copy.deepcopy(tokenizer_config), microbatch_size, backend, consumer_plugin_config, @@ -493,7 +498,8 @@ class SimpleProducer(BaseProducer): rollout_log_file=rollout_log_file, enable_profiling=enable_profiling, ) - self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations) + print("tokenizer_config", tokenizer_config) + self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations, tokenizer_config) self.eval_generation_config = copy.deepcopy(self.model.generate_config) self.eval_generation_config["n"] = 1 # use 1 generation for evaluation self.eval_generation_config.update(eval_generation_config) diff --git a/applications/ColossalChat/rl_example_zero_bubble.py b/applications/ColossalChat/rl_example_zero_bubble.py index e4f149653..89270d753 100644 --- a/applications/ColossalChat/rl_example_zero_bubble.py +++ b/applications/ColossalChat/rl_example_zero_bubble.py @@ -15,6 +15,12 @@ DEFAUT_SYSTEM_PROMPT = { if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B") + parser.add_argument( + "--tokenizer-path", + type=str, + default=None, + help="Path to the tokenizer. If not provided, will use the model path.", + ) parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") parser.add_argument( "-ed", @@ -166,6 +172,7 @@ if __name__ == "__main__": "--enable_profiling", action="store_true", default=False, help="Enable profiling for the training process." ) args = parser.parse_args() + print(args) if args.train_minibatch_size is None: # Default settings: Using train batch size as mini batch size @@ -283,7 +290,7 @@ if __name__ == "__main__": elif args.algo == "DAPO": # DAPO variant settings grpo_config = { - "filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch + "filter_range": [0.01, 0.7], # only filter out all zero batch and all one batch "lr": args.learning_rate, "train_microbatch_size": args.train_microbatch_size, "dynamic_batching": True, @@ -343,7 +350,9 @@ if __name__ == "__main__": ), # microbatch size should be set to train_microbatch_size // pp_size "zero_stage": args.zero_stage, "max_norm": 1.0, + "num_layers_per_stage": [18, 10], }, # for pp, tp + tokenizer_config={"path": args.tokenizer_path} if args.tokenizer_path else {"path": args.model}, inference_backend=args.backend, master_addr="localhost", master_port=args.master_port,