update pad seq (#6303)

Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
Tong Li 2025-05-13 16:51:27 +08:00 committed by GitHub
parent eb6b5dd62e
commit b920af427b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 3 additions and 28 deletions

View File

@ -16,7 +16,7 @@ from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from .comm import ray_broadcast_tensor_dict
from .utils import bind_batch, pad_batch, post_recv, unbind_batch
from .utils import bind_batch, post_recv, unbind_batch
class BaseConsumer:
@ -125,9 +125,6 @@ class BaseConsumer:
batches = self.buffer[
self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size
]
batch = pad_batch(
batches
) # when `imbs` is smaller than `tMbs`, samples may have differ in size, need to pad before stacking
batch = bind_batch(batches)
batch = post_recv(batch)
loss, num_excessive_prompts = self.step(i, pbar, **batch)

View File

@ -236,7 +236,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
log_probs.append(p)
# pad them
max_len = max(out_len)
max_len = self.generate_config.max_tokens
action_mask = torch.ones(len(out_tokens), max_len, dtype=attention_mask.dtype)
for i, new_token_ids in enumerate(out_tokens):

View File

@ -79,7 +79,7 @@ class BaseProducer:
else:
raise ValueError(f"Unexpected backend {backend}")
self.consumer_pp_size = consumer_plugin_config["pp_size"] # consumer pp size
self.consumer_pp_size = consumer_plugin_config.get("pp_size", 1) # consumer pp size
def setup(self) -> None:
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")

View File

@ -1,4 +1,3 @@
from collections import defaultdict
from typing import Any, Dict, List
import torch
@ -27,27 +26,6 @@ def bind_batch(batches: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor
return batch
def pad_batch(batches: List[Dict[str, torch.Tensor]], tokenizer: Any = None) -> List[Dict[str, torch.Tensor]]:
max_len = defaultdict(int)
for sample in batches:
for k in sample:
if k in ["input_ids", "attention_mask", "action_log_probs", "action_mask"]:
max_len[k] = max(max_len[k], sample[k].size(-1))
for idx, sample in enumerate(batches):
for k in sample:
if k in ["input_ids", "attention_mask", "action_log_probs", "action_mask"]:
# right pad with 0s
if k in ["attention_mask", "action_mask"]:
batches[idx][k] = torch.nn.functional.pad(
batches[idx][k], (0, max_len[k] - batches[idx][k].size(-1)), "constant", False
)
else:
batches[idx][k] = torch.nn.functional.pad(
batches[idx][k], (0, max_len[k] - batches[idx][k].size(-1)), "constant", 0
)
return batches
def pre_send(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
# compress mask to save bandwidth
if "attention_mask" in batch: