[chatgpt]fix lora bug (#2974)

* fix lora bug

* polish
This commit is contained in:
BlueRum
2023-03-02 17:51:44 +08:00
committed by GitHub
parent 82149e9d1b
commit c9e27f0d1b
3 changed files with 8 additions and 6 deletions

View File

@@ -61,8 +61,8 @@ def train(args):
# prepare for data and dataset
data = load_dataset(args.dataset)
train_data = data["train"].select(range(100))
eval_data = data['test'].select(range(5))
train_data = data["train"]
eval_data = data['test']
train_dataset = RewardDataset(train_data, tokenizer, max_len)
eval_dataset = RewardDataset(eval_data, tokenizer, max_len)
@@ -93,7 +93,7 @@ if __name__ == '__main__':
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--dataset', type=str, default='Dahoas/rm-static')
parser.add_argument('--save_path', type=str, default='rm_ckpt.pth')
parser.add_argument('--max_epochs', type=int, default=10)
parser.add_argument('--max_epochs', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
args = parser.parse_args()