mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-24 03:59:33 +00:00
fix racing condition
This commit is contained in:
parent
f54ae56f12
commit
e774edeb80
@ -1,5 +1,6 @@
|
|||||||
from typing import Any, Dict
|
|
||||||
import copy
|
import copy
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import ray.util.collective as cc
|
import ray.util.collective as cc
|
||||||
import torch
|
import torch
|
||||||
@ -31,6 +32,7 @@ def ray_broadcast_object(obj: Any, src: int = 0, device=None, group_name: str =
|
|||||||
obj = c10d._tensor_to_object(obj, size_tensor.item())
|
obj = c10d._tensor_to_object(obj, size_tensor.item())
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
def ray_broadcast_tensor_dict(
|
def ray_broadcast_tensor_dict(
|
||||||
tensor_dict: Dict[str, torch.Tensor],
|
tensor_dict: Dict[str, torch.Tensor],
|
||||||
src: int = 0,
|
src: int = 0,
|
||||||
@ -98,7 +100,7 @@ class SharedVariableActor:
|
|||||||
queue length as data may still be generating
|
queue length as data may still be generating
|
||||||
"""
|
"""
|
||||||
ret = False
|
ret = False
|
||||||
if self.queue_size < self.buffer_size_limit:
|
if self.queue_size < (self.buffer_size_limit / max(0.1, self.signals.get("sample_utilization", 1.0))):
|
||||||
ret = True
|
ret = True
|
||||||
self.queue_size += num_tasks
|
self.queue_size += num_tasks
|
||||||
return ret
|
return ret
|
||||||
|
@ -250,6 +250,9 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
input_ids_forward_micro_batch = data["input_ids"][
|
input_ids_forward_micro_batch = data["input_ids"][
|
||||||
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||||
]
|
]
|
||||||
|
old_action_log_probs_micro_batch = old_action_log_probs[
|
||||||
|
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||||
|
]
|
||||||
attention_mask_forward_micro_batch = data["attention_mask"][
|
attention_mask_forward_micro_batch = data["attention_mask"][
|
||||||
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||||
]
|
]
|
||||||
@ -306,17 +309,22 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
"action_mask": action_mask_forward_micro_batch,
|
"action_mask": action_mask_forward_micro_batch,
|
||||||
"advantages": advantages_forward_micro_batch,
|
"advantages": advantages_forward_micro_batch,
|
||||||
"loss_mask": loss_mask_forward_micro_batch,
|
"loss_mask": loss_mask_forward_micro_batch,
|
||||||
|
"old_action_log_probs": old_action_log_probs_micro_batch,
|
||||||
"source": self.rank,
|
"source": self.rank,
|
||||||
}
|
}
|
||||||
if reference_action_log_probs is not None:
|
if reference_action_log_probs is not None:
|
||||||
data_policy_forward["reference_action_log_probs"] = reference_action_log_probs
|
data_policy_forward["reference_action_log_probs"] = reference_action_log_probs
|
||||||
|
|
||||||
kl = []
|
kl = []
|
||||||
policy_model_logits = torch.empty_like(input_ids_forward_micro_batch, device=self.device)
|
|
||||||
|
|
||||||
def _criterion(outputs, inputs):
|
def _criterion(outputs, inputs):
|
||||||
action_logits = outputs.logits
|
action_logits = outputs.logits
|
||||||
policy_model_logits.copy_(action_logits)
|
mini_batch_entropies.append(
|
||||||
|
(
|
||||||
|
((entropy_from_logits(action_logits[:, -num_action:]) * inputs["action_mask"]).sum(-1))
|
||||||
|
/ inputs["action_mask"].sum(-1)
|
||||||
|
).detach()
|
||||||
|
)
|
||||||
action_log_probs = memory_efficient_logprob(
|
action_log_probs = memory_efficient_logprob(
|
||||||
action_logits / self.generate_config["temperature"],
|
action_logits / self.generate_config["temperature"],
|
||||||
inputs["input_ids"],
|
inputs["input_ids"],
|
||||||
@ -339,7 +347,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
|
|
||||||
loss, _ = self.policy_loss_fn(
|
loss, _ = self.policy_loss_fn(
|
||||||
action_log_probs,
|
action_log_probs,
|
||||||
action_log_probs,
|
inputs["old_action_log_probs"],
|
||||||
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
|
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||||
per_token_kl,
|
per_token_kl,
|
||||||
inputs["action_mask"],
|
inputs["action_mask"],
|
||||||
@ -415,7 +423,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
|
|
||||||
loss, _ = self.policy_loss_fn(
|
loss, _ = self.policy_loss_fn(
|
||||||
action_log_probs,
|
action_log_probs,
|
||||||
old_action_log_probs,
|
old_action_log_probs_micro_batch,
|
||||||
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
|
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||||
per_token_kl,
|
per_token_kl,
|
||||||
action_mask_forward_micro_batch,
|
action_mask_forward_micro_batch,
|
||||||
@ -455,7 +463,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin)
|
ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin)
|
||||||
advantages = all_reduce_mean(advantages.mean(), self.plugin)
|
advantages = all_reduce_mean(advantages.mean(), self.plugin)
|
||||||
response_length = all_reduce_mean(response_length.mean(), self.plugin)
|
response_length = all_reduce_mean(response_length.mean(), self.plugin)
|
||||||
entropy = torch.cat(mini_batch_entropies, dim=0).mean()
|
entropy = all_reduce_mean(torch.cat(mini_batch_entropies, dim=0).mean(), self.plugin)
|
||||||
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
|
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
|
||||||
self.accum_entropy.add_(entropy.data)
|
self.accum_entropy.add_(entropy.data)
|
||||||
if self.policy_loss_fn.beta > 0:
|
if self.policy_loss_fn.beta > 0:
|
||||||
|
@ -59,6 +59,7 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
|||||||
generate_config: Dict[str, Any],
|
generate_config: Dict[str, Any],
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
num_generations: int = 8,
|
num_generations: int = 8,
|
||||||
|
tokenizer_config: Dict[str, Any] = None,
|
||||||
):
|
):
|
||||||
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
|
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
|
||||||
model_config.update(self.FORCE_MODEL_CONFIG)
|
model_config.update(self.FORCE_MODEL_CONFIG)
|
||||||
@ -132,6 +133,7 @@ class SGLangInferenceBackend(BaseInferenceBackend):
|
|||||||
generate_config: Dict[str, Any],
|
generate_config: Dict[str, Any],
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
num_generations: int = 8,
|
num_generations: int = 8,
|
||||||
|
tokenizer_config: Dict[str, Any] = None,
|
||||||
):
|
):
|
||||||
if sgl is None:
|
if sgl is None:
|
||||||
raise ImportError("sglang is not installed")
|
raise ImportError("sglang is not installed")
|
||||||
@ -196,12 +198,14 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
|||||||
generate_config: Dict[str, Any],
|
generate_config: Dict[str, Any],
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
num_generations: int = 8,
|
num_generations: int = 8,
|
||||||
|
tokenizer_config: Dict[str, Any] = None,
|
||||||
):
|
):
|
||||||
if LLM is None:
|
if LLM is None:
|
||||||
raise ImportError("vllm is not installed")
|
raise ImportError("vllm is not installed")
|
||||||
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
|
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
|
||||||
path = model_config.pop("path")
|
path = model_config.pop("path")
|
||||||
self.llm = LLM(model=path, **model_config)
|
tokenizer_path = tokenizer_config.get("path", None) if tokenizer_config is not None else None
|
||||||
|
self.llm = LLM(model=path, tokenizer=tokenizer_path, **model_config)
|
||||||
generate_config = generate_config.copy()
|
generate_config = generate_config.copy()
|
||||||
generate_config.update(self.FORCE_GENERATE_CONFIG)
|
generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||||
generate_config.update({"n": num_generations})
|
generate_config.update({"n": num_generations})
|
||||||
|
@ -130,7 +130,7 @@ def launch_distributed(
|
|||||||
train_dataset_config=train_dataset_config,
|
train_dataset_config=train_dataset_config,
|
||||||
model_config=inference_model_config,
|
model_config=inference_model_config,
|
||||||
generate_config=generate_config,
|
generate_config=generate_config,
|
||||||
tokenizer_config=tokenizer_config,
|
tokenizer_config=copy.deepcopy(tokenizer_config),
|
||||||
microbatch_size=inference_microbatch_size,
|
microbatch_size=inference_microbatch_size,
|
||||||
backend=inference_backend,
|
backend=inference_backend,
|
||||||
num_generations=num_generations,
|
num_generations=num_generations,
|
||||||
@ -158,8 +158,6 @@ def launch_distributed(
|
|||||||
consumer_master_ip_address = gpu_to_ip_address[0]
|
consumer_master_ip_address = gpu_to_ip_address[0]
|
||||||
print(f"Use {consumer_master_ip_address} as master address for torch DDP.")
|
print(f"Use {consumer_master_ip_address} as master address for torch DDP.")
|
||||||
consumer_procs = []
|
consumer_procs = []
|
||||||
if num_consumer_procs <= 1:
|
|
||||||
raise ValueError("Number of consumer processes should be greater than 1 for async rl training.")
|
|
||||||
for i in range(num_consumer_procs):
|
for i in range(num_consumer_procs):
|
||||||
node_id = gpu_to_node_id[0]
|
node_id = gpu_to_node_id[0]
|
||||||
consumer_ip_address = gpu_to_ip_address[0]
|
consumer_ip_address = gpu_to_ip_address[0]
|
||||||
@ -180,6 +178,7 @@ def launch_distributed(
|
|||||||
model_config=train_model_config,
|
model_config=train_model_config,
|
||||||
plugin_config=plugin_config,
|
plugin_config=plugin_config,
|
||||||
minibatch_size=train_minibatch_size,
|
minibatch_size=train_minibatch_size,
|
||||||
|
tokenizer_config=copy.deepcopy(tokenizer_config),
|
||||||
generate_config=generate_config_consumer,
|
generate_config=generate_config_consumer,
|
||||||
grpo_config=grpo_config,
|
grpo_config=grpo_config,
|
||||||
num_generations=num_generations,
|
num_generations=num_generations,
|
||||||
|
@ -35,9 +35,9 @@ class PolicyLoss(nn.Module):
|
|||||||
total_effective_tokens_in_batch: torch.Tensor = None,
|
total_effective_tokens_in_batch: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if action_mask is None:
|
if action_mask is None:
|
||||||
ratio = (log_probs - log_probs.detach()).exp()
|
ratio = (log_probs - old_log_probs.detach()).exp()
|
||||||
else:
|
else:
|
||||||
ratio = ((log_probs - log_probs.detach()) * action_mask).exp()
|
ratio = ((log_probs - old_log_probs.detach()) * action_mask).exp()
|
||||||
|
|
||||||
surr1 = ratio * advantages
|
surr1 = ratio * advantages
|
||||||
surr2 = ratio.clamp(1 - self.clip_eps_low, 1 + self.clip_eps_high) * advantages
|
surr2 = ratio.clamp(1 - self.clip_eps_low, 1 + self.clip_eps_high) * advantages
|
||||||
|
@ -7,7 +7,9 @@ import ray
|
|||||||
import ray.util.collective as cc
|
import ray.util.collective as cc
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
|
||||||
from coati.distributed.profiling_utils import CustomProfiler
|
from coati.distributed.profiling_utils import CustomProfiler
|
||||||
|
from coati.distributed.utils import bind_batch, post_recv, unbind_batch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
@ -15,9 +17,6 @@ from colossalai.booster.plugin import HybridParallelPlugin
|
|||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
|
|
||||||
from coati.distributed.utils import bind_batch, post_recv, unbind_batch
|
|
||||||
|
|
||||||
|
|
||||||
class BaseConsumer:
|
class BaseConsumer:
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -175,14 +174,15 @@ class BaseConsumer:
|
|||||||
raw_batch = ray.get(
|
raw_batch = ray.get(
|
||||||
self.shared_sync_data_actor.get_data.remote(self.data_uid)
|
self.shared_sync_data_actor.get_data.remote(self.data_uid)
|
||||||
) # get the first queued data
|
) # get the first queued data
|
||||||
|
self.profiler.log(f"enter sleep")
|
||||||
while raw_batch is None:
|
while raw_batch is None:
|
||||||
self.profiler.log(f"No data received by consumer {self.rank}, skipping")
|
|
||||||
print(
|
print(
|
||||||
f"[T{dist.get_rank()}] No data received by consumer {self.rank}, skipping. Consider increasing the data actor buffer limit"
|
f"[T{dist.get_rank()}] No data received by consumer {self.rank}, skipping. Consider increasing the data actor buffer limit"
|
||||||
)
|
)
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
raw_batch = ray.get(self.shared_sync_data_actor.get_data.remote(self.data_uid))
|
raw_batch = ray.get(self.shared_sync_data_actor.get_data.remote(self.data_uid))
|
||||||
continue
|
continue
|
||||||
|
self.profiler.log(f"exit sleep")
|
||||||
self.data_uid += 1
|
self.data_uid += 1
|
||||||
raw_batch = {k: v.to(self.device) for k, v in raw_batch.items()}
|
raw_batch = {k: v.to(self.device) for k, v in raw_batch.items()}
|
||||||
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
|
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
|
||||||
|
@ -3,12 +3,11 @@ import time
|
|||||||
import ray
|
import ray
|
||||||
import ray.util.collective as cc
|
import ray.util.collective as cc
|
||||||
import torch
|
import torch
|
||||||
|
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
|
||||||
from coati.distributed.profiling_utils import CustomProfiler
|
from coati.distributed.profiling_utils import CustomProfiler
|
||||||
|
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
|
|
||||||
|
|
||||||
|
|
||||||
@ray.remote
|
@ray.remote
|
||||||
class Distributor:
|
class Distributor:
|
||||||
@ -21,6 +20,7 @@ class Distributor:
|
|||||||
enable_profiling: bool = True,
|
enable_profiling: bool = True,
|
||||||
):
|
):
|
||||||
self.distributor_id = distributor_id
|
self.distributor_id = distributor_id
|
||||||
|
self.weight_version = [0] * consumer_pp_size
|
||||||
self.consumer_pp_size = consumer_pp_size
|
self.consumer_pp_size = consumer_pp_size
|
||||||
self.state_dict_cpu = {}
|
self.state_dict_cpu = {}
|
||||||
self.num_producers = num_producers
|
self.num_producers = num_producers
|
||||||
@ -42,14 +42,17 @@ class Distributor:
|
|||||||
print(f"[D] Initialized {group_name} collective group", flush=True)
|
print(f"[D] Initialized {group_name} collective group", flush=True)
|
||||||
|
|
||||||
def loop(self):
|
def loop(self):
|
||||||
|
last_weight_version = self.get_weight_version()
|
||||||
while True:
|
while True:
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
signal = ray.get(self.shared_signal_actor.get_signal.remote())
|
signal = ray.get(self.shared_signal_actor.get_signal.remote())
|
||||||
if self.consumer_pp_size > 1:
|
if self.consumer_pp_size > 1:
|
||||||
for i in range(self.consumer_pp_size):
|
if all(
|
||||||
if signal.get(f"consumer_pp_{i}", None) == "ready_sync_model":
|
[signal.get(f"consumer_pp_{i}", None) == "ready_sync_model" for i in range(self.consumer_pp_size)]
|
||||||
|
):
|
||||||
|
cc.barrier(group_name="distributor_pg")
|
||||||
|
for i in range(self.consumer_pp_size):
|
||||||
self.profiler.enter(f"sync_model_consumer_pp_{i}")
|
self.profiler.enter(f"sync_model_consumer_pp_{i}")
|
||||||
cc.barrier(group_name="distributor_pg")
|
|
||||||
ray.get(self.shared_signal_actor.set_signal.remote(f"consumer_pp_{i}", "not_ready_sync_model"))
|
ray.get(self.shared_signal_actor.set_signal.remote(f"consumer_pp_{i}", "not_ready_sync_model"))
|
||||||
# Broadcast the model state dict from consumer to shared variable actor
|
# Broadcast the model state dict from consumer to shared variable actor
|
||||||
self.state_dict_cpu[i] = ray_broadcast_tensor_dict(
|
self.state_dict_cpu[i] = ray_broadcast_tensor_dict(
|
||||||
@ -60,6 +63,7 @@ class Distributor:
|
|||||||
backend="gloo",
|
backend="gloo",
|
||||||
)
|
)
|
||||||
self.profiler.exit(f"sync_model_consumer_pp_{i}")
|
self.profiler.exit(f"sync_model_consumer_pp_{i}")
|
||||||
|
self.weight_version[i] += 1
|
||||||
for i in range(self.consumer_pp_size):
|
for i in range(self.consumer_pp_size):
|
||||||
if signal.get(f"producer_{self.distributor_id}_pp_{i}", None) == "ready_sync_model":
|
if signal.get(f"producer_{self.distributor_id}_pp_{i}", None) == "ready_sync_model":
|
||||||
self.profiler.enter(f"sync_model_producer_{self.distributor_id}_pp_{i}")
|
self.profiler.enter(f"sync_model_producer_{self.distributor_id}_pp_{i}")
|
||||||
@ -87,6 +91,7 @@ class Distributor:
|
|||||||
None, 0, device=torch.device("cpu"), group_name="sync_model_consumer", backend="gloo"
|
None, 0, device=torch.device("cpu"), group_name="sync_model_consumer", backend="gloo"
|
||||||
)
|
)
|
||||||
self.profiler.exit("sync_model_consumer")
|
self.profiler.exit("sync_model_consumer")
|
||||||
|
self.weight_version[0] += 1
|
||||||
if signal.get(f"producer_{self.distributor_id}", None) == "ready_sync_model":
|
if signal.get(f"producer_{self.distributor_id}", None) == "ready_sync_model":
|
||||||
self.profiler.enter(f"sync_model_producer_{self.distributor_id}")
|
self.profiler.enter(f"sync_model_producer_{self.distributor_id}")
|
||||||
# Broadcast the model state dict to all producers
|
# Broadcast the model state dict to all producers
|
||||||
@ -106,3 +111,9 @@ class Distributor:
|
|||||||
if signal.get("consumer", None) == "terminate":
|
if signal.get("consumer", None) == "terminate":
|
||||||
self.profiler.log("terminate sync model worker")
|
self.profiler.log("terminate sync model worker")
|
||||||
break
|
break
|
||||||
|
if last_weight_version != self.get_weight_version():
|
||||||
|
last_weight_version = self.get_weight_version()
|
||||||
|
ray.get(self.shared_signal_actor.set_signal.remote("distributor_weight_version", last_weight_version))
|
||||||
|
|
||||||
|
def get_weight_version(self):
|
||||||
|
return min(self.weight_version)
|
||||||
|
@ -5,9 +5,9 @@ import ray
|
|||||||
import torch
|
import torch
|
||||||
import wandb
|
import wandb
|
||||||
from coati.distributed.comm import SharedVariableActor
|
from coati.distributed.comm import SharedVariableActor
|
||||||
from coati.distributed.zero_bubble.consumer import BaseConsumer
|
|
||||||
from coati.distributed.loss import PolicyLoss
|
from coati.distributed.loss import PolicyLoss
|
||||||
from coati.distributed.utils import memory_efficient_logprob
|
from coati.distributed.utils import entropy_from_logits, memory_efficient_logprob
|
||||||
|
from coati.distributed.zero_bubble.consumer import BaseConsumer
|
||||||
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
@ -33,6 +33,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
plugin_config,
|
plugin_config,
|
||||||
minibatch_size=1,
|
minibatch_size=1,
|
||||||
num_generations=8,
|
num_generations=8,
|
||||||
|
tokenizer_config=None,
|
||||||
generate_config=None,
|
generate_config=None,
|
||||||
grpo_config={},
|
grpo_config={},
|
||||||
save_interval: int = 100,
|
save_interval: int = 100,
|
||||||
@ -73,9 +74,11 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||||
self.policy_model.train()
|
self.policy_model.train()
|
||||||
self.policy_model.gradient_checkpointing_enable()
|
self.policy_model.gradient_checkpointing_enable()
|
||||||
|
self.vocab_size = self.policy_model.config.vocab_size
|
||||||
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
|
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
|
||||||
self.accum_loss = torch.zeros(1, device=self.device)
|
self.accum_loss = torch.zeros(1, device=self.device)
|
||||||
self.accum_kl = torch.zeros(1, device=self.device)
|
self.accum_kl = torch.zeros(1, device=self.device)
|
||||||
|
self.accum_entropy = torch.zeros(1, device=self.device)
|
||||||
self.accum_advantages = torch.zeros(1, device=self.device)
|
self.accum_advantages = torch.zeros(1, device=self.device)
|
||||||
self.raw_train_batch_reward = []
|
self.raw_train_batch_reward = []
|
||||||
self.raw_train_batch_format_acc = []
|
self.raw_train_batch_format_acc = []
|
||||||
@ -102,8 +105,11 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
if self.policy_loss_fn.beta > 0:
|
if self.policy_loss_fn.beta > 0:
|
||||||
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||||
self.reference_model.eval()
|
self.reference_model.eval()
|
||||||
|
if tokenizer_config is not None:
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
path = tokenizer_config.pop("path", None)
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(path, **tokenizer_config)
|
||||||
|
else:
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
||||||
self.pad_token_id = self.tokenizer.pad_token_id
|
self.pad_token_id = self.tokenizer.pad_token_id
|
||||||
self.num_generations = num_generations
|
self.num_generations = num_generations
|
||||||
self.filter_range = grpo_config.get("filter_range", None)
|
self.filter_range = grpo_config.get("filter_range", None)
|
||||||
@ -243,10 +249,14 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
else self.booster.no_sync(self.policy_model, self.optimizer)
|
else self.booster.no_sync(self.policy_model, self.optimizer)
|
||||||
)
|
)
|
||||||
with ctx:
|
with ctx:
|
||||||
|
mini_batch_entropies = []
|
||||||
for forward_micro_batch_start in range(0, data["input_ids"].size(0), train_microbatch_size):
|
for forward_micro_batch_start in range(0, data["input_ids"].size(0), train_microbatch_size):
|
||||||
input_ids_forward_micro_batch = data["input_ids"][
|
input_ids_forward_micro_batch = data["input_ids"][
|
||||||
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||||
]
|
]
|
||||||
|
old_action_log_probs_micro_batch = old_action_log_probs[
|
||||||
|
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||||
|
]
|
||||||
attention_mask_forward_micro_batch = data["attention_mask"][
|
attention_mask_forward_micro_batch = data["attention_mask"][
|
||||||
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||||
]
|
]
|
||||||
@ -303,6 +313,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
"action_mask": action_mask_forward_micro_batch,
|
"action_mask": action_mask_forward_micro_batch,
|
||||||
"advantages": advantages_forward_micro_batch,
|
"advantages": advantages_forward_micro_batch,
|
||||||
"loss_mask": loss_mask_forward_micro_batch,
|
"loss_mask": loss_mask_forward_micro_batch,
|
||||||
|
"old_action_log_probs": old_action_log_probs_micro_batch,
|
||||||
"source": self.rank,
|
"source": self.rank,
|
||||||
}
|
}
|
||||||
if reference_action_log_probs is not None:
|
if reference_action_log_probs is not None:
|
||||||
@ -312,6 +323,12 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
|
|
||||||
def _criterion(outputs, inputs):
|
def _criterion(outputs, inputs):
|
||||||
action_logits = outputs.logits
|
action_logits = outputs.logits
|
||||||
|
mini_batch_entropies.append(
|
||||||
|
(
|
||||||
|
((entropy_from_logits(action_logits[:, -num_action:]) * inputs["action_mask"]).sum(-1))
|
||||||
|
/ inputs["action_mask"].sum(-1)
|
||||||
|
).detach()
|
||||||
|
)
|
||||||
action_log_probs = memory_efficient_logprob(
|
action_log_probs = memory_efficient_logprob(
|
||||||
action_logits / self.generate_config["temperature"],
|
action_logits / self.generate_config["temperature"],
|
||||||
inputs["input_ids"],
|
inputs["input_ids"],
|
||||||
@ -334,7 +351,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
|
|
||||||
loss, _ = self.policy_loss_fn(
|
loss, _ = self.policy_loss_fn(
|
||||||
action_log_probs,
|
action_log_probs,
|
||||||
action_log_probs,
|
inputs["old_action_log_probs"],
|
||||||
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
|
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||||
per_token_kl,
|
per_token_kl,
|
||||||
inputs["action_mask"],
|
inputs["action_mask"],
|
||||||
@ -396,7 +413,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
|
|
||||||
loss, _ = self.policy_loss_fn(
|
loss, _ = self.policy_loss_fn(
|
||||||
action_log_probs,
|
action_log_probs,
|
||||||
old_action_log_probs,
|
old_action_log_probs_micro_batch,
|
||||||
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
|
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||||
per_token_kl,
|
per_token_kl,
|
||||||
action_mask_forward_micro_batch,
|
action_mask_forward_micro_batch,
|
||||||
@ -411,6 +428,20 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
kl = all_reduce_mean(kl.mean(), self.plugin)
|
kl = all_reduce_mean(kl.mean(), self.plugin)
|
||||||
mean_kl.append(kl.data)
|
mean_kl.append(kl.data)
|
||||||
mean_loss.append(loss.data)
|
mean_loss.append(loss.data)
|
||||||
|
mini_batch_entropies.append(
|
||||||
|
all_reduce_mean(
|
||||||
|
(
|
||||||
|
(
|
||||||
|
(
|
||||||
|
entropy_from_logits(policy_model_logits[:, -num_action:])
|
||||||
|
* action_mask_forward_micro_batch
|
||||||
|
).sum(-1)
|
||||||
|
)
|
||||||
|
/ action_mask_forward_micro_batch.sum(-1)
|
||||||
|
).detach(),
|
||||||
|
self.plugin,
|
||||||
|
)
|
||||||
|
)
|
||||||
if not self.plugin.pp_size > 1 or (
|
if not self.plugin.pp_size > 1 or (
|
||||||
self.plugin.pp_size > 1
|
self.plugin.pp_size > 1
|
||||||
and self.booster.plugin.stage_manager.is_last_stage()
|
and self.booster.plugin.stage_manager.is_last_stage()
|
||||||
@ -422,7 +453,9 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin)
|
ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin)
|
||||||
advantages = all_reduce_mean(advantages.mean(), self.plugin)
|
advantages = all_reduce_mean(advantages.mean(), self.plugin)
|
||||||
response_length = all_reduce_mean(response_length.mean(), self.plugin)
|
response_length = all_reduce_mean(response_length.mean(), self.plugin)
|
||||||
|
entropy = all_reduce_mean(torch.cat(mini_batch_entropies, dim=0).mean(), self.plugin)
|
||||||
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
|
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
|
||||||
|
self.accum_entropy.add_(entropy.data)
|
||||||
if self.policy_loss_fn.beta > 0:
|
if self.policy_loss_fn.beta > 0:
|
||||||
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
|
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
|
||||||
self.accum_advantages.add_(advantages.data)
|
self.accum_advantages.add_(advantages.data)
|
||||||
@ -465,6 +498,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
f"Response Length: {raw_batch_response_len_mean:.4f}",
|
f"Response Length: {raw_batch_response_len_mean:.4f}",
|
||||||
f"Sample_utilization: {sample_utilization:.4f}",
|
f"Sample_utilization: {sample_utilization:.4f}",
|
||||||
f"Overlength samples ratio: {overlength_samples_ratio:.4f}",
|
f"Overlength samples ratio: {overlength_samples_ratio:.4f}",
|
||||||
|
f"Entropy: {self.accum_entropy.item() / self.accum_count:.4f}",
|
||||||
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
|
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
|
||||||
print("\n".join(to_log_msg))
|
print("\n".join(to_log_msg))
|
||||||
metrics = {
|
metrics = {
|
||||||
@ -476,6 +510,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
||||||
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
||||||
"train/sample_utilization": sample_utilization,
|
"train/sample_utilization": sample_utilization,
|
||||||
|
"train/entropy": self.accum_entropy.item() / self.accum_count,
|
||||||
"train/overlength_samples_ratio": overlength_samples_ratio,
|
"train/overlength_samples_ratio": overlength_samples_ratio,
|
||||||
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
|
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
|
||||||
}
|
}
|
||||||
@ -483,8 +518,10 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
metrics["train/kl"] = self.accum_kl.item() / self.accum_count
|
metrics["train/kl"] = self.accum_kl.item() / self.accum_count
|
||||||
if self.wandb_run is not None:
|
if self.wandb_run is not None:
|
||||||
self.wandb_run.log(metrics)
|
self.wandb_run.log(metrics)
|
||||||
|
ray.get(self.shared_signal_actor.set_signal.remote("sample_utilization", sample_utilization))
|
||||||
self.accum_loss.zero_()
|
self.accum_loss.zero_()
|
||||||
self.accum_kl.zero_()
|
self.accum_kl.zero_()
|
||||||
|
self.accum_entropy.zero_()
|
||||||
self.accum_advantages.zero_()
|
self.accum_advantages.zero_()
|
||||||
self.accum_count = 0
|
self.accum_count = 0
|
||||||
return loss_scalar
|
return loss_scalar
|
||||||
|
@ -11,9 +11,12 @@ import torch
|
|||||||
import tqdm
|
import tqdm
|
||||||
import wandb
|
import wandb
|
||||||
from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
|
from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
|
||||||
|
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
|
||||||
|
from coati.distributed.inference_backend import BACKEND_MAP
|
||||||
from coati.distributed.profiling_utils import CustomProfiler
|
from coati.distributed.profiling_utils import CustomProfiler
|
||||||
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
|
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
|
||||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||||
|
from coati.distributed.utils import pre_send, safe_append_to_jsonl_file
|
||||||
from ray.util.collective import allreduce
|
from ray.util.collective import allreduce
|
||||||
from ray.util.collective.types import ReduceOp
|
from ray.util.collective.types import ReduceOp
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
@ -21,10 +24,6 @@ from transformers import AutoTokenizer
|
|||||||
|
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
|
|
||||||
from coati.distributed.inference_backend import BACKEND_MAP
|
|
||||||
from coati.distributed.utils import pre_send, safe_append_to_jsonl_file
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -280,11 +279,17 @@ class BaseProducer:
|
|||||||
self.profiler.exit("sync_model")
|
self.profiler.exit("sync_model")
|
||||||
self.sync_model_thread_started = False
|
self.sync_model_thread_started = False
|
||||||
|
|
||||||
if not self.sync_model_thread_started and self.consumer_global_step != self.producer_weight_version:
|
distributor_weight_version = ray.get(self.shared_signal_actor.get_signal.remote()).get(
|
||||||
|
f"distributor_weight_version", 0
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
not self.sync_model_thread_started
|
||||||
|
and distributor_weight_version != self.producer_weight_version
|
||||||
|
):
|
||||||
# only sync model when the thread is not started and global step is changed
|
# only sync model when the thread is not started and global step is changed
|
||||||
self.sync_model_thread_started = True
|
self.sync_model_thread_started = True
|
||||||
self.sync_model_thread = threading.Thread(target=sync_model_thread)
|
self.sync_model_thread = threading.Thread(target=sync_model_thread)
|
||||||
self.producer_weight_version = self.consumer_global_step
|
self.producer_weight_version = distributor_weight_version
|
||||||
self.sync_model_thread.start()
|
self.sync_model_thread.start()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
|
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
|
||||||
@ -478,7 +483,7 @@ class SimpleProducer(BaseProducer):
|
|||||||
train_dataset_config,
|
train_dataset_config,
|
||||||
model_config,
|
model_config,
|
||||||
generate_config,
|
generate_config,
|
||||||
tokenizer_config,
|
copy.deepcopy(tokenizer_config),
|
||||||
microbatch_size,
|
microbatch_size,
|
||||||
backend,
|
backend,
|
||||||
consumer_plugin_config,
|
consumer_plugin_config,
|
||||||
@ -493,7 +498,8 @@ class SimpleProducer(BaseProducer):
|
|||||||
rollout_log_file=rollout_log_file,
|
rollout_log_file=rollout_log_file,
|
||||||
enable_profiling=enable_profiling,
|
enable_profiling=enable_profiling,
|
||||||
)
|
)
|
||||||
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
|
print("tokenizer_config", tokenizer_config)
|
||||||
|
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations, tokenizer_config)
|
||||||
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
|
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
|
||||||
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
|
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
|
||||||
self.eval_generation_config.update(eval_generation_config)
|
self.eval_generation_config.update(eval_generation_config)
|
||||||
|
@ -15,6 +15,12 @@ DEFAUT_SYSTEM_PROMPT = {
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokenizer-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to the tokenizer. If not provided, will use the model path.",
|
||||||
|
)
|
||||||
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-ed",
|
"-ed",
|
||||||
@ -166,6 +172,7 @@ if __name__ == "__main__":
|
|||||||
"--enable_profiling", action="store_true", default=False, help="Enable profiling for the training process."
|
"--enable_profiling", action="store_true", default=False, help="Enable profiling for the training process."
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
print(args)
|
||||||
|
|
||||||
if args.train_minibatch_size is None:
|
if args.train_minibatch_size is None:
|
||||||
# Default settings: Using train batch size as mini batch size
|
# Default settings: Using train batch size as mini batch size
|
||||||
@ -283,7 +290,7 @@ if __name__ == "__main__":
|
|||||||
elif args.algo == "DAPO":
|
elif args.algo == "DAPO":
|
||||||
# DAPO variant settings
|
# DAPO variant settings
|
||||||
grpo_config = {
|
grpo_config = {
|
||||||
"filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch
|
"filter_range": [0.01, 0.7], # only filter out all zero batch and all one batch
|
||||||
"lr": args.learning_rate,
|
"lr": args.learning_rate,
|
||||||
"train_microbatch_size": args.train_microbatch_size,
|
"train_microbatch_size": args.train_microbatch_size,
|
||||||
"dynamic_batching": True,
|
"dynamic_batching": True,
|
||||||
@ -343,7 +350,9 @@ if __name__ == "__main__":
|
|||||||
), # microbatch size should be set to train_microbatch_size // pp_size
|
), # microbatch size should be set to train_microbatch_size // pp_size
|
||||||
"zero_stage": args.zero_stage,
|
"zero_stage": args.zero_stage,
|
||||||
"max_norm": 1.0,
|
"max_norm": 1.0,
|
||||||
|
"num_layers_per_stage": [18, 10],
|
||||||
}, # for pp, tp
|
}, # for pp, tp
|
||||||
|
tokenizer_config={"path": args.tokenizer_path} if args.tokenizer_path else {"path": args.model},
|
||||||
inference_backend=args.backend,
|
inference_backend=args.backend,
|
||||||
master_addr="localhost",
|
master_addr="localhost",
|
||||||
master_port=args.master_port,
|
master_port=args.master_port,
|
||||||
|
Loading…
Reference in New Issue
Block a user