mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
update pad seq (#6303)
Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
@@ -27,27 +26,6 @@ def bind_batch(batches: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor
|
||||
return batch
|
||||
|
||||
|
||||
def pad_batch(batches: List[Dict[str, torch.Tensor]], tokenizer: Any = None) -> List[Dict[str, torch.Tensor]]:
|
||||
max_len = defaultdict(int)
|
||||
for sample in batches:
|
||||
for k in sample:
|
||||
if k in ["input_ids", "attention_mask", "action_log_probs", "action_mask"]:
|
||||
max_len[k] = max(max_len[k], sample[k].size(-1))
|
||||
for idx, sample in enumerate(batches):
|
||||
for k in sample:
|
||||
if k in ["input_ids", "attention_mask", "action_log_probs", "action_mask"]:
|
||||
# right pad with 0s
|
||||
if k in ["attention_mask", "action_mask"]:
|
||||
batches[idx][k] = torch.nn.functional.pad(
|
||||
batches[idx][k], (0, max_len[k] - batches[idx][k].size(-1)), "constant", False
|
||||
)
|
||||
else:
|
||||
batches[idx][k] = torch.nn.functional.pad(
|
||||
batches[idx][k], (0, max_len[k] - batches[idx][k].size(-1)), "constant", 0
|
||||
)
|
||||
return batches
|
||||
|
||||
|
||||
def pre_send(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
# compress mask to save bandwidth
|
||||
if "attention_mask" in batch:
|
||||
|
Reference in New Issue
Block a user