mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 23:11:55 +00:00
move prompt-level-filtering to buffer side
This commit is contained in:
parent
957e3a521a
commit
55eee129d2
@ -113,18 +113,24 @@ class BaseConsumer:
|
|||||||
) as pbar:
|
) as pbar:
|
||||||
for step in pbar:
|
for step in pbar:
|
||||||
i = 0
|
i = 0
|
||||||
allow_sync_model = False
|
allow_sync_model = True
|
||||||
for _ in range(self.num_recv_per_update):
|
for _ in range(self.num_recv_per_update):
|
||||||
# receive data from producers
|
# receive data from producers
|
||||||
for r in range(self.num_producers):
|
for r in range(self.num_producers):
|
||||||
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
|
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
|
||||||
self.buffer.extend(
|
raw_batch = unbind_batch(
|
||||||
unbind_batch(
|
ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}")
|
||||||
ray_broadcast_tensor_dict(
|
|
||||||
None, src=0, device=self.device, group_name=f"sync_data_{r}"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
filtered_batch = [
|
||||||
|
t
|
||||||
|
for t in [
|
||||||
|
self.prompt_level_filtering(self.calculate_group_reward(group))
|
||||||
|
for group in raw_batch
|
||||||
|
]
|
||||||
|
if t is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
self.buffer.extend(filtered_batch)
|
||||||
while len(self.buffer) >= self.dp_size * self.minibatch_size:
|
while len(self.buffer) >= self.dp_size * self.minibatch_size:
|
||||||
batches = self.buffer[
|
batches = self.buffer[
|
||||||
self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size
|
self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size
|
||||||
@ -177,7 +183,7 @@ class BaseConsumer:
|
|||||||
)
|
)
|
||||||
del state_dict
|
del state_dict
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
allow_sync_model = False
|
allow_sync_model = True
|
||||||
|
|
||||||
|
|
||||||
@ray.remote
|
@ray.remote
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import Any, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import torch
|
import torch
|
||||||
@ -179,7 +179,6 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
Format:
|
Format:
|
||||||
[minibatch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
|
[minibatch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Reshape to [minibatch_size x num_of_generation, prompt_length + response_length]
|
# Reshape to [minibatch_size x num_of_generation, prompt_length + response_length]
|
||||||
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
|
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
|
||||||
action_mask = data["action_mask"]
|
action_mask = data["action_mask"]
|
||||||
@ -188,15 +187,9 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
|
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
|
||||||
train_microbatch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0))
|
train_microbatch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0))
|
||||||
|
|
||||||
reward_group = self.reward_model(
|
reward = data["reward"].view((-1))
|
||||||
data["input_ids"],
|
format_acc = data["format_acc"].view((-1))
|
||||||
gt_answer=data["gt_answer"],
|
ans_acc = data["ans_acc"].view((-1))
|
||||||
response_idx=data["response_idx"],
|
|
||||||
)
|
|
||||||
|
|
||||||
reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device)
|
|
||||||
format_acc = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device)
|
|
||||||
ans_acc = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
|
|
||||||
|
|
||||||
# [minibatch_size, num_generations]
|
# [minibatch_size, num_generations]
|
||||||
|
|
||||||
@ -213,11 +206,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0)
|
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0)
|
||||||
)
|
)
|
||||||
# [minibatch_size x num_of_generation]
|
# [minibatch_size x num_of_generation]
|
||||||
loss_mask = (
|
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()
|
||||||
torch.ones(action_mask.size(0), device=action_mask.device).bool()
|
|
||||||
if self.filter_range is None
|
|
||||||
else torch.logical_and(group_ans_acc > self.filter_range[0], group_ans_acc < self.filter_range[1])
|
|
||||||
)
|
|
||||||
|
|
||||||
# filter out overlength samples
|
# filter out overlength samples
|
||||||
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
|
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
|
||||||
@ -525,6 +514,68 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
else:
|
else:
|
||||||
return None, excessive_prompts_idx
|
return None, excessive_prompts_idx
|
||||||
|
|
||||||
|
def calculate_group_reward(self, rollout_group: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Calculate the group reward for the given rollout group.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rollout_group (Dict[str, Any]):
|
||||||
|
a group of samples generated by the model from the same prompt
|
||||||
|
contain the following keys:
|
||||||
|
"input_ids": torch.Tensor, [num_of_generation, prompt_length + response_length]
|
||||||
|
"attention_mask": torch.Tensor, [num_of_generation, prompt_length + response_length]
|
||||||
|
"action_mask": torch.Tensor, [num_of_generation, response_length]
|
||||||
|
"action_log_probs": torch.Tensor, [num_of_generation, response_length]
|
||||||
|
"response_idx": int, torch.Tensor, [num_of_generation, 2]
|
||||||
|
"gt_answer": torch.Tensor, [num_of_generation, 128]
|
||||||
|
"temperature": torch.Tensor, [] (scalar)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: The new group data with calculated reward.
|
||||||
|
"""
|
||||||
|
reward_group = self.reward_model(
|
||||||
|
rollout_group["input_ids"],
|
||||||
|
gt_answer=rollout_group["gt_answer"],
|
||||||
|
response_idx=rollout_group["response_idx"],
|
||||||
|
)
|
||||||
|
# [num_of_generation]
|
||||||
|
reward = torch.tensor([value[0] for value in reward_group]).to(rollout_group["input_ids"].device)
|
||||||
|
format_acc = torch.tensor([value[1] for value in reward_group]).to(rollout_group["input_ids"].device)
|
||||||
|
ans_acc = torch.tensor([value[2] for value in reward_group]).to(rollout_group["input_ids"].device)
|
||||||
|
|
||||||
|
rollout_group["reward"] = reward.view((-1, 1))
|
||||||
|
rollout_group["format_acc"] = format_acc.view((-1, 1))
|
||||||
|
rollout_group["ans_acc"] = ans_acc.view((-1, 1))
|
||||||
|
return rollout_group
|
||||||
|
|
||||||
|
def prompt_level_filtering(self, rollout_group: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
rollout_group: Dict[str, Any]
|
||||||
|
a group of samples generated by the model from the same prompt
|
||||||
|
contain the following keys:
|
||||||
|
"input_ids": torch.Tensor, [num_of_generation, prompt_length + response_length]
|
||||||
|
"attention_mask": torch.Tensor, [num_of_generation, prompt_length + response_length]
|
||||||
|
"action_mask": torch.Tensor, [num_of_generation, response_length]
|
||||||
|
"action_log_probs": torch.Tensor, [num_of_generation, response_length]
|
||||||
|
"response_idx": int, torch.Tensor, [num_of_generation, 2]
|
||||||
|
"gt_answer": torch.Tensor, [num_of_generation, 128]
|
||||||
|
"temperature": torch.Tensor, [] (scalar)
|
||||||
|
"reward": torch.Tensor, [num_of_generation]
|
||||||
|
"format_acc": torch.Tensor, [num_of_generation]
|
||||||
|
"ans_acc": torch.Tensor, [num_of_generation]
|
||||||
|
"""
|
||||||
|
if self.filter_range is not None:
|
||||||
|
# filter prompt whoes accuracy is too high or too low (out of range)
|
||||||
|
group_ans_acc = torch.mean(rollout_group["ans_acc"])
|
||||||
|
if group_ans_acc < self.filter_range[0] or group_ans_acc > self.filter_range[1]:
|
||||||
|
# filter out the prompt
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return rollout_group
|
||||||
|
else:
|
||||||
|
# no filter
|
||||||
|
return rollout_group
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
self.policy_model._force_wait_all_gather()
|
self.policy_model._force_wait_all_gather()
|
||||||
model = self.policy_model.unwrap()
|
model = self.policy_model.unwrap()
|
||||||
|
Loading…
Reference in New Issue
Block a user