mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-22 19:23:16 +00:00
fix racing condition
This commit is contained in:
parent
f54ae56f12
commit
e774edeb80
@ -1,5 +1,6 @@
|
||||
from typing import Any, Dict
|
||||
import copy
|
||||
from typing import Any, Dict
|
||||
|
||||
import ray
|
||||
import ray.util.collective as cc
|
||||
import torch
|
||||
@ -31,6 +32,7 @@ def ray_broadcast_object(obj: Any, src: int = 0, device=None, group_name: str =
|
||||
obj = c10d._tensor_to_object(obj, size_tensor.item())
|
||||
return obj
|
||||
|
||||
|
||||
def ray_broadcast_tensor_dict(
|
||||
tensor_dict: Dict[str, torch.Tensor],
|
||||
src: int = 0,
|
||||
@ -98,7 +100,7 @@ class SharedVariableActor:
|
||||
queue length as data may still be generating
|
||||
"""
|
||||
ret = False
|
||||
if self.queue_size < self.buffer_size_limit:
|
||||
if self.queue_size < (self.buffer_size_limit / max(0.1, self.signals.get("sample_utilization", 1.0))):
|
||||
ret = True
|
||||
self.queue_size += num_tasks
|
||||
return ret
|
||||
|
@ -250,6 +250,9 @@ class GRPOConsumer(BaseConsumer):
|
||||
input_ids_forward_micro_batch = data["input_ids"][
|
||||
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||
]
|
||||
old_action_log_probs_micro_batch = old_action_log_probs[
|
||||
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||
]
|
||||
attention_mask_forward_micro_batch = data["attention_mask"][
|
||||
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||
]
|
||||
@ -306,17 +309,22 @@ class GRPOConsumer(BaseConsumer):
|
||||
"action_mask": action_mask_forward_micro_batch,
|
||||
"advantages": advantages_forward_micro_batch,
|
||||
"loss_mask": loss_mask_forward_micro_batch,
|
||||
"old_action_log_probs": old_action_log_probs_micro_batch,
|
||||
"source": self.rank,
|
||||
}
|
||||
if reference_action_log_probs is not None:
|
||||
data_policy_forward["reference_action_log_probs"] = reference_action_log_probs
|
||||
|
||||
kl = []
|
||||
policy_model_logits = torch.empty_like(input_ids_forward_micro_batch, device=self.device)
|
||||
|
||||
def _criterion(outputs, inputs):
|
||||
action_logits = outputs.logits
|
||||
policy_model_logits.copy_(action_logits)
|
||||
mini_batch_entropies.append(
|
||||
(
|
||||
((entropy_from_logits(action_logits[:, -num_action:]) * inputs["action_mask"]).sum(-1))
|
||||
/ inputs["action_mask"].sum(-1)
|
||||
).detach()
|
||||
)
|
||||
action_log_probs = memory_efficient_logprob(
|
||||
action_logits / self.generate_config["temperature"],
|
||||
inputs["input_ids"],
|
||||
@ -339,7 +347,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
|
||||
loss, _ = self.policy_loss_fn(
|
||||
action_log_probs,
|
||||
action_log_probs,
|
||||
inputs["old_action_log_probs"],
|
||||
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||
per_token_kl,
|
||||
inputs["action_mask"],
|
||||
@ -415,7 +423,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
|
||||
loss, _ = self.policy_loss_fn(
|
||||
action_log_probs,
|
||||
old_action_log_probs,
|
||||
old_action_log_probs_micro_batch,
|
||||
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||
per_token_kl,
|
||||
action_mask_forward_micro_batch,
|
||||
@ -455,7 +463,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin)
|
||||
advantages = all_reduce_mean(advantages.mean(), self.plugin)
|
||||
response_length = all_reduce_mean(response_length.mean(), self.plugin)
|
||||
entropy = torch.cat(mini_batch_entropies, dim=0).mean()
|
||||
entropy = all_reduce_mean(torch.cat(mini_batch_entropies, dim=0).mean(), self.plugin)
|
||||
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
|
||||
self.accum_entropy.add_(entropy.data)
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
|
@ -59,6 +59,7 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
||||
generate_config: Dict[str, Any],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
num_generations: int = 8,
|
||||
tokenizer_config: Dict[str, Any] = None,
|
||||
):
|
||||
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
|
||||
model_config.update(self.FORCE_MODEL_CONFIG)
|
||||
@ -132,6 +133,7 @@ class SGLangInferenceBackend(BaseInferenceBackend):
|
||||
generate_config: Dict[str, Any],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
num_generations: int = 8,
|
||||
tokenizer_config: Dict[str, Any] = None,
|
||||
):
|
||||
if sgl is None:
|
||||
raise ImportError("sglang is not installed")
|
||||
@ -196,12 +198,14 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
generate_config: Dict[str, Any],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
num_generations: int = 8,
|
||||
tokenizer_config: Dict[str, Any] = None,
|
||||
):
|
||||
if LLM is None:
|
||||
raise ImportError("vllm is not installed")
|
||||
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
|
||||
path = model_config.pop("path")
|
||||
self.llm = LLM(model=path, **model_config)
|
||||
tokenizer_path = tokenizer_config.get("path", None) if tokenizer_config is not None else None
|
||||
self.llm = LLM(model=path, tokenizer=tokenizer_path, **model_config)
|
||||
generate_config = generate_config.copy()
|
||||
generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||
generate_config.update({"n": num_generations})
|
||||
|
@ -130,7 +130,7 @@ def launch_distributed(
|
||||
train_dataset_config=train_dataset_config,
|
||||
model_config=inference_model_config,
|
||||
generate_config=generate_config,
|
||||
tokenizer_config=tokenizer_config,
|
||||
tokenizer_config=copy.deepcopy(tokenizer_config),
|
||||
microbatch_size=inference_microbatch_size,
|
||||
backend=inference_backend,
|
||||
num_generations=num_generations,
|
||||
@ -158,8 +158,6 @@ def launch_distributed(
|
||||
consumer_master_ip_address = gpu_to_ip_address[0]
|
||||
print(f"Use {consumer_master_ip_address} as master address for torch DDP.")
|
||||
consumer_procs = []
|
||||
if num_consumer_procs <= 1:
|
||||
raise ValueError("Number of consumer processes should be greater than 1 for async rl training.")
|
||||
for i in range(num_consumer_procs):
|
||||
node_id = gpu_to_node_id[0]
|
||||
consumer_ip_address = gpu_to_ip_address[0]
|
||||
@ -180,6 +178,7 @@ def launch_distributed(
|
||||
model_config=train_model_config,
|
||||
plugin_config=plugin_config,
|
||||
minibatch_size=train_minibatch_size,
|
||||
tokenizer_config=copy.deepcopy(tokenizer_config),
|
||||
generate_config=generate_config_consumer,
|
||||
grpo_config=grpo_config,
|
||||
num_generations=num_generations,
|
||||
|
@ -35,9 +35,9 @@ class PolicyLoss(nn.Module):
|
||||
total_effective_tokens_in_batch: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if action_mask is None:
|
||||
ratio = (log_probs - log_probs.detach()).exp()
|
||||
ratio = (log_probs - old_log_probs.detach()).exp()
|
||||
else:
|
||||
ratio = ((log_probs - log_probs.detach()) * action_mask).exp()
|
||||
ratio = ((log_probs - old_log_probs.detach()) * action_mask).exp()
|
||||
|
||||
surr1 = ratio * advantages
|
||||
surr2 = ratio.clamp(1 - self.clip_eps_low, 1 + self.clip_eps_high) * advantages
|
||||
|
@ -7,7 +7,9 @@ import ray
|
||||
import ray.util.collective as cc
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
|
||||
from coati.distributed.profiling_utils import CustomProfiler
|
||||
from coati.distributed.utils import bind_batch, post_recv, unbind_batch
|
||||
from tqdm import tqdm
|
||||
|
||||
from colossalai.booster import Booster
|
||||
@ -15,9 +17,6 @@ from colossalai.booster.plugin import HybridParallelPlugin
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
|
||||
from coati.distributed.utils import bind_batch, post_recv, unbind_batch
|
||||
|
||||
|
||||
class BaseConsumer:
|
||||
def __init__(
|
||||
@ -175,14 +174,15 @@ class BaseConsumer:
|
||||
raw_batch = ray.get(
|
||||
self.shared_sync_data_actor.get_data.remote(self.data_uid)
|
||||
) # get the first queued data
|
||||
self.profiler.log(f"enter sleep")
|
||||
while raw_batch is None:
|
||||
self.profiler.log(f"No data received by consumer {self.rank}, skipping")
|
||||
print(
|
||||
f"[T{dist.get_rank()}] No data received by consumer {self.rank}, skipping. Consider increasing the data actor buffer limit"
|
||||
)
|
||||
time.sleep(1)
|
||||
raw_batch = ray.get(self.shared_sync_data_actor.get_data.remote(self.data_uid))
|
||||
continue
|
||||
self.profiler.log(f"exit sleep")
|
||||
self.data_uid += 1
|
||||
raw_batch = {k: v.to(self.device) for k, v in raw_batch.items()}
|
||||
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
|
||||
|
@ -3,12 +3,11 @@ import time
|
||||
import ray
|
||||
import ray.util.collective as cc
|
||||
import torch
|
||||
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
|
||||
from coati.distributed.profiling_utils import CustomProfiler
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
|
||||
|
||||
|
||||
@ray.remote
|
||||
class Distributor:
|
||||
@ -21,6 +20,7 @@ class Distributor:
|
||||
enable_profiling: bool = True,
|
||||
):
|
||||
self.distributor_id = distributor_id
|
||||
self.weight_version = [0] * consumer_pp_size
|
||||
self.consumer_pp_size = consumer_pp_size
|
||||
self.state_dict_cpu = {}
|
||||
self.num_producers = num_producers
|
||||
@ -42,14 +42,17 @@ class Distributor:
|
||||
print(f"[D] Initialized {group_name} collective group", flush=True)
|
||||
|
||||
def loop(self):
|
||||
last_weight_version = self.get_weight_version()
|
||||
while True:
|
||||
time.sleep(1)
|
||||
signal = ray.get(self.shared_signal_actor.get_signal.remote())
|
||||
if self.consumer_pp_size > 1:
|
||||
for i in range(self.consumer_pp_size):
|
||||
if signal.get(f"consumer_pp_{i}", None) == "ready_sync_model":
|
||||
if all(
|
||||
[signal.get(f"consumer_pp_{i}", None) == "ready_sync_model" for i in range(self.consumer_pp_size)]
|
||||
):
|
||||
cc.barrier(group_name="distributor_pg")
|
||||
for i in range(self.consumer_pp_size):
|
||||
self.profiler.enter(f"sync_model_consumer_pp_{i}")
|
||||
cc.barrier(group_name="distributor_pg")
|
||||
ray.get(self.shared_signal_actor.set_signal.remote(f"consumer_pp_{i}", "not_ready_sync_model"))
|
||||
# Broadcast the model state dict from consumer to shared variable actor
|
||||
self.state_dict_cpu[i] = ray_broadcast_tensor_dict(
|
||||
@ -60,6 +63,7 @@ class Distributor:
|
||||
backend="gloo",
|
||||
)
|
||||
self.profiler.exit(f"sync_model_consumer_pp_{i}")
|
||||
self.weight_version[i] += 1
|
||||
for i in range(self.consumer_pp_size):
|
||||
if signal.get(f"producer_{self.distributor_id}_pp_{i}", None) == "ready_sync_model":
|
||||
self.profiler.enter(f"sync_model_producer_{self.distributor_id}_pp_{i}")
|
||||
@ -87,6 +91,7 @@ class Distributor:
|
||||
None, 0, device=torch.device("cpu"), group_name="sync_model_consumer", backend="gloo"
|
||||
)
|
||||
self.profiler.exit("sync_model_consumer")
|
||||
self.weight_version[0] += 1
|
||||
if signal.get(f"producer_{self.distributor_id}", None) == "ready_sync_model":
|
||||
self.profiler.enter(f"sync_model_producer_{self.distributor_id}")
|
||||
# Broadcast the model state dict to all producers
|
||||
@ -106,3 +111,9 @@ class Distributor:
|
||||
if signal.get("consumer", None) == "terminate":
|
||||
self.profiler.log("terminate sync model worker")
|
||||
break
|
||||
if last_weight_version != self.get_weight_version():
|
||||
last_weight_version = self.get_weight_version()
|
||||
ray.get(self.shared_signal_actor.set_signal.remote("distributor_weight_version", last_weight_version))
|
||||
|
||||
def get_weight_version(self):
|
||||
return min(self.weight_version)
|
||||
|
@ -5,9 +5,9 @@ import ray
|
||||
import torch
|
||||
import wandb
|
||||
from coati.distributed.comm import SharedVariableActor
|
||||
from coati.distributed.zero_bubble.consumer import BaseConsumer
|
||||
from coati.distributed.loss import PolicyLoss
|
||||
from coati.distributed.utils import memory_efficient_logprob
|
||||
from coati.distributed.utils import entropy_from_logits, memory_efficient_logprob
|
||||
from coati.distributed.zero_bubble.consumer import BaseConsumer
|
||||
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
@ -33,6 +33,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
plugin_config,
|
||||
minibatch_size=1,
|
||||
num_generations=8,
|
||||
tokenizer_config=None,
|
||||
generate_config=None,
|
||||
grpo_config={},
|
||||
save_interval: int = 100,
|
||||
@ -73,9 +74,11 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
self.policy_model.train()
|
||||
self.policy_model.gradient_checkpointing_enable()
|
||||
self.vocab_size = self.policy_model.config.vocab_size
|
||||
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
|
||||
self.accum_loss = torch.zeros(1, device=self.device)
|
||||
self.accum_kl = torch.zeros(1, device=self.device)
|
||||
self.accum_entropy = torch.zeros(1, device=self.device)
|
||||
self.accum_advantages = torch.zeros(1, device=self.device)
|
||||
self.raw_train_batch_reward = []
|
||||
self.raw_train_batch_format_acc = []
|
||||
@ -102,8 +105,11 @@ class GRPOConsumer(BaseConsumer):
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
self.reference_model.eval()
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
||||
if tokenizer_config is not None:
|
||||
path = tokenizer_config.pop("path", None)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(path, **tokenizer_config)
|
||||
else:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
||||
self.pad_token_id = self.tokenizer.pad_token_id
|
||||
self.num_generations = num_generations
|
||||
self.filter_range = grpo_config.get("filter_range", None)
|
||||
@ -243,10 +249,14 @@ class GRPOConsumer(BaseConsumer):
|
||||
else self.booster.no_sync(self.policy_model, self.optimizer)
|
||||
)
|
||||
with ctx:
|
||||
mini_batch_entropies = []
|
||||
for forward_micro_batch_start in range(0, data["input_ids"].size(0), train_microbatch_size):
|
||||
input_ids_forward_micro_batch = data["input_ids"][
|
||||
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||
]
|
||||
old_action_log_probs_micro_batch = old_action_log_probs[
|
||||
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||
]
|
||||
attention_mask_forward_micro_batch = data["attention_mask"][
|
||||
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||
]
|
||||
@ -303,6 +313,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
"action_mask": action_mask_forward_micro_batch,
|
||||
"advantages": advantages_forward_micro_batch,
|
||||
"loss_mask": loss_mask_forward_micro_batch,
|
||||
"old_action_log_probs": old_action_log_probs_micro_batch,
|
||||
"source": self.rank,
|
||||
}
|
||||
if reference_action_log_probs is not None:
|
||||
@ -312,6 +323,12 @@ class GRPOConsumer(BaseConsumer):
|
||||
|
||||
def _criterion(outputs, inputs):
|
||||
action_logits = outputs.logits
|
||||
mini_batch_entropies.append(
|
||||
(
|
||||
((entropy_from_logits(action_logits[:, -num_action:]) * inputs["action_mask"]).sum(-1))
|
||||
/ inputs["action_mask"].sum(-1)
|
||||
).detach()
|
||||
)
|
||||
action_log_probs = memory_efficient_logprob(
|
||||
action_logits / self.generate_config["temperature"],
|
||||
inputs["input_ids"],
|
||||
@ -334,7 +351,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
|
||||
loss, _ = self.policy_loss_fn(
|
||||
action_log_probs,
|
||||
action_log_probs,
|
||||
inputs["old_action_log_probs"],
|
||||
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||
per_token_kl,
|
||||
inputs["action_mask"],
|
||||
@ -396,7 +413,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
|
||||
loss, _ = self.policy_loss_fn(
|
||||
action_log_probs,
|
||||
old_action_log_probs,
|
||||
old_action_log_probs_micro_batch,
|
||||
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||
per_token_kl,
|
||||
action_mask_forward_micro_batch,
|
||||
@ -411,6 +428,20 @@ class GRPOConsumer(BaseConsumer):
|
||||
kl = all_reduce_mean(kl.mean(), self.plugin)
|
||||
mean_kl.append(kl.data)
|
||||
mean_loss.append(loss.data)
|
||||
mini_batch_entropies.append(
|
||||
all_reduce_mean(
|
||||
(
|
||||
(
|
||||
(
|
||||
entropy_from_logits(policy_model_logits[:, -num_action:])
|
||||
* action_mask_forward_micro_batch
|
||||
).sum(-1)
|
||||
)
|
||||
/ action_mask_forward_micro_batch.sum(-1)
|
||||
).detach(),
|
||||
self.plugin,
|
||||
)
|
||||
)
|
||||
if not self.plugin.pp_size > 1 or (
|
||||
self.plugin.pp_size > 1
|
||||
and self.booster.plugin.stage_manager.is_last_stage()
|
||||
@ -422,7 +453,9 @@ class GRPOConsumer(BaseConsumer):
|
||||
ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin)
|
||||
advantages = all_reduce_mean(advantages.mean(), self.plugin)
|
||||
response_length = all_reduce_mean(response_length.mean(), self.plugin)
|
||||
entropy = all_reduce_mean(torch.cat(mini_batch_entropies, dim=0).mean(), self.plugin)
|
||||
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
|
||||
self.accum_entropy.add_(entropy.data)
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
|
||||
self.accum_advantages.add_(advantages.data)
|
||||
@ -465,6 +498,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
f"Response Length: {raw_batch_response_len_mean:.4f}",
|
||||
f"Sample_utilization: {sample_utilization:.4f}",
|
||||
f"Overlength samples ratio: {overlength_samples_ratio:.4f}",
|
||||
f"Entropy: {self.accum_entropy.item() / self.accum_count:.4f}",
|
||||
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
|
||||
print("\n".join(to_log_msg))
|
||||
metrics = {
|
||||
@ -476,6 +510,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
||||
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
||||
"train/sample_utilization": sample_utilization,
|
||||
"train/entropy": self.accum_entropy.item() / self.accum_count,
|
||||
"train/overlength_samples_ratio": overlength_samples_ratio,
|
||||
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
|
||||
}
|
||||
@ -483,8 +518,10 @@ class GRPOConsumer(BaseConsumer):
|
||||
metrics["train/kl"] = self.accum_kl.item() / self.accum_count
|
||||
if self.wandb_run is not None:
|
||||
self.wandb_run.log(metrics)
|
||||
ray.get(self.shared_signal_actor.set_signal.remote("sample_utilization", sample_utilization))
|
||||
self.accum_loss.zero_()
|
||||
self.accum_kl.zero_()
|
||||
self.accum_entropy.zero_()
|
||||
self.accum_advantages.zero_()
|
||||
self.accum_count = 0
|
||||
return loss_scalar
|
||||
|
@ -11,9 +11,12 @@ import torch
|
||||
import tqdm
|
||||
import wandb
|
||||
from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
|
||||
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
|
||||
from coati.distributed.inference_backend import BACKEND_MAP
|
||||
from coati.distributed.profiling_utils import CustomProfiler
|
||||
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
|
||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||
from coati.distributed.utils import pre_send, safe_append_to_jsonl_file
|
||||
from ray.util.collective import allreduce
|
||||
from ray.util.collective.types import ReduceOp
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
@ -21,10 +24,6 @@ from transformers import AutoTokenizer
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
|
||||
from coati.distributed.inference_backend import BACKEND_MAP
|
||||
from coati.distributed.utils import pre_send, safe_append_to_jsonl_file
|
||||
|
||||
try:
|
||||
from vllm import SamplingParams
|
||||
except ImportError:
|
||||
@ -280,11 +279,17 @@ class BaseProducer:
|
||||
self.profiler.exit("sync_model")
|
||||
self.sync_model_thread_started = False
|
||||
|
||||
if not self.sync_model_thread_started and self.consumer_global_step != self.producer_weight_version:
|
||||
distributor_weight_version = ray.get(self.shared_signal_actor.get_signal.remote()).get(
|
||||
f"distributor_weight_version", 0
|
||||
)
|
||||
if (
|
||||
not self.sync_model_thread_started
|
||||
and distributor_weight_version != self.producer_weight_version
|
||||
):
|
||||
# only sync model when the thread is not started and global step is changed
|
||||
self.sync_model_thread_started = True
|
||||
self.sync_model_thread = threading.Thread(target=sync_model_thread)
|
||||
self.producer_weight_version = self.consumer_global_step
|
||||
self.producer_weight_version = distributor_weight_version
|
||||
self.sync_model_thread.start()
|
||||
torch.cuda.empty_cache()
|
||||
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
|
||||
@ -478,7 +483,7 @@ class SimpleProducer(BaseProducer):
|
||||
train_dataset_config,
|
||||
model_config,
|
||||
generate_config,
|
||||
tokenizer_config,
|
||||
copy.deepcopy(tokenizer_config),
|
||||
microbatch_size,
|
||||
backend,
|
||||
consumer_plugin_config,
|
||||
@ -493,7 +498,8 @@ class SimpleProducer(BaseProducer):
|
||||
rollout_log_file=rollout_log_file,
|
||||
enable_profiling=enable_profiling,
|
||||
)
|
||||
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
|
||||
print("tokenizer_config", tokenizer_config)
|
||||
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations, tokenizer_config)
|
||||
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
|
||||
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
|
||||
self.eval_generation_config.update(eval_generation_config)
|
||||
|
@ -15,6 +15,12 @@ DEFAUT_SYSTEM_PROMPT = {
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
||||
parser.add_argument(
|
||||
"--tokenizer-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the tokenizer. If not provided, will use the model path.",
|
||||
)
|
||||
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
||||
parser.add_argument(
|
||||
"-ed",
|
||||
@ -166,6 +172,7 @@ if __name__ == "__main__":
|
||||
"--enable_profiling", action="store_true", default=False, help="Enable profiling for the training process."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
if args.train_minibatch_size is None:
|
||||
# Default settings: Using train batch size as mini batch size
|
||||
@ -283,7 +290,7 @@ if __name__ == "__main__":
|
||||
elif args.algo == "DAPO":
|
||||
# DAPO variant settings
|
||||
grpo_config = {
|
||||
"filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch
|
||||
"filter_range": [0.01, 0.7], # only filter out all zero batch and all one batch
|
||||
"lr": args.learning_rate,
|
||||
"train_microbatch_size": args.train_microbatch_size,
|
||||
"dynamic_batching": True,
|
||||
@ -343,7 +350,9 @@ if __name__ == "__main__":
|
||||
), # microbatch size should be set to train_microbatch_size // pp_size
|
||||
"zero_stage": args.zero_stage,
|
||||
"max_norm": 1.0,
|
||||
"num_layers_per_stage": [18, 10],
|
||||
}, # for pp, tp
|
||||
tokenizer_config={"path": args.tokenizer_path} if args.tokenizer_path else {"path": args.model},
|
||||
inference_backend=args.backend,
|
||||
master_addr="localhost",
|
||||
master_port=args.master_port,
|
||||
|
Loading…
Reference in New Issue
Block a user