From b920af427bb47803482ab8a41c4517b2b07b7f59 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Tue, 13 May 2025 16:51:27 +0800 Subject: [PATCH] update pad seq (#6303) Co-authored-by: Tong Li --- .../coati/distributed/consumer.py | 5 +---- .../coati/distributed/inference_backend.py | 2 +- .../coati/distributed/producer.py | 2 +- .../ColossalChat/coati/distributed/utils.py | 22 ------------------- 4 files changed, 3 insertions(+), 28 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 1cebcb40e..9503d65eb 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -16,7 +16,7 @@ from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device from .comm import ray_broadcast_tensor_dict -from .utils import bind_batch, pad_batch, post_recv, unbind_batch +from .utils import bind_batch, post_recv, unbind_batch class BaseConsumer: @@ -125,9 +125,6 @@ class BaseConsumer: batches = self.buffer[ self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size ] - batch = pad_batch( - batches - ) # when `imbs` is smaller than `tMbs`, samples may have differ in size, need to pad before stacking batch = bind_batch(batches) batch = post_recv(batch) loss, num_excessive_prompts = self.step(i, pbar, **batch) diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index 7d32cd52a..8ff4b15bf 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -236,7 +236,7 @@ class VLLMInferenceBackend(BaseInferenceBackend): log_probs.append(p) # pad them - max_len = max(out_len) + max_len = self.generate_config.max_tokens action_mask = torch.ones(len(out_tokens), max_len, dtype=attention_mask.dtype) for i, new_token_ids in enumerate(out_tokens): diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index a2d675870..ffe3a4428 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -79,7 +79,7 @@ class BaseProducer: else: raise ValueError(f"Unexpected backend {backend}") - self.consumer_pp_size = consumer_plugin_config["pp_size"] # consumer pp size + self.consumer_pp_size = consumer_plugin_config.get("pp_size", 1) # consumer pp size def setup(self) -> None: cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}") diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index ce4923dc4..5f7879669 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -1,4 +1,3 @@ -from collections import defaultdict from typing import Any, Dict, List import torch @@ -27,27 +26,6 @@ def bind_batch(batches: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor return batch -def pad_batch(batches: List[Dict[str, torch.Tensor]], tokenizer: Any = None) -> List[Dict[str, torch.Tensor]]: - max_len = defaultdict(int) - for sample in batches: - for k in sample: - if k in ["input_ids", "attention_mask", "action_log_probs", "action_mask"]: - max_len[k] = max(max_len[k], sample[k].size(-1)) - for idx, sample in enumerate(batches): - for k in sample: - if k in ["input_ids", "attention_mask", "action_log_probs", "action_mask"]: - # right pad with 0s - if k in ["attention_mask", "action_mask"]: - batches[idx][k] = torch.nn.functional.pad( - batches[idx][k], (0, max_len[k] - batches[idx][k].size(-1)), "constant", False - ) - else: - batches[idx][k] = torch.nn.functional.pad( - batches[idx][k], (0, max_len[k] - batches[idx][k].size(-1)), "constant", 0 - ) - return batches - - def pre_send(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # compress mask to save bandwidth if "attention_mask" in batch: