mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-29 08:47:17 +00:00
fix metric calculation
This commit is contained in:
parent
116621d004
commit
37663386bc
@ -120,24 +120,85 @@ class BaseConsumer:
|
|||||||
raw_batch = unbind_batch(
|
raw_batch = unbind_batch(
|
||||||
ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}")
|
ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}")
|
||||||
)
|
)
|
||||||
processed_batch = [
|
recv_effective_count = 0
|
||||||
self.prompt_level_filtering(self.calculate_group_reward(group)) for group in raw_batch
|
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
|
||||||
]
|
# we need to calculate the metrics before filtering here for logging
|
||||||
filtered_batch = [t for t in processed_batch if t is not None]
|
for group in raw_batch:
|
||||||
|
group_with_reward = self.calculate_group_reward(group)
|
||||||
|
group_reward_mean = group_with_reward["reward"].mean().cpu().item()
|
||||||
|
group_format_acc_mean = group_with_reward["format_acc"].mean().cpu().item()
|
||||||
|
group_ans_acc_mean = group_with_reward["ans_acc"].mean().cpu().item()
|
||||||
|
group_response_len = (
|
||||||
|
(
|
||||||
|
group_with_reward["response_idx"][:, 1]
|
||||||
|
- group_with_reward["response_idx"][:, 0]
|
||||||
|
+ 1
|
||||||
|
)
|
||||||
|
.type(torch.float32)
|
||||||
|
.mean()
|
||||||
|
.cpu()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
filtered_group = self.prompt_level_filtering(group_with_reward)
|
||||||
|
recv_effective_count += 1 if filtered_group is not None else 0
|
||||||
|
self.buffer.append(
|
||||||
|
[
|
||||||
|
filtered_group,
|
||||||
|
group_reward_mean,
|
||||||
|
group_format_acc_mean,
|
||||||
|
group_ans_acc_mean,
|
||||||
|
group_response_len,
|
||||||
|
]
|
||||||
|
)
|
||||||
if self.filter_range is not None:
|
if self.filter_range is not None:
|
||||||
print(
|
print(
|
||||||
f"[T{dist.get_rank()}] Filter recv data: {len(processed_batch)} -> {len(filtered_batch)}"
|
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {recv_effective_count}"
|
||||||
)
|
)
|
||||||
|
# mapping the effective group to the raw group for indexing
|
||||||
|
effective_group_to_raw_group_mapping = {}
|
||||||
|
for buffer_idx in range(len(self.buffer)):
|
||||||
|
if self.buffer[buffer_idx][0] is not None:
|
||||||
|
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
|
||||||
|
buffer_idx
|
||||||
|
)
|
||||||
|
|
||||||
self.buffer.extend(filtered_batch)
|
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
|
||||||
while len(self.buffer) >= self.dp_size * self.minibatch_size:
|
# on each dp_rank, we use minibatch_size effective samples to form a batch
|
||||||
batches = self.buffer[
|
batches = [
|
||||||
self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size
|
self.buffer[effective_group_to_raw_group_mapping[i]]
|
||||||
|
for i in range(
|
||||||
|
self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size
|
||||||
|
)
|
||||||
]
|
]
|
||||||
batch = bind_batch(batches)
|
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
|
||||||
|
# each mini-batch use the first self.dp_size * minibatch_size effective samples
|
||||||
|
raw_mini_batches = self.buffer[
|
||||||
|
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
|
||||||
|
] # include the last effective sample
|
||||||
|
raw_mini_batches_metric_dict = {
|
||||||
|
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
|
||||||
|
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
|
||||||
|
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
|
||||||
|
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
|
||||||
|
}
|
||||||
|
batch = bind_batch([t[0] for t in batches])
|
||||||
batch = post_recv(batch)
|
batch = post_recv(batch)
|
||||||
loss = self.step(i, pbar, **batch)
|
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
|
||||||
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
|
self.buffer = self.buffer[
|
||||||
|
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
|
||||||
|
]
|
||||||
|
# recalculate the effective group to raw group mapping
|
||||||
|
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
|
||||||
|
effective_group_to_raw_group_mapping = {}
|
||||||
|
for buffer_idx in range(len(self.buffer)):
|
||||||
|
if self.buffer[buffer_idx][0] is not None:
|
||||||
|
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
|
||||||
|
buffer_idx
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
len(effective_group_to_raw_group_mapping)
|
||||||
|
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
|
||||||
|
)
|
||||||
if loss is not None:
|
if loss is not None:
|
||||||
pbar.set_postfix({"loss": loss})
|
pbar.set_postfix({"loss": loss})
|
||||||
i += 1
|
i += 1
|
||||||
|
@ -72,12 +72,12 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.policy_model.gradient_checkpointing_enable()
|
self.policy_model.gradient_checkpointing_enable()
|
||||||
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
|
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
|
||||||
self.accum_loss = torch.zeros(1, device=self.device)
|
self.accum_loss = torch.zeros(1, device=self.device)
|
||||||
self.accum_reward = torch.zeros(1, device=self.device)
|
|
||||||
self.accum_kl = torch.zeros(1, device=self.device)
|
self.accum_kl = torch.zeros(1, device=self.device)
|
||||||
self.accum_format_acc = torch.zeros(1, device=self.device)
|
|
||||||
self.accum_ans_acc = torch.zeros(1, device=self.device)
|
|
||||||
self.accum_advantages = torch.zeros(1, device=self.device)
|
self.accum_advantages = torch.zeros(1, device=self.device)
|
||||||
self.accum_response_length = torch.zeros(1, device=self.device)
|
self.raw_train_batch_reward = []
|
||||||
|
self.raw_train_batch_format_acc = []
|
||||||
|
self.raw_train_batch_ans_acc = []
|
||||||
|
self.raw_train_batch_response_len = []
|
||||||
self.accum_count = 0
|
self.accum_count = 0
|
||||||
self.generate_config = generate_config
|
self.generate_config = generate_config
|
||||||
self.grpo_config = grpo_config
|
self.grpo_config = grpo_config
|
||||||
@ -186,7 +186,11 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
[minibatch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
|
[minibatch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
|
||||||
"""
|
"""
|
||||||
# Reshape to [minibatch_size x num_of_generation, prompt_length + response_length]
|
# Reshape to [minibatch_size x num_of_generation, prompt_length + response_length]
|
||||||
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
|
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items() if "raw_train_mini_batch_" not in k}
|
||||||
|
self.raw_train_batch_reward.extend(kwargs["raw_train_mini_batch_reward"])
|
||||||
|
self.raw_train_batch_format_acc.extend(kwargs["raw_train_mini_batch_format_acc"])
|
||||||
|
self.raw_train_batch_ans_acc.extend(kwargs["raw_train_mini_batch_ans_acc"])
|
||||||
|
self.raw_train_batch_response_len.extend(kwargs["raw_train_mini_batch_response_len"])
|
||||||
action_mask = data["action_mask"]
|
action_mask = data["action_mask"]
|
||||||
num_action = action_mask.shape[1]
|
num_action = action_mask.shape[1]
|
||||||
old_action_log_probs = data["action_log_probs"]
|
old_action_log_probs = data["action_log_probs"]
|
||||||
@ -430,11 +434,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
|
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
|
||||||
if self.policy_loss_fn.beta > 0:
|
if self.policy_loss_fn.beta > 0:
|
||||||
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
|
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
|
||||||
self.accum_reward.add_(reward.data)
|
|
||||||
self.accum_format_acc.add_(format_acc.data)
|
|
||||||
self.accum_ans_acc.add_(ans_acc.data)
|
|
||||||
self.accum_advantages.add_(advantages.data)
|
self.accum_advantages.add_(advantages.data)
|
||||||
self.accum_response_length.add_(response_length.data)
|
|
||||||
self.accum_count += 1
|
self.accum_count += 1
|
||||||
if need_update:
|
if need_update:
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
@ -452,21 +452,33 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
||||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||||
):
|
):
|
||||||
|
raw_batch_reward_mean = sum(self.raw_train_batch_reward) / len(self.raw_train_batch_reward)
|
||||||
|
raw_batch_format_acc_mean = sum(self.raw_train_batch_format_acc) / len(
|
||||||
|
self.raw_train_batch_format_acc
|
||||||
|
)
|
||||||
|
raw_batch_ans_acc_mean = sum(self.raw_train_batch_ans_acc) / len(self.raw_train_batch_ans_acc)
|
||||||
|
raw_batch_response_len_mean = sum(self.raw_train_batch_response_len) / len(
|
||||||
|
self.raw_train_batch_response_len
|
||||||
|
)
|
||||||
|
self.raw_train_batch_reward = []
|
||||||
|
self.raw_train_batch_format_acc = []
|
||||||
|
self.raw_train_batch_ans_acc = []
|
||||||
|
self.raw_train_batch_response_len = []
|
||||||
to_log_msg = [
|
to_log_msg = [
|
||||||
f"Loss: {self.accum_loss.item() / self.accum_count:.4f}",
|
f"Loss: {self.accum_loss.item() / self.accum_count:.4f}",
|
||||||
f"Reward: {self.accum_reward.item() / self.accum_count:.4f}",
|
f"Reward: {raw_batch_reward_mean:.4f}",
|
||||||
f"format Reward: {self.accum_format_acc.item() / self.accum_count:.4f}",
|
f"format Reward: {raw_batch_format_acc_mean:.4f}",
|
||||||
f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}",
|
f"Acc Reward: {raw_batch_ans_acc_mean:.4f}",
|
||||||
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
|
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
|
||||||
f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}",
|
f"Response Length: {raw_batch_response_len_mean:.4f}",
|
||||||
f"Sample_utilization: {sample_utilization:.4f}",
|
f"Sample_utilization: {sample_utilization:.4f}",
|
||||||
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
|
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
|
||||||
print("\n".join(to_log_msg))
|
print("\n".join(to_log_msg))
|
||||||
metrics = {
|
metrics = {
|
||||||
"metrics/reward": self.accum_reward.item() / self.accum_count,
|
"metrics/reward": raw_batch_reward_mean,
|
||||||
"metrics/format_acc": self.accum_format_acc.item() / self.accum_count,
|
"metrics/format_acc": raw_batch_format_acc_mean,
|
||||||
"metrics/ans_acc": self.accum_ans_acc.item() / self.accum_count,
|
"metrics/ans_acc": raw_batch_ans_acc_mean,
|
||||||
"metrics/response_length": self.accum_response_length.item() / self.accum_count,
|
"metrics/response_length": raw_batch_response_len_mean,
|
||||||
"train/loss": self.accum_loss.item() / self.accum_count,
|
"train/loss": self.accum_loss.item() / self.accum_count,
|
||||||
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
||||||
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
||||||
@ -478,12 +490,8 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
if self.wandb_run is not None:
|
if self.wandb_run is not None:
|
||||||
self.wandb_run.log(metrics)
|
self.wandb_run.log(metrics)
|
||||||
self.accum_loss.zero_()
|
self.accum_loss.zero_()
|
||||||
self.accum_reward.zero_()
|
|
||||||
self.accum_ans_acc.zero_()
|
|
||||||
self.accum_format_acc.zero_()
|
|
||||||
self.accum_kl.zero_()
|
self.accum_kl.zero_()
|
||||||
self.accum_advantages.zero_()
|
self.accum_advantages.zero_()
|
||||||
self.accum_response_length.zero_()
|
|
||||||
self.accum_count = 0
|
self.accum_count = 0
|
||||||
return loss_scalar
|
return loss_scalar
|
||||||
else:
|
else:
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import copy
|
import copy
|
||||||
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
@ -56,7 +57,7 @@ def launch_distributed(
|
|||||||
eval_save_dir: Optional[str] = None,
|
eval_save_dir: Optional[str] = None,
|
||||||
eval_generation_config: Optional[Dict[str, Any]] = None,
|
eval_generation_config: Optional[Dict[str, Any]] = None,
|
||||||
log_rollout_interval: int = 20,
|
log_rollout_interval: int = 20,
|
||||||
rollout_log_file: str = "./rollout_log.jsonl",
|
rollout_save_dir: str = "./rollout",
|
||||||
):
|
):
|
||||||
if core_algo not in ALGO_MAP:
|
if core_algo not in ALGO_MAP:
|
||||||
raise NotImplementedError(f"{core_algo} is not supported yet.")
|
raise NotImplementedError(f"{core_algo} is not supported yet.")
|
||||||
@ -74,6 +75,10 @@ def launch_distributed(
|
|||||||
|
|
||||||
run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
|
run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
|
||||||
wandb_group_name = str(uuid.uuid4())
|
wandb_group_name = str(uuid.uuid4())
|
||||||
|
rollout_log_file = os.path.join(
|
||||||
|
rollout_save_dir,
|
||||||
|
f"{project_name}_run_{wandb_group_name}.jsonl",
|
||||||
|
)
|
||||||
|
|
||||||
procs = []
|
procs = []
|
||||||
for i in range(num_producers):
|
for i in range(num_producers):
|
||||||
|
@ -121,6 +121,34 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings."
|
"-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-tp",
|
||||||
|
"--tensor-parallel-size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Tensor parallel size for the inference backend. Please check the generation arguments documentation for your backend.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-pp",
|
||||||
|
"--pipeline-parallel-size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Pipeline parallel size for the inference backend. Please check the generation arguments documentation for your backend.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-zero",
|
||||||
|
"--zero-stage",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Zero stage for the inference backend. Please check the generation arguments documentation for your backend.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-ptp",
|
||||||
|
"--produce-tensor-parallel-size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.train_minibatch_size is None:
|
if args.train_minibatch_size is None:
|
||||||
@ -178,7 +206,7 @@ if __name__ == "__main__":
|
|||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
enable_chunked_prefill=True,
|
enable_chunked_prefill=True,
|
||||||
max_model_len=args.max_new_tokens + args.max_prompt_tokens,
|
max_model_len=args.max_new_tokens + args.max_prompt_tokens,
|
||||||
tensor_parallel_size=1,
|
tensor_parallel_size=args.produce_tensor_parallel_size,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
generate_config.update(
|
generate_config.update(
|
||||||
@ -228,7 +256,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
launch_distributed(
|
launch_distributed(
|
||||||
num_producers=args.num_inferencer,
|
num_producers=args.num_inferencer,
|
||||||
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", 1),
|
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.produce_tensor_parallel_size),
|
||||||
num_consumer_procs=args.num_trainers,
|
num_consumer_procs=args.num_trainers,
|
||||||
num_episodes=args.num_episodes,
|
num_episodes=args.num_episodes,
|
||||||
inference_batch_size=args.inference_batch_size,
|
inference_batch_size=args.inference_batch_size,
|
||||||
@ -247,17 +275,14 @@ if __name__ == "__main__":
|
|||||||
train_model_config=train_model_config,
|
train_model_config=train_model_config,
|
||||||
grpo_config=grpo_config,
|
grpo_config=grpo_config,
|
||||||
plugin_config={
|
plugin_config={
|
||||||
"zero_stage": 2,
|
"tp_size": args.tensor_parallel_size,
|
||||||
}, # for zero
|
"pp_size": args.pipeline_parallel_size,
|
||||||
# plugin_config={
|
"microbatch_size": max(
|
||||||
# "tp_size": 2,
|
1, args.train_microbatch_size // args.pipeline_parallel_size
|
||||||
# "pp_size": 2,
|
), # microbatch size should be set to train_microbatch_size // pp_size
|
||||||
# "microbatch_size": max(
|
"zero_stage": args.zero_stage,
|
||||||
# 1, args.train_microbatch_size // 2
|
"max_norm": 1.0,
|
||||||
# ), # microbatch size should be set to train_microbatch_size // pp_size
|
}, # for pp, tp
|
||||||
# "zero_stage": 0,
|
|
||||||
# "max_norm": 1.0,
|
|
||||||
# }, # for pp, tp
|
|
||||||
inference_backend=args.backend,
|
inference_backend=args.backend,
|
||||||
master_addr="localhost",
|
master_addr="localhost",
|
||||||
master_port=args.master_port,
|
master_port=args.master_port,
|
||||||
@ -273,5 +298,5 @@ if __name__ == "__main__":
|
|||||||
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
|
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
|
||||||
eval_generation_config=eval_generation_config,
|
eval_generation_config=eval_generation_config,
|
||||||
log_rollout_interval=20,
|
log_rollout_interval=20,
|
||||||
rollout_log_file=os.path.join(args.rollout_save_dir, args.project.replace(" ", "_") + ".jsonl"),
|
rollout_save_dir=args.rollout_save_dir,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user