[llama] update training script (#5360)

* [llama] update training script

* [doc] polish docstr
This commit is contained in:
Hongxin Liu
2024-02-05 16:33:18 +08:00
committed by GitHub
parent 6c0fa7b9a8
commit 73f9f23fc6
5 changed files with 105 additions and 475 deletions

View File

@@ -58,6 +58,7 @@ class DataCollatorForSupervisedDataset(object):
tokenizer: PreTrainedTokenizer
max_length: int = 4096
ignore_index: int = -100
padding: str = "max_length"
def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
"""
@@ -102,10 +103,11 @@ class DataCollatorForSupervisedDataset(object):
batch_first=True,
padding_value=self.ignore_index,
) # (bsz, max_len)
# pad to max
to_pad = self.max_length - input_ids.size(1)
input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
labels = F.pad(labels, (0, to_pad), value=self.ignore_index)
if self.padding == "max_length":
# pad to max
to_pad = self.max_length - input_ids.size(1)
input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
labels = F.pad(labels, (0, to_pad), value=self.ignore_index)
elif self.tokenizer.padding_side == "left":
reversed_input_ids = [seq.flip(dims=(0,)) for seq in batch_input_ids]
reversed_input_ids = torch.nn.utils.rnn.pad_sequence(