mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[chat] add distributed impl (#6210)
This commit is contained in:
40
applications/ColossalChat/coati/distributed/utils.py
Normal file
40
applications/ColossalChat/coati/distributed/utils.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def unbind_batch(batch: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]:
|
||||
batches = []
|
||||
for k, v in batch.items():
|
||||
if len(batches) == 0:
|
||||
unbinded_tensors = v.unbind(0)
|
||||
batches = [{k: tensor} for tensor in unbinded_tensors]
|
||||
else:
|
||||
unbinded_tensors = v.unbind(0)
|
||||
assert len(batches) == len(unbinded_tensors)
|
||||
for i, tensor in enumerate(unbinded_tensors):
|
||||
batches[i][k] = tensor
|
||||
return batches
|
||||
|
||||
|
||||
def bind_batch(batches: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
||||
batch = {}
|
||||
for k in batches[0].keys():
|
||||
batch[k] = torch.stack([batch[k] for batch in batches], dim=0)
|
||||
return batch
|
||||
|
||||
|
||||
def pre_send(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
# compress attention_mask to save bandwidth
|
||||
if "attention_mask" in batch:
|
||||
attention_mask = batch["attention_mask"]
|
||||
batch["attention_mask"] = attention_mask.to(torch.bool)
|
||||
return batch
|
||||
|
||||
|
||||
def post_recv(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
# decompress attention_mask
|
||||
if "attention_mask" in batch:
|
||||
attention_mask = batch["attention_mask"]
|
||||
batch["attention_mask"] = attention_mask.to(torch.int)
|
||||
return batch
|
Reference in New Issue
Block a user