#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Dataloader for sft, dpo, ppo """ import os from dataclasses import dataclass from typing import Dict, Iterator, List, Optional, Sequence, Union import jsonlines import torch import torch.nn.functional as F from coati.dataset.utils import chuncate_sequence, pad_to_max_len from datasets import Dataset as HFDataset from datasets import dataset_dict, load_from_disk from torch.utils.data import ConcatDataset, Dataset, DistributedSampler from transformers.tokenization_utils import PreTrainedTokenizer DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset] PathType = Union[str, os.PathLike] def load_tokenized_dataset( dataset_paths: Union[PathType, List[PathType]], mode: str = "train", **kwargs ) -> Optional[DatasetType]: """ Load pre-tokenized dataset. Each instance of dataset is a dictionary with `{'input_ids': List[int], 'labels': List[int], sequence: str}` format. """ if not dataset_paths: return None mode_map = kwargs.get("mode_map", {"train": "train", "dev": "validation", "test": "test"}) assert mode in tuple(mode_map), f"Unsupported mode {mode}, it must be in {tuple(mode_map)}" if isinstance(dataset_paths, (str, os.PathLike)): dataset_paths = [dataset_paths] datasets = [] # `List[datasets.dataset_dict.Dataset]` for ds_path in dataset_paths: ds_path = os.path.abspath(ds_path) assert os.path.exists(ds_path), f"Not existed file path {ds_path}" ds_dict = load_from_disk(dataset_path=ds_path, keep_in_memory=False) if isinstance(ds_dict, HFDataset): datasets.append(ds_dict) else: if mode_map[mode] in ds_dict: datasets.append(ds_dict[mode_map[mode]]) if len(datasets) == 0: return None if len(datasets) == 1: return datasets.pop() return ConcatDataset(datasets=datasets) @dataclass class DataCollatorForSupervisedDataset(object): """ Collate instances for supervised dataset. Each instance is a tokenized dictionary with fields `input_ids`(List[int]), `labels`(List[int]) and `sequence`(str). """ tokenizer: PreTrainedTokenizer max_length: int = 4096 ignore_index: int = -100 def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]: """ Args: instances (`Sequence[Dict[str, List[int]]]`): Mini-batch samples, each sample is stored in an individual dictionary. Returns: (`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`: `input_ids`: `torch.Tensor` of shape (bsz, max_len); `attention_mask`: `torch.BoolTensor` of shape (bsz, max_len); `labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`. """ assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, ( f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, " f"but now `{self.tokenizer.pad_token_id}`" ) # `List[torch.Tensor]` batch_input_ids = [ ( torch.LongTensor(instance["input_ids"][: self.max_length]) if len(instance["input_ids"]) > self.max_length else torch.LongTensor(instance["input_ids"]) ) for instance in instances ] batch_labels = [ ( torch.LongTensor(instance["labels"][: self.max_length]) if len(instance["labels"]) > self.max_length else torch.LongTensor(instance["labels"]) ) for instance in instances ] if self.tokenizer.padding_side == "right": input_ids = torch.nn.utils.rnn.pad_sequence( sequences=batch_input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id, ) # (bsz, max_len) labels = torch.nn.utils.rnn.pad_sequence( sequences=batch_labels, batch_first=True, padding_value=self.ignore_index, ) # (bsz, max_len) # pad to max to_pad = self.max_length - input_ids.size(1) input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id) labels = F.pad(labels, (0, to_pad), value=self.ignore_index) elif self.tokenizer.padding_side == "left": reversed_input_ids = [seq.flip(dims=(0,)) for seq in batch_input_ids] reversed_input_ids = torch.nn.utils.rnn.pad_sequence( sequences=reversed_input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id, ) # (bsz, max_len) input_ids = torch.flip(reversed_input_ids, dims=(1,)) # (bsz, max_len) reversed_labels = [seq.flip(dims=(0,)) for seq in batch_labels] reversed_labels = torch.nn.utils.rnn.pad_sequence( sequences=reversed_labels, batch_first=True, padding_value=self.ignore_index, ) # (bsz, max_len) labels = torch.flip(reversed_labels, dims=(1,)) # (bsz, max_len) else: raise RuntimeError( f"`{self.tokenizer.__class__.__name__}.padding_side` can only be `left` or `right`, " f"but now `{self.tokenizer.padding_side}`" ) attention_mask = input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len) return dict(input_ids=input_ids, attention_mask=attention_mask, labels=labels) @dataclass class DataCollatorForPromptDataset(DataCollatorForSupervisedDataset): def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]: """ Args: instances (`Sequence[Dict[str, List[int]]]`): Mini-batch samples, each sample is stored in an individual dictionary. Returns: (`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`: `input_ids`: `torch.Tensor` of shape (bsz, max_len); `attention_mask`: `torch.BoolTensor` of shape (bsz, max_len); """ gt_answer = [ins.get("gt_answer", None) for ins in instances] instances = [{"input_ids": ins["input_ids"], "labels": ins["input_ids"]} for ins in instances] ret = super().__call__(instances=instances) input_ids = F.pad( ret["input_ids"], (self.max_length - ret["input_ids"].size(1), 0), value=self.tokenizer.pad_token_id ) attention_mask = F.pad(ret["attention_mask"], (self.max_length - ret["attention_mask"].size(1), 0), value=False) return {"input_ids": input_ids, "attention_mask": attention_mask, "gt_answer": gt_answer} @dataclass class DataCollatorForPreferenceDataset(object): """ Collate instances for supervised dataset. Each instance is a tokenized dictionary with fields `input_ids`(List[int]), `labels`(List[int]) and `sequence`(str). """ tokenizer: PreTrainedTokenizer max_length: int = 4096 def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]: """ Args: instances (`Sequence[Dict[str, List[int]]]`): Mini-batch samples, each sample is stored in an individual dictionary. Returns: (`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`: `input_ids`: `torch.Tensor` of shape (bsz, max_len); `attention_mask`: `torch.BoolTensor` of shape (bsz, max_len); `labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`. """ assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, ( f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, " f"but now `{self.tokenizer.pad_token_id}`" ) ( chosen_input_ids, chosen_loss_mask, # [batch_size * seq_len] reject_input_ids, reject_loss_mask, ) = ( chuncate_sequence([ins["chosen_input_ids"] for ins in instances], self.max_length, torch.int64), chuncate_sequence([ins["chosen_loss_mask"] for ins in instances], self.max_length, torch.bool), chuncate_sequence([ins["rejected_input_ids"] for ins in instances], self.max_length, torch.int64), chuncate_sequence([ins["rejected_loss_mask"] for ins in instances], self.max_length, torch.bool), ) padding_side = self.tokenizer.padding_side chosen_attention_mask = [torch.ones_like(seq).bool() for seq in chosen_input_ids] reject_attention_mask = [torch.ones_like(seq).bool() for seq in reject_input_ids] ( chosen_input_ids, chosen_attention_mask, chosen_loss_mask, reject_input_ids, reject_attention_mask, reject_loss_mask, ) = ( pad_to_max_len(chosen_input_ids, self.max_length, self.tokenizer.pad_token_id, padding_side=padding_side), pad_to_max_len(chosen_attention_mask, self.max_length, False, padding_side=padding_side), pad_to_max_len(chosen_loss_mask, self.max_length, False, padding_side=padding_side), pad_to_max_len(reject_input_ids, self.max_length, self.tokenizer.pad_token_id, padding_side=padding_side), pad_to_max_len(reject_attention_mask, self.max_length, False, padding_side=padding_side), pad_to_max_len(reject_loss_mask, self.max_length, False, padding_side=padding_side), ) return dict( chosen_input_ids=chosen_input_ids, chosen_attention_mask=chosen_attention_mask, chosen_loss_mask=chosen_loss_mask, reject_input_ids=reject_input_ids, reject_attention_mask=reject_attention_mask, reject_loss_mask=reject_loss_mask, ) @dataclass class DataCollatorForKTODataset(object): """ Collate instances for kto dataset. Each input instance is a tokenized dictionary with fields `prompt`(List[int]), `completion`(List[int]) and `label`(bool). Each output instance is a tokenized dictionary with fields `kl_input_ids`(List[int]), `kl_attention_mask`(List[int]) and `kl_loss_mask`(List[int]). `input_ids`(List[int]), `attention_mask`(List[int]), `loss_mask`(List[int]) and `label`(bool). """ tokenizer: PreTrainedTokenizer max_length: int = 4096 ignore_index: int = -100 def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]: """ Args: instances (`Sequence[Dict[str, List[int]]]`): Mini-batch samples, each sample is stored in an individual dictionary contains the following fields: `prompt`(List[int]), `completion`(List[int]) and `label`(bool, if the sample is desirable or not). Returns: (`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`: `input_ids`: `torch.Tensor` of shape (bsz, max_len); `attention_mask`: `torch.BoolTensor` of shape (bsz, max_len); `labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`. """ assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, ( f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, " f"but now `{self.tokenizer.pad_token_id}`" ) # prepare the preference data prompt = [torch.LongTensor(instance["prompt"]) for instance in instances] prompt_zeros = [torch.zeros_like(t) for t in prompt] completion = [torch.LongTensor(instance["completion"]) for instance in instances] completion_ones = [torch.ones_like(t) for t in completion] label = [torch.tensor(instance["label"], dtype=torch.bool) for instance in instances] input_ids = [torch.cat([prompt[i], completion[i]], dim=-1) for i in range(len(instances))] loss_mask = [torch.cat([prompt_zeros[i], completion_ones[i]], dim=-1) for i in range(len(instances))] # right padding input_ids = torch.nn.utils.rnn.pad_sequence( sequences=input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id, ) # (bsz, max_len) loss_mask = torch.nn.utils.rnn.pad_sequence( sequences=loss_mask, batch_first=True, padding_value=0 ) # (bsz, max_len) to_pad = self.max_length - input_ids.size(1) input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id) loss_mask = F.pad(loss_mask, (0, to_pad), value=0) attention_mask = input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len) # prepare kt data kl_completion = completion[::-1] # y' kl_completion_ones = [torch.ones_like(t) for t in kl_completion] kl_input_ids = [torch.cat([prompt[i], kl_completion[i]], dim=-1) for i in range(len(instances))] kl_loss_mask = [torch.cat([prompt_zeros[i], kl_completion_ones[i]], dim=-1) for i in range(len(instances))] # right padding kl_input_ids = torch.nn.utils.rnn.pad_sequence( sequences=kl_input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id, ) # (bsz, max_len) kl_loss_mask = torch.nn.utils.rnn.pad_sequence( sequences=kl_loss_mask, batch_first=True, padding_value=0 ) # (bsz, max_len) to_pad = self.max_length - kl_input_ids.size(1) kl_input_ids = F.pad(kl_input_ids, (0, to_pad), value=self.tokenizer.pad_token_id) kl_loss_mask = F.pad(kl_loss_mask, (0, to_pad), value=0) kl_attention_mask = kl_input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len) data_dict = { "input_ids": input_ids, "attention_mask": attention_mask, "loss_mask": loss_mask, "label": torch.stack(label), "kl_input_ids": kl_input_ids, "kl_attention_mask": kl_attention_mask, "kl_loss_mask": kl_loss_mask, } return data_dict class StatefulDistributedSampler(DistributedSampler): def __init__( self, dataset: Dataset, num_replicas: Optional[int] = None, rank: Optional[int] = None, shuffle: bool = True, seed: int = 0, drop_last: bool = False, ) -> None: super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last) self.start_index: int = 0 def __iter__(self) -> Iterator: iterator = super().__iter__() indices = list(iterator) indices = indices[self.start_index :] return iter(indices) def __len__(self) -> int: return self.num_samples - self.start_index def set_start_index(self, start_index: int) -> None: self.start_index = start_index def apply_chat_template_and_mask( tokenizer: PreTrainedTokenizer, chat: List[Dict[str, str]], max_length: Optional[int] = None, system_prompt: str = None, padding: bool = True, truncation: bool = True, ignore_idx: int = -100, ) -> Dict[str, torch.Tensor]: if system_prompt is None: system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the tags, i.e., 123 .\n\n" system_element = { "role": "system", "content": system_prompt, } # Format for RL. gt_answer = None if "messages" in chat and "gt_answer" in chat: gt_answer = chat["gt_answer"] chat = [chat["messages"]] tokens = [] assistant_mask = [] for i, msg in enumerate(chat): msg_tokens = tokenizer.apply_chat_template([msg], tokenize=True, add_generation_prompt=True) # remove unexpected bos token if i > 0 and msg_tokens[0] == tokenizer.bos_token_id: msg_tokens = msg_tokens[1:] tokens.extend(msg_tokens) if msg["role"] == "assistant": assistant_mask.extend([True] * len(msg_tokens)) else: assistant_mask.extend([False] * len(msg_tokens)) attention_mask = [1] * len(tokens) if max_length is not None: if padding and len(tokens) < max_length: to_pad = max_length - len(tokens) # Left padding for generation. tokens = [tokenizer.pad_token_id] * to_pad + tokens assistant_mask = [False] * to_pad + assistant_mask attention_mask = [0] * to_pad + attention_mask if truncation and len(tokens) > max_length: tokens = tokens[:max_length] assistant_mask = assistant_mask[:max_length] attention_mask = attention_mask[:max_length] input_ids = torch.tensor(tokens, dtype=torch.long) attention_mask = torch.tensor(attention_mask, dtype=torch.long) labels = input_ids.clone() labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx if gt_answer is not None: gt_answer = tokenizer.encode( gt_answer, padding="max_length", truncation=True, max_length=128, return_tensors="pt" ) gt_answer = gt_answer.squeeze(1) return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "gt_answer": gt_answer} return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, } class RawConversationDataset(Dataset): """ Raw conversation dataset. Each instance is a dictionary with fields `system`, `roles`, `messages`, `offset`, `sep_style`, `seps`. """ def __init__(self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: int, system_prompt: str) -> None: self.tokenizer = tokenizer self.raw_texts = [] with jsonlines.open(input_file) as f: for line in f: self.raw_texts.append(line) self.tokenized_texts = [None] * len(self.raw_texts) self.max_length = max_length self.system_prompt = system_prompt def __len__(self) -> int: return len(self.raw_texts) def __getitem__(self, index: int): if self.tokenized_texts[index] is None: message = self.raw_texts[index] tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length, self.system_prompt) self.tokenized_texts[index] = dict(tokens) return self.tokenized_texts[index]