fix racing condition

This commit is contained in:
YeAnbang 2025-07-21 17:21:07 +08:00
parent f54ae56f12
commit e774edeb80
10 changed files with 113 additions and 37 deletions

View File

@ -1,5 +1,6 @@
from typing import Any, Dict
import copy
from typing import Any, Dict
import ray
import ray.util.collective as cc
import torch
@ -31,6 +32,7 @@ def ray_broadcast_object(obj: Any, src: int = 0, device=None, group_name: str =
obj = c10d._tensor_to_object(obj, size_tensor.item())
return obj
def ray_broadcast_tensor_dict(
tensor_dict: Dict[str, torch.Tensor],
src: int = 0,
@ -98,7 +100,7 @@ class SharedVariableActor:
queue length as data may still be generating
"""
ret = False
if self.queue_size < self.buffer_size_limit:
if self.queue_size < (self.buffer_size_limit / max(0.1, self.signals.get("sample_utilization", 1.0))):
ret = True
self.queue_size += num_tasks
return ret

View File

@ -250,6 +250,9 @@ class GRPOConsumer(BaseConsumer):
input_ids_forward_micro_batch = data["input_ids"][
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
]
old_action_log_probs_micro_batch = old_action_log_probs[
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
]
attention_mask_forward_micro_batch = data["attention_mask"][
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
]
@ -306,17 +309,22 @@ class GRPOConsumer(BaseConsumer):
"action_mask": action_mask_forward_micro_batch,
"advantages": advantages_forward_micro_batch,
"loss_mask": loss_mask_forward_micro_batch,
"old_action_log_probs": old_action_log_probs_micro_batch,
"source": self.rank,
}
if reference_action_log_probs is not None:
data_policy_forward["reference_action_log_probs"] = reference_action_log_probs
kl = []
policy_model_logits = torch.empty_like(input_ids_forward_micro_batch, device=self.device)
def _criterion(outputs, inputs):
action_logits = outputs.logits
policy_model_logits.copy_(action_logits)
mini_batch_entropies.append(
(
((entropy_from_logits(action_logits[:, -num_action:]) * inputs["action_mask"]).sum(-1))
/ inputs["action_mask"].sum(-1)
).detach()
)
action_log_probs = memory_efficient_logprob(
action_logits / self.generate_config["temperature"],
inputs["input_ids"],
@ -339,7 +347,7 @@ class GRPOConsumer(BaseConsumer):
loss, _ = self.policy_loss_fn(
action_log_probs,
action_log_probs,
inputs["old_action_log_probs"],
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
per_token_kl,
inputs["action_mask"],
@ -415,7 +423,7 @@ class GRPOConsumer(BaseConsumer):
loss, _ = self.policy_loss_fn(
action_log_probs,
old_action_log_probs,
old_action_log_probs_micro_batch,
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
per_token_kl,
action_mask_forward_micro_batch,
@ -455,7 +463,7 @@ class GRPOConsumer(BaseConsumer):
ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin)
advantages = all_reduce_mean(advantages.mean(), self.plugin)
response_length = all_reduce_mean(response_length.mean(), self.plugin)
entropy = torch.cat(mini_batch_entropies, dim=0).mean()
entropy = all_reduce_mean(torch.cat(mini_batch_entropies, dim=0).mean(), self.plugin)
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
self.accum_entropy.add_(entropy.data)
if self.policy_loss_fn.beta > 0:

View File

@ -59,6 +59,7 @@ class TransformersInferenceBackend(BaseInferenceBackend):
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
tokenizer_config: Dict[str, Any] = None,
):
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
model_config.update(self.FORCE_MODEL_CONFIG)
@ -132,6 +133,7 @@ class SGLangInferenceBackend(BaseInferenceBackend):
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
tokenizer_config: Dict[str, Any] = None,
):
if sgl is None:
raise ImportError("sglang is not installed")
@ -196,12 +198,14 @@ class VLLMInferenceBackend(BaseInferenceBackend):
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
tokenizer_config: Dict[str, Any] = None,
):
if LLM is None:
raise ImportError("vllm is not installed")
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
path = model_config.pop("path")
self.llm = LLM(model=path, **model_config)
tokenizer_path = tokenizer_config.get("path", None) if tokenizer_config is not None else None
self.llm = LLM(model=path, tokenizer=tokenizer_path, **model_config)
generate_config = generate_config.copy()
generate_config.update(self.FORCE_GENERATE_CONFIG)
generate_config.update({"n": num_generations})

View File

@ -130,7 +130,7 @@ def launch_distributed(
train_dataset_config=train_dataset_config,
model_config=inference_model_config,
generate_config=generate_config,
tokenizer_config=tokenizer_config,
tokenizer_config=copy.deepcopy(tokenizer_config),
microbatch_size=inference_microbatch_size,
backend=inference_backend,
num_generations=num_generations,
@ -158,8 +158,6 @@ def launch_distributed(
consumer_master_ip_address = gpu_to_ip_address[0]
print(f"Use {consumer_master_ip_address} as master address for torch DDP.")
consumer_procs = []
if num_consumer_procs <= 1:
raise ValueError("Number of consumer processes should be greater than 1 for async rl training.")
for i in range(num_consumer_procs):
node_id = gpu_to_node_id[0]
consumer_ip_address = gpu_to_ip_address[0]
@ -180,6 +178,7 @@ def launch_distributed(
model_config=train_model_config,
plugin_config=plugin_config,
minibatch_size=train_minibatch_size,
tokenizer_config=copy.deepcopy(tokenizer_config),
generate_config=generate_config_consumer,
grpo_config=grpo_config,
num_generations=num_generations,

View File

@ -35,9 +35,9 @@ class PolicyLoss(nn.Module):
total_effective_tokens_in_batch: torch.Tensor = None,
) -> torch.Tensor:
if action_mask is None:
ratio = (log_probs - log_probs.detach()).exp()
ratio = (log_probs - old_log_probs.detach()).exp()
else:
ratio = ((log_probs - log_probs.detach()) * action_mask).exp()
ratio = ((log_probs - old_log_probs.detach()) * action_mask).exp()
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps_low, 1 + self.clip_eps_high) * advantages

View File

@ -7,7 +7,9 @@ import ray
import ray.util.collective as cc
import torch
import torch.distributed as dist
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
from coati.distributed.profiling_utils import CustomProfiler
from coati.distributed.utils import bind_batch, post_recv, unbind_batch
from tqdm import tqdm
from colossalai.booster import Booster
@ -15,9 +17,6 @@ from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.initialize import launch
from colossalai.utils import get_current_device
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
from coati.distributed.utils import bind_batch, post_recv, unbind_batch
class BaseConsumer:
def __init__(
@ -175,14 +174,15 @@ class BaseConsumer:
raw_batch = ray.get(
self.shared_sync_data_actor.get_data.remote(self.data_uid)
) # get the first queued data
self.profiler.log(f"enter sleep")
while raw_batch is None:
self.profiler.log(f"No data received by consumer {self.rank}, skipping")
print(
f"[T{dist.get_rank()}] No data received by consumer {self.rank}, skipping. Consider increasing the data actor buffer limit"
)
time.sleep(1)
raw_batch = ray.get(self.shared_sync_data_actor.get_data.remote(self.data_uid))
continue
self.profiler.log(f"exit sleep")
self.data_uid += 1
raw_batch = {k: v.to(self.device) for k, v in raw_batch.items()}
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),

View File

@ -3,12 +3,11 @@ import time
import ray
import ray.util.collective as cc
import torch
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
from coati.distributed.profiling_utils import CustomProfiler
from colossalai.utils import get_current_device
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
@ray.remote
class Distributor:
@ -21,6 +20,7 @@ class Distributor:
enable_profiling: bool = True,
):
self.distributor_id = distributor_id
self.weight_version = [0] * consumer_pp_size
self.consumer_pp_size = consumer_pp_size
self.state_dict_cpu = {}
self.num_producers = num_producers
@ -42,14 +42,17 @@ class Distributor:
print(f"[D] Initialized {group_name} collective group", flush=True)
def loop(self):
last_weight_version = self.get_weight_version()
while True:
time.sleep(1)
signal = ray.get(self.shared_signal_actor.get_signal.remote())
if self.consumer_pp_size > 1:
for i in range(self.consumer_pp_size):
if signal.get(f"consumer_pp_{i}", None) == "ready_sync_model":
self.profiler.enter(f"sync_model_consumer_pp_{i}")
if all(
[signal.get(f"consumer_pp_{i}", None) == "ready_sync_model" for i in range(self.consumer_pp_size)]
):
cc.barrier(group_name="distributor_pg")
for i in range(self.consumer_pp_size):
self.profiler.enter(f"sync_model_consumer_pp_{i}")
ray.get(self.shared_signal_actor.set_signal.remote(f"consumer_pp_{i}", "not_ready_sync_model"))
# Broadcast the model state dict from consumer to shared variable actor
self.state_dict_cpu[i] = ray_broadcast_tensor_dict(
@ -60,6 +63,7 @@ class Distributor:
backend="gloo",
)
self.profiler.exit(f"sync_model_consumer_pp_{i}")
self.weight_version[i] += 1
for i in range(self.consumer_pp_size):
if signal.get(f"producer_{self.distributor_id}_pp_{i}", None) == "ready_sync_model":
self.profiler.enter(f"sync_model_producer_{self.distributor_id}_pp_{i}")
@ -87,6 +91,7 @@ class Distributor:
None, 0, device=torch.device("cpu"), group_name="sync_model_consumer", backend="gloo"
)
self.profiler.exit("sync_model_consumer")
self.weight_version[0] += 1
if signal.get(f"producer_{self.distributor_id}", None) == "ready_sync_model":
self.profiler.enter(f"sync_model_producer_{self.distributor_id}")
# Broadcast the model state dict to all producers
@ -106,3 +111,9 @@ class Distributor:
if signal.get("consumer", None) == "terminate":
self.profiler.log("terminate sync model worker")
break
if last_weight_version != self.get_weight_version():
last_weight_version = self.get_weight_version()
ray.get(self.shared_signal_actor.set_signal.remote("distributor_weight_version", last_weight_version))
def get_weight_version(self):
return min(self.weight_version)

View File

@ -5,9 +5,9 @@ import ray
import torch
import wandb
from coati.distributed.comm import SharedVariableActor
from coati.distributed.zero_bubble.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss
from coati.distributed.utils import memory_efficient_logprob
from coati.distributed.utils import entropy_from_logits, memory_efficient_logprob
from coati.distributed.zero_bubble.consumer import BaseConsumer
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
from transformers import AutoModelForCausalLM, AutoTokenizer
@ -33,6 +33,7 @@ class GRPOConsumer(BaseConsumer):
plugin_config,
minibatch_size=1,
num_generations=8,
tokenizer_config=None,
generate_config=None,
grpo_config={},
save_interval: int = 100,
@ -73,9 +74,11 @@ class GRPOConsumer(BaseConsumer):
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.policy_model.train()
self.policy_model.gradient_checkpointing_enable()
self.vocab_size = self.policy_model.config.vocab_size
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
self.accum_loss = torch.zeros(1, device=self.device)
self.accum_kl = torch.zeros(1, device=self.device)
self.accum_entropy = torch.zeros(1, device=self.device)
self.accum_advantages = torch.zeros(1, device=self.device)
self.raw_train_batch_reward = []
self.raw_train_batch_format_acc = []
@ -102,7 +105,10 @@ class GRPOConsumer(BaseConsumer):
if self.policy_loss_fn.beta > 0:
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.reference_model.eval()
if tokenizer_config is not None:
path = tokenizer_config.pop("path", None)
self.tokenizer = AutoTokenizer.from_pretrained(path, **tokenizer_config)
else:
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.pad_token_id = self.tokenizer.pad_token_id
self.num_generations = num_generations
@ -243,10 +249,14 @@ class GRPOConsumer(BaseConsumer):
else self.booster.no_sync(self.policy_model, self.optimizer)
)
with ctx:
mini_batch_entropies = []
for forward_micro_batch_start in range(0, data["input_ids"].size(0), train_microbatch_size):
input_ids_forward_micro_batch = data["input_ids"][
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
]
old_action_log_probs_micro_batch = old_action_log_probs[
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
]
attention_mask_forward_micro_batch = data["attention_mask"][
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
]
@ -303,6 +313,7 @@ class GRPOConsumer(BaseConsumer):
"action_mask": action_mask_forward_micro_batch,
"advantages": advantages_forward_micro_batch,
"loss_mask": loss_mask_forward_micro_batch,
"old_action_log_probs": old_action_log_probs_micro_batch,
"source": self.rank,
}
if reference_action_log_probs is not None:
@ -312,6 +323,12 @@ class GRPOConsumer(BaseConsumer):
def _criterion(outputs, inputs):
action_logits = outputs.logits
mini_batch_entropies.append(
(
((entropy_from_logits(action_logits[:, -num_action:]) * inputs["action_mask"]).sum(-1))
/ inputs["action_mask"].sum(-1)
).detach()
)
action_log_probs = memory_efficient_logprob(
action_logits / self.generate_config["temperature"],
inputs["input_ids"],
@ -334,7 +351,7 @@ class GRPOConsumer(BaseConsumer):
loss, _ = self.policy_loss_fn(
action_log_probs,
action_log_probs,
inputs["old_action_log_probs"],
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
per_token_kl,
inputs["action_mask"],
@ -396,7 +413,7 @@ class GRPOConsumer(BaseConsumer):
loss, _ = self.policy_loss_fn(
action_log_probs,
old_action_log_probs,
old_action_log_probs_micro_batch,
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
per_token_kl,
action_mask_forward_micro_batch,
@ -411,6 +428,20 @@ class GRPOConsumer(BaseConsumer):
kl = all_reduce_mean(kl.mean(), self.plugin)
mean_kl.append(kl.data)
mean_loss.append(loss.data)
mini_batch_entropies.append(
all_reduce_mean(
(
(
(
entropy_from_logits(policy_model_logits[:, -num_action:])
* action_mask_forward_micro_batch
).sum(-1)
)
/ action_mask_forward_micro_batch.sum(-1)
).detach(),
self.plugin,
)
)
if not self.plugin.pp_size > 1 or (
self.plugin.pp_size > 1
and self.booster.plugin.stage_manager.is_last_stage()
@ -422,7 +453,9 @@ class GRPOConsumer(BaseConsumer):
ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin)
advantages = all_reduce_mean(advantages.mean(), self.plugin)
response_length = all_reduce_mean(response_length.mean(), self.plugin)
entropy = all_reduce_mean(torch.cat(mini_batch_entropies, dim=0).mean(), self.plugin)
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
self.accum_entropy.add_(entropy.data)
if self.policy_loss_fn.beta > 0:
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
self.accum_advantages.add_(advantages.data)
@ -465,6 +498,7 @@ class GRPOConsumer(BaseConsumer):
f"Response Length: {raw_batch_response_len_mean:.4f}",
f"Sample_utilization: {sample_utilization:.4f}",
f"Overlength samples ratio: {overlength_samples_ratio:.4f}",
f"Entropy: {self.accum_entropy.item() / self.accum_count:.4f}",
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
print("\n".join(to_log_msg))
metrics = {
@ -476,6 +510,7 @@ class GRPOConsumer(BaseConsumer):
"train/advantages": self.accum_advantages.item() / self.accum_count,
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
"train/sample_utilization": sample_utilization,
"train/entropy": self.accum_entropy.item() / self.accum_count,
"train/overlength_samples_ratio": overlength_samples_ratio,
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
}
@ -483,8 +518,10 @@ class GRPOConsumer(BaseConsumer):
metrics["train/kl"] = self.accum_kl.item() / self.accum_count
if self.wandb_run is not None:
self.wandb_run.log(metrics)
ray.get(self.shared_signal_actor.set_signal.remote("sample_utilization", sample_utilization))
self.accum_loss.zero_()
self.accum_kl.zero_()
self.accum_entropy.zero_()
self.accum_advantages.zero_()
self.accum_count = 0
return loss_scalar

View File

@ -11,9 +11,12 @@ import torch
import tqdm
import wandb
from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
from coati.distributed.inference_backend import BACKEND_MAP
from coati.distributed.profiling_utils import CustomProfiler
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import pre_send, safe_append_to_jsonl_file
from ray.util.collective import allreduce
from ray.util.collective.types import ReduceOp
from torch.utils.data import DataLoader, DistributedSampler
@ -21,10 +24,6 @@ from transformers import AutoTokenizer
from colossalai.utils import get_current_device
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
from coati.distributed.inference_backend import BACKEND_MAP
from coati.distributed.utils import pre_send, safe_append_to_jsonl_file
try:
from vllm import SamplingParams
except ImportError:
@ -280,11 +279,17 @@ class BaseProducer:
self.profiler.exit("sync_model")
self.sync_model_thread_started = False
if not self.sync_model_thread_started and self.consumer_global_step != self.producer_weight_version:
distributor_weight_version = ray.get(self.shared_signal_actor.get_signal.remote()).get(
f"distributor_weight_version", 0
)
if (
not self.sync_model_thread_started
and distributor_weight_version != self.producer_weight_version
):
# only sync model when the thread is not started and global step is changed
self.sync_model_thread_started = True
self.sync_model_thread = threading.Thread(target=sync_model_thread)
self.producer_weight_version = self.consumer_global_step
self.producer_weight_version = distributor_weight_version
self.sync_model_thread.start()
torch.cuda.empty_cache()
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
@ -478,7 +483,7 @@ class SimpleProducer(BaseProducer):
train_dataset_config,
model_config,
generate_config,
tokenizer_config,
copy.deepcopy(tokenizer_config),
microbatch_size,
backend,
consumer_plugin_config,
@ -493,7 +498,8 @@ class SimpleProducer(BaseProducer):
rollout_log_file=rollout_log_file,
enable_profiling=enable_profiling,
)
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
print("tokenizer_config", tokenizer_config)
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations, tokenizer_config)
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
self.eval_generation_config.update(eval_generation_config)

View File

@ -15,6 +15,12 @@ DEFAUT_SYSTEM_PROMPT = {
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
parser.add_argument(
"--tokenizer-path",
type=str,
default=None,
help="Path to the tokenizer. If not provided, will use the model path.",
)
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
parser.add_argument(
"-ed",
@ -166,6 +172,7 @@ if __name__ == "__main__":
"--enable_profiling", action="store_true", default=False, help="Enable profiling for the training process."
)
args = parser.parse_args()
print(args)
if args.train_minibatch_size is None:
# Default settings: Using train batch size as mini batch size
@ -283,7 +290,7 @@ if __name__ == "__main__":
elif args.algo == "DAPO":
# DAPO variant settings
grpo_config = {
"filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch
"filter_range": [0.01, 0.7], # only filter out all zero batch and all one batch
"lr": args.learning_rate,
"train_microbatch_size": args.train_microbatch_size,
"dynamic_batching": True,
@ -343,7 +350,9 @@ if __name__ == "__main__":
), # microbatch size should be set to train_microbatch_size // pp_size
"zero_stage": args.zero_stage,
"max_norm": 1.0,
"num_layers_per_stage": [18, 10],
}, # for pp, tp
tokenizer_config={"path": args.tokenizer_path} if args.tokenizer_path else {"path": args.model},
inference_backend=args.backend,
master_addr="localhost",
master_port=args.master_port,