mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
fix metric calculation
This commit is contained in:
@@ -120,24 +120,85 @@ class BaseConsumer:
|
||||
raw_batch = unbind_batch(
|
||||
ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}")
|
||||
)
|
||||
processed_batch = [
|
||||
self.prompt_level_filtering(self.calculate_group_reward(group)) for group in raw_batch
|
||||
]
|
||||
filtered_batch = [t for t in processed_batch if t is not None]
|
||||
recv_effective_count = 0
|
||||
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
|
||||
# we need to calculate the metrics before filtering here for logging
|
||||
for group in raw_batch:
|
||||
group_with_reward = self.calculate_group_reward(group)
|
||||
group_reward_mean = group_with_reward["reward"].mean().cpu().item()
|
||||
group_format_acc_mean = group_with_reward["format_acc"].mean().cpu().item()
|
||||
group_ans_acc_mean = group_with_reward["ans_acc"].mean().cpu().item()
|
||||
group_response_len = (
|
||||
(
|
||||
group_with_reward["response_idx"][:, 1]
|
||||
- group_with_reward["response_idx"][:, 0]
|
||||
+ 1
|
||||
)
|
||||
.type(torch.float32)
|
||||
.mean()
|
||||
.cpu()
|
||||
.item()
|
||||
)
|
||||
filtered_group = self.prompt_level_filtering(group_with_reward)
|
||||
recv_effective_count += 1 if filtered_group is not None else 0
|
||||
self.buffer.append(
|
||||
[
|
||||
filtered_group,
|
||||
group_reward_mean,
|
||||
group_format_acc_mean,
|
||||
group_ans_acc_mean,
|
||||
group_response_len,
|
||||
]
|
||||
)
|
||||
if self.filter_range is not None:
|
||||
print(
|
||||
f"[T{dist.get_rank()}] Filter recv data: {len(processed_batch)} -> {len(filtered_batch)}"
|
||||
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {recv_effective_count}"
|
||||
)
|
||||
# mapping the effective group to the raw group for indexing
|
||||
effective_group_to_raw_group_mapping = {}
|
||||
for buffer_idx in range(len(self.buffer)):
|
||||
if self.buffer[buffer_idx][0] is not None:
|
||||
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
|
||||
buffer_idx
|
||||
)
|
||||
|
||||
self.buffer.extend(filtered_batch)
|
||||
while len(self.buffer) >= self.dp_size * self.minibatch_size:
|
||||
batches = self.buffer[
|
||||
self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size
|
||||
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
|
||||
# on each dp_rank, we use minibatch_size effective samples to form a batch
|
||||
batches = [
|
||||
self.buffer[effective_group_to_raw_group_mapping[i]]
|
||||
for i in range(
|
||||
self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size
|
||||
)
|
||||
]
|
||||
batch = bind_batch(batches)
|
||||
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
|
||||
# each mini-batch use the first self.dp_size * minibatch_size effective samples
|
||||
raw_mini_batches = self.buffer[
|
||||
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
|
||||
] # include the last effective sample
|
||||
raw_mini_batches_metric_dict = {
|
||||
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
|
||||
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
|
||||
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
|
||||
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
|
||||
}
|
||||
batch = bind_batch([t[0] for t in batches])
|
||||
batch = post_recv(batch)
|
||||
loss = self.step(i, pbar, **batch)
|
||||
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
|
||||
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
|
||||
self.buffer = self.buffer[
|
||||
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
|
||||
]
|
||||
# recalculate the effective group to raw group mapping
|
||||
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
|
||||
effective_group_to_raw_group_mapping = {}
|
||||
for buffer_idx in range(len(self.buffer)):
|
||||
if self.buffer[buffer_idx][0] is not None:
|
||||
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
|
||||
buffer_idx
|
||||
)
|
||||
assert (
|
||||
len(effective_group_to_raw_group_mapping)
|
||||
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
|
||||
)
|
||||
if loss is not None:
|
||||
pbar.set_postfix({"loss": loss})
|
||||
i += 1
|
||||
|
Reference in New Issue
Block a user