mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-29 16:57:55 +00:00
address conversation
This commit is contained in:
parent
78a06f5ce3
commit
4c3656870a
@ -117,14 +117,12 @@ class BaseConsumer:
|
|||||||
# receive data from producers
|
# receive data from producers
|
||||||
for r in range(self.num_producers):
|
for r in range(self.num_producers):
|
||||||
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
|
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
|
||||||
raw_batch = unbind_batch(
|
raw_batch = ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}")
|
||||||
ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}")
|
|
||||||
)
|
|
||||||
recv_effective_count = 0
|
recv_effective_count = 0
|
||||||
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
|
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
|
||||||
# we need to calculate the metrics before filtering here for logging
|
# we need to calculate the metrics before filtering here for logging
|
||||||
for group in raw_batch:
|
raw_batch_with_reward = unbind_batch(self.calculate_reward(raw_batch))
|
||||||
group_with_reward = self.calculate_group_reward(group)
|
for group_with_reward in raw_batch_with_reward:
|
||||||
group_reward_mean = group_with_reward["reward"].mean().cpu().item()
|
group_reward_mean = group_with_reward["reward"].mean().cpu().item()
|
||||||
group_format_acc_mean = group_with_reward["format_acc"].mean().cpu().item()
|
group_format_acc_mean = group_with_reward["format_acc"].mean().cpu().item()
|
||||||
group_ans_acc_mean = group_with_reward["ans_acc"].mean().cpu().item()
|
group_ans_acc_mean = group_with_reward["ans_acc"].mean().cpu().item()
|
||||||
@ -139,7 +137,8 @@ class BaseConsumer:
|
|||||||
.cpu()
|
.cpu()
|
||||||
.item()
|
.item()
|
||||||
)
|
)
|
||||||
filtered_group = self.prompt_level_filtering(group_with_reward)
|
if self.grpo_config.get("dynamic_batching", True):
|
||||||
|
filtered_group = self.prompt_level_filtering(group_with_reward)
|
||||||
recv_effective_count += 1 if filtered_group is not None else 0
|
recv_effective_count += 1 if filtered_group is not None else 0
|
||||||
self.buffer.append(
|
self.buffer.append(
|
||||||
[
|
[
|
||||||
|
@ -218,8 +218,6 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
|
|
||||||
if self.grpo_config.get("dynamic_batching", True):
|
if self.grpo_config.get("dynamic_batching", True):
|
||||||
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
|
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
|
||||||
excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size
|
|
||||||
assert excessive_prompts <= 0, "Debug: Excessive prompts should always be less than 0. Bug!!!!"
|
|
||||||
else:
|
else:
|
||||||
# If dynamic batching is disabled, we need to use all samples for training.
|
# If dynamic batching is disabled, we need to use all samples for training.
|
||||||
need_update = (step_idx + 1) % self.num_microbatches == 0
|
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||||
@ -488,7 +486,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def calculate_group_reward(self, rollout_group: Dict[str, Any]) -> Dict[str, Any]:
|
def calculate_reward(self, rollout: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Calculate the group reward for the given rollout group.
|
Calculate the group reward for the given rollout group.
|
||||||
|
|
||||||
@ -507,20 +505,20 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
Returns:
|
Returns:
|
||||||
Dict[str, Any]: The new group data with calculated reward.
|
Dict[str, Any]: The new group data with calculated reward.
|
||||||
"""
|
"""
|
||||||
reward_group = self.reward_model(
|
reward_model_output = self.reward_model(
|
||||||
rollout_group["input_ids"],
|
rollout["input_ids"],
|
||||||
gt_answer=rollout_group["gt_answer"],
|
gt_answer=rollout["gt_answer"],
|
||||||
response_idx=rollout_group["response_idx"],
|
response_idx=rollout["response_idx"],
|
||||||
)
|
)
|
||||||
# [num_of_generation]
|
# [num_of_generation]
|
||||||
reward = torch.tensor([value[0] for value in reward_group]).to(rollout_group["input_ids"].device)
|
reward = torch.tensor([value[0] for value in reward_model_output]).to(rollout["input_ids"].device)
|
||||||
format_acc = torch.tensor([value[1] for value in reward_group]).to(rollout_group["input_ids"].device)
|
format_acc = torch.tensor([value[1] for value in reward_model_output]).to(rollout["input_ids"].device)
|
||||||
ans_acc = torch.tensor([value[2] for value in reward_group]).to(rollout_group["input_ids"].device)
|
ans_acc = torch.tensor([value[2] for value in reward_model_output]).to(rollout["input_ids"].device)
|
||||||
|
|
||||||
rollout_group["reward"] = reward.view((-1, 1))
|
rollout["reward"] = reward.view((-1, 1))
|
||||||
rollout_group["format_acc"] = format_acc.view((-1, 1))
|
rollout["format_acc"] = format_acc.view((-1, 1))
|
||||||
rollout_group["ans_acc"] = ans_acc.view((-1, 1))
|
rollout["ans_acc"] = ans_acc.view((-1, 1))
|
||||||
return rollout_group
|
return rollout
|
||||||
|
|
||||||
def prompt_level_filtering(self, rollout_group: Dict[str, Any]) -> Dict[str, Any]:
|
def prompt_level_filtering(self, rollout_group: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user