mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[llama] update training script (#5360)
* [llama] update training script * [doc] polish docstr
This commit is contained in:
@@ -58,6 +58,7 @@ class DataCollatorForSupervisedDataset(object):
|
||||
tokenizer: PreTrainedTokenizer
|
||||
max_length: int = 4096
|
||||
ignore_index: int = -100
|
||||
padding: str = "max_length"
|
||||
|
||||
def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
@@ -102,10 +103,11 @@ class DataCollatorForSupervisedDataset(object):
|
||||
batch_first=True,
|
||||
padding_value=self.ignore_index,
|
||||
) # (bsz, max_len)
|
||||
# pad to max
|
||||
to_pad = self.max_length - input_ids.size(1)
|
||||
input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
|
||||
labels = F.pad(labels, (0, to_pad), value=self.ignore_index)
|
||||
if self.padding == "max_length":
|
||||
# pad to max
|
||||
to_pad = self.max_length - input_ids.size(1)
|
||||
input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
|
||||
labels = F.pad(labels, (0, to_pad), value=self.ignore_index)
|
||||
elif self.tokenizer.padding_side == "left":
|
||||
reversed_input_ids = [seq.flip(dims=(0,)) for seq in batch_input_ids]
|
||||
reversed_input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
|
Reference in New Issue
Block a user