mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 05:04:47 +00:00
update pad seq (#6303)
Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
parent
eb6b5dd62e
commit
b920af427b
@ -16,7 +16,7 @@ from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .comm import ray_broadcast_tensor_dict
|
||||
from .utils import bind_batch, pad_batch, post_recv, unbind_batch
|
||||
from .utils import bind_batch, post_recv, unbind_batch
|
||||
|
||||
|
||||
class BaseConsumer:
|
||||
@ -125,9 +125,6 @@ class BaseConsumer:
|
||||
batches = self.buffer[
|
||||
self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size
|
||||
]
|
||||
batch = pad_batch(
|
||||
batches
|
||||
) # when `imbs` is smaller than `tMbs`, samples may have differ in size, need to pad before stacking
|
||||
batch = bind_batch(batches)
|
||||
batch = post_recv(batch)
|
||||
loss, num_excessive_prompts = self.step(i, pbar, **batch)
|
||||
|
@ -236,7 +236,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
log_probs.append(p)
|
||||
|
||||
# pad them
|
||||
max_len = max(out_len)
|
||||
max_len = self.generate_config.max_tokens
|
||||
action_mask = torch.ones(len(out_tokens), max_len, dtype=attention_mask.dtype)
|
||||
|
||||
for i, new_token_ids in enumerate(out_tokens):
|
||||
|
@ -79,7 +79,7 @@ class BaseProducer:
|
||||
else:
|
||||
raise ValueError(f"Unexpected backend {backend}")
|
||||
|
||||
self.consumer_pp_size = consumer_plugin_config["pp_size"] # consumer pp size
|
||||
self.consumer_pp_size = consumer_plugin_config.get("pp_size", 1) # consumer pp size
|
||||
|
||||
def setup(self) -> None:
|
||||
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
|
||||
|
@ -1,4 +1,3 @@
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
@ -27,27 +26,6 @@ def bind_batch(batches: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor
|
||||
return batch
|
||||
|
||||
|
||||
def pad_batch(batches: List[Dict[str, torch.Tensor]], tokenizer: Any = None) -> List[Dict[str, torch.Tensor]]:
|
||||
max_len = defaultdict(int)
|
||||
for sample in batches:
|
||||
for k in sample:
|
||||
if k in ["input_ids", "attention_mask", "action_log_probs", "action_mask"]:
|
||||
max_len[k] = max(max_len[k], sample[k].size(-1))
|
||||
for idx, sample in enumerate(batches):
|
||||
for k in sample:
|
||||
if k in ["input_ids", "attention_mask", "action_log_probs", "action_mask"]:
|
||||
# right pad with 0s
|
||||
if k in ["attention_mask", "action_mask"]:
|
||||
batches[idx][k] = torch.nn.functional.pad(
|
||||
batches[idx][k], (0, max_len[k] - batches[idx][k].size(-1)), "constant", False
|
||||
)
|
||||
else:
|
||||
batches[idx][k] = torch.nn.functional.pad(
|
||||
batches[idx][k], (0, max_len[k] - batches[idx][k].size(-1)), "constant", 0
|
||||
)
|
||||
return batches
|
||||
|
||||
|
||||
def pre_send(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
# compress mask to save bandwidth
|
||||
if "attention_mask" in batch:
|
||||
|
Loading…
Reference in New Issue
Block a user