update pad seq (#6303)

Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
Tong Li
2025-05-13 16:51:27 +08:00
committed by GitHub
parent eb6b5dd62e
commit b920af427b
4 changed files with 3 additions and 28 deletions

View File

@@ -1,4 +1,3 @@
from collections import defaultdict
from typing import Any, Dict, List
import torch
@@ -27,27 +26,6 @@ def bind_batch(batches: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor
return batch
def pad_batch(batches: List[Dict[str, torch.Tensor]], tokenizer: Any = None) -> List[Dict[str, torch.Tensor]]:
max_len = defaultdict(int)
for sample in batches:
for k in sample:
if k in ["input_ids", "attention_mask", "action_log_probs", "action_mask"]:
max_len[k] = max(max_len[k], sample[k].size(-1))
for idx, sample in enumerate(batches):
for k in sample:
if k in ["input_ids", "attention_mask", "action_log_probs", "action_mask"]:
# right pad with 0s
if k in ["attention_mask", "action_mask"]:
batches[idx][k] = torch.nn.functional.pad(
batches[idx][k], (0, max_len[k] - batches[idx][k].size(-1)), "constant", False
)
else:
batches[idx][k] = torch.nn.functional.pad(
batches[idx][k], (0, max_len[k] - batches[idx][k].size(-1)), "constant", 0
)
return batches
def pre_send(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
# compress mask to save bandwidth
if "attention_mask" in batch: