diff --git a/configs/deepspeed/ds_config_pythiaseek.json b/configs/deepspeed/ds_config_pythiaseek.json new file mode 100644 index 00000000..e6ec172d --- /dev/null +++ b/configs/deepspeed/ds_config_pythiaseek.json @@ -0,0 +1,39 @@ +{ + "train_batch_size": "auto", + "gradient_accumulation_steps": "auto", + "train_micro_batch_size_per_gpu": "auto", + "fp16": { + "enabled": "auto", + "min_loss_scale": 1, + "loss_scale_window": 1000, + "hysteresis": 2, + "initial_scale_power": 32 + }, + "bf16": { + "enabled": "auto" + }, + "gradient_clipping": 1.0, + "zero_optimization": { + "stage": 2, + "offload_param": { + "device": "none" + }, + "offload_optimizer": { + "device": "none" + }, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "contiguous_gradients": true + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": [ + 0.9, + 0.999 + ], + "eps": 1e-08 + } + } +} \ No newline at end of file diff --git a/configs/train/finetune_pythiaseek.yaml b/configs/train/finetune_pythiaseek.yaml new file mode 100644 index 00000000..4c50b724 --- /dev/null +++ b/configs/train/finetune_pythiaseek.yaml @@ -0,0 +1,46 @@ +# model/tokenizer +model_name: "EleutherAI/pythia-1b" +tokenizer_name: "EleutherAI/pythia-1b" +version: null +gradient_checkpointing: true +save_name: "nomic-ai/pythiaseek-large-bs" +push_to_hub: false +encoder_dim: 384 +learnable_alpha: true +cross_attn_layer: 9 +freeze_pretrained: false + +# dataset +streaming: false +num_proc: 604 +dataset_path: "/home/paperspace/gpt4all/gpt4all/index/squad_supplemented_train" +max_length: 1024 +batch_size: 32 +pct_test: 0.05 +q_column: "question" +a_column: "answers" +encoder_column: "neighbor_embeddings" + + +# train dynamics +lr: 1.0e-4 +min_lr: 0 +weight_decay: 0.0 +eval_every: 50 +save_every: -1 +log_grads_every: 100 +log_lr_every: 10 +output_dir: "ckpts/learnable_alpha" +checkpoint: null +lora: false +warmup_steps: 500 +num_epochs: 5 +debug: false +scheduler: false + +# logging +wandb: true +wandb_entity: gpt4all +wandb_project_name: retrieval +seed: 42 +