ColossalAI/applications/Chat/coati/dataset/prompt_dataset.py
Wenhao Chen 9d02590c9a
[chat] refactor actor class (#3968)
* refactor: separate log_probs fn from Actor forward fn

* refactor: separate generate fn from Actor class

* feat: update unwrap_model and get_base_model
* unwrap_model returns model not wrapped by Strategy
* get_base_model returns HF model for Actor, Critic and RewardModel

* feat: simplify Strategy.prepare

* style: remove get_base_model method of Actor

* perf: tokenize text in batches

* refactor: move calc_action_log_probs to utils of model

* test: update test with new forward fn

* style: rename forward fn args

* fix: do not unwrap model in save_model fn of naive strategy

* test: add gemini test for train_prompts

* fix: fix _set_default_generate_kwargs
2023-06-13 13:31:56 +08:00

52 lines
1.7 KiB
Python

import copy
import random
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Callable, Dict, Sequence
import torch
import torch.distributed as dist
import transformers
from torch.utils.data import Dataset
from tqdm import tqdm
from colossalai.logging import get_dist_logger
from .utils import is_rank_0, jload
logger = get_dist_logger()
class PromptDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self,
data_path: str,
tokenizer: transformers.PreTrainedTokenizer,
max_datasets_size: int = None,
max_length: int = 96):
super(PromptDataset, self).__init__()
self.keyed_prompt = defaultdict(list)
logger.info("Loading data...")
list_data_dict = jload(data_path)
logger.info(f"Loaded {len(list_data_dict)} examples.")
if max_datasets_size is not None:
logger.info(f"Limiting dataset to {max_datasets_size} examples.")
list_data_dict = list_data_dict[:max_datasets_size]
instructions = [data_dict["instruction"] for data_dict in list_data_dict]
tokens = tokenizer(instructions,
return_tensors='pt',
max_length=max_length,
padding='max_length',
truncation=True)
for k, tensor in tokens.items():
self.keyed_prompt[k] = tensor.to(torch.cuda.current_device()).unbind()
def __len__(self):
return len(self.keyed_prompt["input_ids"])
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return {k: v[i] for k, v in self.keyed_prompt.items()}