mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-01 06:25:17 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -15,7 +15,7 @@ def _build_key_size_numel_dictionaries(keys, data):
|
||||
if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
offset = 0
|
||||
for key in keys:
|
||||
assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM'
|
||||
assert data[key].dim() < max_dim, "you should increase MAX_DATA_DIM"
|
||||
size = data[key].size()
|
||||
for i, s in enumerate(size):
|
||||
sizes[i + offset] = s
|
||||
@@ -23,9 +23,9 @@ def _build_key_size_numel_dictionaries(keys, data):
|
||||
|
||||
# Move to GPU and broadcast.
|
||||
sizes_cuda = torch.cuda.LongTensor(sizes)
|
||||
torch.distributed.broadcast(sizes_cuda,
|
||||
gpc.get_ranks_in_group(ParallelMode.TENSOR)[0],
|
||||
group=gpc.get_group(ParallelMode.TENSOR))
|
||||
torch.distributed.broadcast(
|
||||
sizes_cuda, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], group=gpc.get_group(ParallelMode.TENSOR)
|
||||
)
|
||||
|
||||
# Move back to cpu and unpack.
|
||||
sizes_cpu = sizes_cuda.cpu()
|
||||
@@ -73,9 +73,9 @@ def broadcast_data(keys, data, datatype):
|
||||
flatten_data = torch.empty(total_numel, device=torch.cuda.current_device(), dtype=datatype)
|
||||
|
||||
# Broadcast
|
||||
torch.distributed.broadcast(flatten_data,
|
||||
gpc.get_ranks_in_group(ParallelMode.TENSOR)[0],
|
||||
group=gpc.get_group(ParallelMode.TENSOR))
|
||||
torch.distributed.broadcast(
|
||||
flatten_data, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], group=gpc.get_group(ParallelMode.TENSOR)
|
||||
)
|
||||
|
||||
# Unpack
|
||||
output = {}
|
||||
@@ -93,7 +93,7 @@ def get_batch(data_iterator):
|
||||
"""Build the batch."""
|
||||
|
||||
# Items and their type.
|
||||
keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
|
||||
keys = ["text", "types", "labels", "is_random", "loss_mask", "padding_mask"]
|
||||
datatype = torch.int64
|
||||
|
||||
# Broadcast data.
|
||||
@@ -104,12 +104,12 @@ def get_batch(data_iterator):
|
||||
data_b = broadcast_data(keys, data, datatype)
|
||||
|
||||
# Unpack.
|
||||
tokens = data_b['text'].long()
|
||||
types = data_b['types'].long()
|
||||
sentence_order = data_b['is_random'].long()
|
||||
loss_mask = data_b['loss_mask'].float()
|
||||
lm_labels = data_b['labels'].long()
|
||||
padding_mask = data_b['padding_mask'].long()
|
||||
tokens = data_b["text"].long()
|
||||
types = data_b["types"].long()
|
||||
sentence_order = data_b["is_random"].long()
|
||||
loss_mask = data_b["loss_mask"].float()
|
||||
lm_labels = data_b["labels"].long()
|
||||
padding_mask = data_b["padding_mask"].long()
|
||||
|
||||
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
|
||||
|
||||
@@ -118,7 +118,7 @@ def get_batch_for_sequence_parallel(data_iterator):
|
||||
"""Build the batch."""
|
||||
|
||||
# Items and their type.
|
||||
keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
|
||||
keys = ["text", "types", "labels", "is_random", "loss_mask", "padding_mask"]
|
||||
datatype = torch.int64
|
||||
|
||||
# Broadcast data.
|
||||
@@ -134,24 +134,23 @@ def get_batch_for_sequence_parallel(data_iterator):
|
||||
global_rank = torch.distributed.get_rank()
|
||||
local_world_size = 1 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_world_size(ParallelMode.TENSOR)
|
||||
local_rank = global_rank % local_world_size
|
||||
seq_length = data_b['text'].size(1)
|
||||
seq_length = data_b["text"].size(1)
|
||||
sub_seq_length = seq_length // local_world_size
|
||||
sub_seq_start = local_rank * sub_seq_length
|
||||
sub_seq_end = (local_rank + 1) * sub_seq_length
|
||||
#
|
||||
# # Unpack.
|
||||
tokens = data_b['text'][:, sub_seq_start:sub_seq_end].long()
|
||||
types = data_b['types'][:, sub_seq_start:sub_seq_end].long()
|
||||
sentence_order = data_b['is_random'].long()
|
||||
loss_mask = data_b['loss_mask'][:, sub_seq_start:sub_seq_end].float()
|
||||
lm_labels = data_b['labels'][:, sub_seq_start:sub_seq_end].long()
|
||||
padding_mask = data_b['padding_mask'].long()
|
||||
tokens = data_b["text"][:, sub_seq_start:sub_seq_end].long()
|
||||
types = data_b["types"][:, sub_seq_start:sub_seq_end].long()
|
||||
sentence_order = data_b["is_random"].long()
|
||||
loss_mask = data_b["loss_mask"][:, sub_seq_start:sub_seq_end].float()
|
||||
lm_labels = data_b["labels"][:, sub_seq_start:sub_seq_end].long()
|
||||
padding_mask = data_b["padding_mask"].long()
|
||||
|
||||
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
|
||||
|
||||
|
||||
class SequenceParallelDataIterator:
|
||||
|
||||
def __init__(self, data_iter):
|
||||
self.data_iter = data_iter
|
||||
|
||||
|
Reference in New Issue
Block a user