mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 05:49:55 +00:00
@@ -61,8 +61,8 @@ def train(args):
|
||||
|
||||
# prepare for data and dataset
|
||||
data = load_dataset(args.dataset)
|
||||
train_data = data["train"].select(range(100))
|
||||
eval_data = data['test'].select(range(5))
|
||||
train_data = data["train"]
|
||||
eval_data = data['test']
|
||||
train_dataset = RewardDataset(train_data, tokenizer, max_len)
|
||||
eval_dataset = RewardDataset(eval_data, tokenizer, max_len)
|
||||
|
||||
@@ -93,7 +93,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--dataset', type=str, default='Dahoas/rm-static')
|
||||
parser.add_argument('--save_path', type=str, default='rm_ckpt.pth')
|
||||
parser.add_argument('--max_epochs', type=int, default=10)
|
||||
parser.add_argument('--max_epochs', type=int, default=1)
|
||||
parser.add_argument('--batch_size', type=int, default=4)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
args = parser.parse_args()
|
||||
|
Reference in New Issue
Block a user