mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-22 09:49:14 +00:00
565 lines
29 KiB
Python
565 lines
29 KiB
Python
from contextlib import nullcontext
|
|
from typing import Any, Optional
|
|
|
|
import ray
|
|
import torch
|
|
import wandb
|
|
from coati.distributed.consumer import BaseConsumer
|
|
from coati.distributed.loss import PolicyLoss
|
|
from coati.distributed.utils import calc_action_log_probs, memory_efficient_logprob
|
|
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
|
from colossalai.nn.optimizer import HybridAdam
|
|
|
|
|
|
@ray.remote # (runtime_env={ "nsight": "default"})
|
|
class GRPOConsumer(BaseConsumer):
|
|
def __init__(
|
|
self,
|
|
num_producers,
|
|
num_episodes,
|
|
rank,
|
|
world_size,
|
|
master_addr,
|
|
master_port,
|
|
num_update_per_episode,
|
|
num_recv_per_update,
|
|
batch_size,
|
|
model_config,
|
|
plugin_config,
|
|
minibatch_size=1,
|
|
num_generations=8,
|
|
generate_config=None,
|
|
grpo_config={},
|
|
save_interval: int = 100,
|
|
save_dir="./model",
|
|
project_name: str = None,
|
|
run_name: str = None,
|
|
wandb_group_name: str = None,
|
|
):
|
|
print(f"Using GRPO config: {grpo_config}")
|
|
if (
|
|
plugin_config.get("pp_size", 1) > 1
|
|
and "num_microbatches" not in plugin_config
|
|
and "microbatch_size" not in plugin_config
|
|
):
|
|
plugin_config["microbatch_size"] = max(
|
|
1, grpo_config.get("train_microbatch_size") // plugin_config.get("pp_size", 1)
|
|
)
|
|
super().__init__(
|
|
num_producers,
|
|
num_episodes,
|
|
rank,
|
|
world_size,
|
|
master_addr,
|
|
master_port,
|
|
num_update_per_episode,
|
|
num_recv_per_update,
|
|
batch_size,
|
|
model_config,
|
|
plugin_config,
|
|
minibatch_size,
|
|
save_interval=save_interval,
|
|
save_dir=save_dir,
|
|
)
|
|
path = model_config.pop("path")
|
|
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
|
self.policy_model.train()
|
|
self.policy_model.gradient_checkpointing_enable()
|
|
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
|
|
self.accum_loss = torch.zeros(1, device=self.device)
|
|
self.accum_kl = torch.zeros(1, device=self.device)
|
|
self.accum_advantages = torch.zeros(1, device=self.device)
|
|
self.raw_train_batch_reward = []
|
|
self.raw_train_batch_format_acc = []
|
|
self.raw_train_batch_ans_acc = []
|
|
self.raw_train_batch_response_len = []
|
|
self.accum_count = 0
|
|
self.generate_config = generate_config
|
|
self.grpo_config = grpo_config
|
|
self.project_name = project_name
|
|
self.effective_sample_count = 0
|
|
self.effective_prompt_count = 0
|
|
self.project_name = project_name
|
|
self.run_name = run_name
|
|
self.wandb_group_name = wandb_group_name
|
|
|
|
self.policy_loss_fn = PolicyLoss(
|
|
clip_eps_low=grpo_config.get("clip_eps_low", 0.2),
|
|
clip_eps_high=grpo_config.get("clip_eps_high", 0.2),
|
|
beta=grpo_config.get("beta", 0.01),
|
|
loss_variation=grpo_config.get("loss_variation", "sample_level"),
|
|
)
|
|
|
|
# Reference model is initialized from policy model.
|
|
if self.policy_loss_fn.beta > 0:
|
|
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
|
self.reference_model.eval()
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
|
self.pad_token_id = self.tokenizer.pad_token_id
|
|
self.num_generations = num_generations
|
|
self.filter_range = grpo_config.get("filter_range", None)
|
|
if self.filter_range is not None:
|
|
assert len(self.filter_range) == 2, "Filter range should have 2 values."
|
|
|
|
self.filter_truncated_response = grpo_config.get("filter_truncated_response", False)
|
|
if self.filter_truncated_response:
|
|
self.max_length = 0
|
|
if "max_tokens" in self.generate_config:
|
|
self.max_length = self.generate_config["max_tokens"]
|
|
elif "max_new_tokens" in self.generate_config:
|
|
self.max_length = self.generate_config["max_new_tokens"]
|
|
else:
|
|
raise ValueError(
|
|
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
|
|
)
|
|
# Initialize verifiable reward.
|
|
grpo_config.get("response_format_tags", None)
|
|
self.global_step = 0
|
|
|
|
self.lr_scheduler = CosineAnnealingWarmupLR(
|
|
optimizer=self.optimizer,
|
|
total_steps=min(self.num_episodes, 4) * self.num_update_per_episode,
|
|
warmup_steps=0,
|
|
eta_min=0.1 * grpo_config.get("lr", 1e-6),
|
|
)
|
|
|
|
def setup(self):
|
|
super().setup()
|
|
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
|
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
|
):
|
|
self.wandb_run = wandb.init(
|
|
project=self.project_name,
|
|
sync_tensorboard=False,
|
|
dir="./wandb",
|
|
name=self.run_name,
|
|
group=self.wandb_group_name,
|
|
)
|
|
|
|
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
|
|
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
|
|
)
|
|
if self.policy_loss_fn.beta > 0:
|
|
self.reference_model, *_ = self.booster.boost(self.reference_model)
|
|
self.plugin.logger.set_level("ERROR")
|
|
|
|
def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
|
|
"""
|
|
Step data from policy model:
|
|
[{
|
|
"input_ids": torch.Tensor,
|
|
"attention_mask": torch.Tensor,
|
|
"action_mask": torch.Tensor,
|
|
"action_log_probs": torch.Tensor,
|
|
},
|
|
...]
|
|
Format:
|
|
[minibatch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
|
|
"""
|
|
# Reshape to [minibatch_size x num_of_generation, prompt_length + response_length]
|
|
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items() if "raw_train_mini_batch_" not in k}
|
|
self.raw_train_batch_reward.extend(kwargs["raw_train_mini_batch_reward"])
|
|
self.raw_train_batch_format_acc.extend(kwargs["raw_train_mini_batch_format_acc"])
|
|
self.raw_train_batch_ans_acc.extend(kwargs["raw_train_mini_batch_ans_acc"])
|
|
self.raw_train_batch_response_len.extend(kwargs["raw_train_mini_batch_response_len"])
|
|
action_mask = data["action_mask"]
|
|
num_action = action_mask.shape[1]
|
|
old_action_log_probs = data["action_log_probs"]
|
|
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
|
|
train_microbatch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0))
|
|
|
|
reward = data["reward"].view((-1))
|
|
format_acc = data["format_acc"].view((-1))
|
|
ans_acc = data["ans_acc"].view((-1))
|
|
|
|
# [minibatch_size, num_generations]
|
|
|
|
group_reward = reward.view(-1, self.num_generations)
|
|
reward_mean = group_reward.mean(dim=1)
|
|
# [minibatch_size x num_generations]
|
|
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
|
|
|
|
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
|
|
# [minibatch_size x num_generations]
|
|
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
|
|
|
|
# [minibatch_size x num_of_generation]
|
|
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()
|
|
|
|
# filter out overlength samples
|
|
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
|
|
loss_mask = torch.logical_and(
|
|
loss_mask,
|
|
action_mask[:, -1] == False,
|
|
)
|
|
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", False) == False:
|
|
# filter out samples with reward outside the range
|
|
# if dynamic batching is enabled, we filter out out of range groups before training
|
|
group_ans_acc_mean = (
|
|
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=-1)
|
|
)
|
|
loss_mask = torch.logical_and(
|
|
loss_mask,
|
|
torch.logical_and(
|
|
group_ans_acc_mean > self.filter_range[0],
|
|
group_ans_acc_mean < self.filter_range[1],
|
|
),
|
|
)
|
|
self.effective_prompt_count += group_reward.size(0) * self.dp_size
|
|
|
|
mean_kl, mean_loss = [], []
|
|
|
|
if self.grpo_config.get("dynamic_batching", True):
|
|
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
|
|
else:
|
|
# If dynamic batching is disabled, we need to use all samples for training.
|
|
need_update = (step_idx + 1) % self.num_microbatches == 0
|
|
|
|
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
|
|
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
|
|
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
|
|
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
|
|
self.effective_sample_count += effective_samples.item()
|
|
pbar.set_postfix(
|
|
{
|
|
"Global Step": self.global_step,
|
|
"Gradient Accumulation on": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size} effective prompts, {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations} effective samples",
|
|
}
|
|
)
|
|
|
|
# Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500
|
|
ctx = (
|
|
nullcontext()
|
|
if need_update or self.booster.plugin.zero_stage == 2
|
|
else self.booster.no_sync(self.policy_model, self.optimizer)
|
|
)
|
|
with ctx:
|
|
for forward_micro_batch_start in range(0, data["input_ids"].size(0), train_microbatch_size):
|
|
input_ids_forward_micro_batch = data["input_ids"][
|
|
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
|
]
|
|
attention_mask_forward_micro_batch = data["attention_mask"][
|
|
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
|
]
|
|
action_mask_forward_micro_batch = action_mask[
|
|
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
|
]
|
|
loss_mask_forward_micro_batch = (
|
|
loss_mask[forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size]
|
|
if loss_mask is not None
|
|
else None
|
|
)
|
|
advantages_forward_micro_batch = advantages[
|
|
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
|
]
|
|
|
|
if self.plugin.pp_size > 1:
|
|
# torch.cuda.empty_cache()
|
|
# Support training with PP.
|
|
if self.policy_loss_fn.beta > 0:
|
|
with torch.no_grad():
|
|
# torch.cuda.reset_peak_memory_stats()
|
|
reference_model_outputs = self.booster.execute_pipeline(
|
|
iter(
|
|
[
|
|
{
|
|
"input_ids": input_ids_forward_micro_batch,
|
|
"attention_mask": attention_mask_forward_micro_batch,
|
|
}
|
|
]
|
|
),
|
|
self.reference_model,
|
|
criterion=lambda outputs, inputs: torch.tensor(
|
|
[0.0], device=action_mask.device
|
|
), # dummy criterion
|
|
optimizer=None,
|
|
return_loss=False,
|
|
return_outputs=True,
|
|
)
|
|
self.profiler.log(
|
|
f"reference_model_forward_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
|
|
)
|
|
|
|
if self.booster.plugin.stage_manager.is_last_stage():
|
|
# breakpoint()
|
|
# torch.cuda.empty_cache()
|
|
# current_memory = torch.cuda.memory_allocated()
|
|
# torch.cuda.reset_peak_memory_stats()
|
|
|
|
# reference_action_log_probs = calc_action_log_probs(
|
|
# reference_model_outputs["outputs"]["logits"] / self.generate_config["temperature"],
|
|
# input_ids_forward_micro_batch,
|
|
# num_action,
|
|
# self.plugin.shard_config,
|
|
# )
|
|
|
|
# self.profiler.log(f"reference_action_log_probs: peak_memory: {(torch.cuda.max_memory_allocated()-current_memory) / 1024 / 1024:.2f}MB")
|
|
# torch.cuda.empty_cache()
|
|
# current_memory = torch.cuda.memory_allocated()
|
|
# torch.cuda.reset_peak_memory_stats()
|
|
reference_action_log_probs = memory_efficient_logprob(
|
|
reference_model_outputs["outputs"]["logits"],
|
|
input_ids_forward_micro_batch,
|
|
num_action,
|
|
shard_config=self.plugin.shard_config,
|
|
)
|
|
# self.profiler.log(f"me_reference_action_log_probs: peak_memory: {(torch.cuda.max_memory_allocated()-current_memory) / 1024 / 1024:.2f}MB")
|
|
# if torch.allclose(reference_action_log_probs, me_reference_action_log_probs):
|
|
# self.profiler.log("Memory efficient reference action log probs is same as normal reference action log probs")
|
|
# else:
|
|
# self.profiler.log("Memory efficient reference action log probs is different from normal reference action log probs")
|
|
# breakpoint()
|
|
# torch.cuda.empty_cache()
|
|
self.profiler.log(
|
|
f"reference_action_log_probs_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
|
|
)
|
|
else:
|
|
# Dummy reference logprobs for data iterator.
|
|
reference_action_log_probs = None
|
|
del reference_model_outputs
|
|
else:
|
|
reference_action_log_probs = None
|
|
|
|
data_policy_forward = {
|
|
"input_ids": input_ids_forward_micro_batch,
|
|
"attention_mask": attention_mask_forward_micro_batch,
|
|
"action_mask": action_mask_forward_micro_batch,
|
|
"advantages": advantages_forward_micro_batch,
|
|
"loss_mask": loss_mask_forward_micro_batch,
|
|
"source": self.rank,
|
|
}
|
|
if reference_action_log_probs is not None:
|
|
data_policy_forward["reference_action_log_probs"] = reference_action_log_probs
|
|
|
|
kl = []
|
|
|
|
def _criterion(outputs, inputs):
|
|
action_logits = outputs.logits
|
|
# breakpoint()
|
|
# torch.cuda.empty_cache()
|
|
# current_memory = torch.cuda.memory_allocated()
|
|
# torch.cuda.reset_peak_memory_stats()
|
|
# action_log_probs = calc_action_log_probs(
|
|
# action_logits / self.generate_config["temperature"],
|
|
# inputs["input_ids"],
|
|
# num_action,
|
|
# self.plugin.shard_config,
|
|
# )
|
|
# # torch.cuda.empty_cache()
|
|
# self.profiler.log(
|
|
# f"action_log_probs_{self.global_step}: peak_memory: {(torch.cuda.max_memory_allocated()-current_memory) / 1024 / 1024:.2f}MB"
|
|
# )
|
|
# torch.cuda.empty_cache()
|
|
# current_memory = torch.cuda.memory_allocated()
|
|
# torch.cuda.reset_peak_memory_stats()
|
|
action_log_probs = memory_efficient_logprob(
|
|
action_logits,
|
|
inputs["input_ids"],
|
|
num_action,
|
|
shard_config=self.plugin.shard_config,
|
|
)
|
|
# self.profiler.log(
|
|
# f"me_action_log_probs_{self.global_step}: peak_memory: {(torch.cuda.max_memory_allocated()-current_memory) / 1024 / 1024:.2f}MB"
|
|
# )
|
|
# if torch.allclose(action_log_probs, me_action_log_probs):
|
|
# self.profiler.log("Memory efficient action log probs is same as normal action log probs")
|
|
# else:
|
|
# self.profiler.log("Memory efficient action log probs is different from normal action log probs")
|
|
# torch.cuda.empty_cache()
|
|
# breakpoint()
|
|
# current_memory = torch.cuda.memory_allocated()
|
|
# torch.cuda.empty_cache()
|
|
# self.profiler.log(
|
|
# f"released by del outputs: {(torch.cuda.memory_allocated()-current_memory) / 1024 / 1024:.2f}MB"
|
|
# )
|
|
if "reference_action_log_probs" in inputs:
|
|
per_token_kl = (
|
|
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
|
|
- (inputs["reference_action_log_probs"] - action_log_probs)
|
|
- 1
|
|
)
|
|
appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum(
|
|
inputs["action_mask"], dim=-1
|
|
)
|
|
kl.append(appox_kl.mean())
|
|
else:
|
|
per_token_kl = 0.0
|
|
kl.append(torch.tensor(0.0))
|
|
|
|
loss, _ = self.policy_loss_fn(
|
|
action_log_probs,
|
|
action_log_probs,
|
|
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
|
|
per_token_kl,
|
|
inputs["action_mask"],
|
|
loss_mask=inputs["loss_mask"],
|
|
total_effective_tokens_in_batch=total_effective_tokens_count,
|
|
)
|
|
return loss
|
|
|
|
policy_model_outputs = self.booster.execute_pipeline(
|
|
iter([data_policy_forward]),
|
|
self.policy_model,
|
|
criterion=_criterion,
|
|
optimizer=self.optimizer,
|
|
return_loss=True,
|
|
return_outputs=False,
|
|
)
|
|
self.profiler.log(
|
|
f"policy_model_forward_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
|
|
)
|
|
loss = policy_model_outputs["loss"]
|
|
|
|
if self.booster.plugin.stage_manager.is_last_stage():
|
|
if len(kl) > 0:
|
|
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data
|
|
mean_kl.append(kl)
|
|
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
|
|
else:
|
|
|
|
policy_model_logits = self.policy_model(
|
|
input_ids=input_ids_forward_micro_batch,
|
|
attention_mask=attention_mask_forward_micro_batch,
|
|
).logits
|
|
action_log_probs = calc_action_log_probs(
|
|
policy_model_logits / self.generate_config["temperature"],
|
|
input_ids_forward_micro_batch,
|
|
num_action,
|
|
self.plugin.shard_config,
|
|
)
|
|
|
|
if self.policy_loss_fn.beta > 0:
|
|
with torch.no_grad():
|
|
reference_model_logits = self.reference_model(
|
|
input_ids=input_ids_forward_micro_batch,
|
|
attention_mask=attention_mask_forward_micro_batch,
|
|
).logits
|
|
|
|
# torch.cuda.empty_cache()
|
|
reference_action_log_probs = calc_action_log_probs(
|
|
reference_model_logits / self.generate_config["temperature"],
|
|
input_ids_forward_micro_batch,
|
|
num_action,
|
|
self.plugin.shard_config,
|
|
)
|
|
per_token_kl = (
|
|
torch.exp(reference_action_log_probs - action_log_probs)
|
|
- (reference_action_log_probs - action_log_probs)
|
|
- 1
|
|
)
|
|
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
|
|
action_mask_forward_micro_batch, dim=-1
|
|
)
|
|
else:
|
|
per_token_kl = 0.0
|
|
kl = None
|
|
|
|
loss, _ = self.policy_loss_fn(
|
|
action_log_probs,
|
|
old_action_log_probs,
|
|
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
|
|
per_token_kl,
|
|
action_mask_forward_micro_batch,
|
|
loss_mask=loss_mask_forward_micro_batch,
|
|
total_effective_tokens_in_batch=total_effective_tokens_count,
|
|
)
|
|
|
|
self.booster.backward(loss, self.optimizer)
|
|
loss = all_reduce_mean(loss, self.plugin)
|
|
# Calculate accumulate value.
|
|
if kl is not None:
|
|
kl = all_reduce_mean(kl.mean(), self.plugin)
|
|
mean_kl.append(kl.data)
|
|
mean_loss.append(loss.data)
|
|
if not self.plugin.pp_size > 1 or (
|
|
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
|
):
|
|
reward = all_reduce_mean(reward.mean(), self.plugin)
|
|
format_acc = all_reduce_mean(format_acc.mean(), self.plugin)
|
|
ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin)
|
|
advantages = all_reduce_mean(advantages.mean(), self.plugin)
|
|
response_length = all_reduce_mean(response_length.mean(), self.plugin)
|
|
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
|
|
if self.policy_loss_fn.beta > 0:
|
|
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
|
|
self.accum_advantages.add_(advantages.data)
|
|
self.accum_count += 1
|
|
if need_update:
|
|
# breakpoint()
|
|
# torch.cuda.empty_cache()
|
|
# torch.cuda.reset_peak_memory_stats()
|
|
self.optimizer.step()
|
|
self.optimizer.zero_grad()
|
|
self.global_step += 1
|
|
# no need to run all reduce as raw_train_batch_* are not splited across dp rank
|
|
sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations
|
|
self.effective_prompt_count = 0
|
|
self.effective_sample_count = 0
|
|
self.profiler.log(
|
|
f"optimizer_step_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
|
|
)
|
|
loss_scalar = self.accum_loss.item()
|
|
if not self.plugin.pp_size > 1 or (
|
|
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
|
):
|
|
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
|
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
|
):
|
|
raw_batch_reward_mean = torch.cat(self.raw_train_batch_reward, dim=0).mean().cpu().item()
|
|
raw_batch_format_acc_mean = torch.cat(self.raw_train_batch_format_acc, dim=0).mean().cpu().item()
|
|
raw_batch_ans_acc_mean = torch.cat(self.raw_train_batch_ans_acc, dim=0).mean().cpu().item()
|
|
raw_batch_response_len = torch.cat(self.raw_train_batch_response_len, dim=0)
|
|
raw_batch_response_len_mean = raw_batch_response_len.mean().cpu().item()
|
|
overlength_samples_ratio = (
|
|
(raw_batch_response_len >= action_mask.size(-1)).to(float).mean().cpu().item()
|
|
) # not an exact figure, but a close estimate
|
|
self.raw_train_batch_reward = []
|
|
self.raw_train_batch_format_acc = []
|
|
self.raw_train_batch_ans_acc = []
|
|
self.raw_train_batch_response_len = []
|
|
to_log_msg = [
|
|
f"Loss: {self.accum_loss.item() / self.accum_count:.4f}",
|
|
f"Reward: {raw_batch_reward_mean:.4f}",
|
|
f"format Reward: {raw_batch_format_acc_mean:.4f}",
|
|
f"Acc Reward: {raw_batch_ans_acc_mean:.4f}",
|
|
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
|
|
f"Response Length: {raw_batch_response_len_mean:.4f}",
|
|
f"Sample_utilization: {sample_utilization:.4f}",
|
|
f"Overlength samples ratio: {overlength_samples_ratio:.4f}",
|
|
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
|
|
print("\n".join(to_log_msg))
|
|
metrics = {
|
|
"metrics/reward": raw_batch_reward_mean,
|
|
"metrics/format_acc": raw_batch_format_acc_mean,
|
|
"metrics/ans_acc": raw_batch_ans_acc_mean,
|
|
"metrics/response_length": raw_batch_response_len_mean,
|
|
"train/loss": self.accum_loss.item() / self.accum_count,
|
|
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
|
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
|
"train/sample_utilization": sample_utilization,
|
|
"train/overlength_samples_ratio": overlength_samples_ratio,
|
|
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
|
|
}
|
|
if self.policy_loss_fn.beta > 0:
|
|
metrics["train/kl"] = self.accum_kl.item() / self.accum_count
|
|
if self.wandb_run is not None:
|
|
self.wandb_run.log(metrics)
|
|
self.accum_loss.zero_()
|
|
self.accum_kl.zero_()
|
|
self.accum_advantages.zero_()
|
|
self.accum_count = 0
|
|
return loss_scalar
|
|
else:
|
|
return None
|
|
|
|
def state_dict(self):
|
|
self.policy_model._force_wait_all_gather()
|
|
model = self.policy_model.unwrap()
|
|
state_dict = model.state_dict()
|
|
state_dict["consumer_global_step"] = torch.tensor([self.global_step], device=self.device)
|
|
return state_dict
|