mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-08-14 14:13:38 +00:00
feat: update embeddings
This commit is contained in:
parent
2ac939529d
commit
2daecd6066
@ -8,8 +8,8 @@ save_name: "zpn/vicuna-lora"
|
|||||||
streaming: false
|
streaming: false
|
||||||
num_proc: 64
|
num_proc: 64
|
||||||
dataset_path: "data"
|
dataset_path: "data"
|
||||||
max_length: 512
|
max_length: 1024
|
||||||
batch_size: 8
|
batch_size: 4
|
||||||
|
|
||||||
# train dynamics
|
# train dynamics
|
||||||
lr: 5.0e-5
|
lr: 5.0e-5
|
||||||
@ -22,8 +22,7 @@ lora: true
|
|||||||
warmup_steps: 100
|
warmup_steps: 100
|
||||||
|
|
||||||
# logging
|
# logging
|
||||||
wandb: false
|
wandb: true
|
||||||
wandb_entity: zanussbaum
|
wandb_entity: vicuna
|
||||||
wandb_project: llama
|
wandb_project_name: vicuna
|
||||||
seed: 42
|
seed: 42
|
||||||
|
|
23
train.py
23
train.py
@ -47,11 +47,10 @@ def train(accelerator, config):
|
|||||||
accelerator.print(f"Using {accelerator.num_processes} GPUs")
|
accelerator.print(f"Using {accelerator.num_processes} GPUs")
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'])
|
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'])
|
||||||
# llama has no pad token, set it to eos
|
# llama has no pad token, set it to new token
|
||||||
if tokenizer.pad_token is None:
|
if tokenizer.pad_token is None:
|
||||||
# these tokens are already in the vocab, just not mapped correctly
|
# these tokens are already in the vocab, just not mapped correctly
|
||||||
tokenizer.add_special_tokens({"bos_token": "<s>", "eos_token": "</s>"})
|
added_tokens = tokenizer.add_special_tokens({"bos_token": "<s>", "eos_token": "</s>", "pad_token": "<pad>"})
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
|
|
||||||
|
|
||||||
with accelerator.main_process_first():
|
with accelerator.main_process_first():
|
||||||
@ -62,6 +61,9 @@ def train(accelerator, config):
|
|||||||
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
|
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
|
||||||
use_cache=False if checkpoint else True,
|
use_cache=False if checkpoint else True,
|
||||||
trust_remote_code=True)
|
trust_remote_code=True)
|
||||||
|
|
||||||
|
if added_tokens > 0:
|
||||||
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
if checkpoint:
|
if checkpoint:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
@ -108,21 +110,28 @@ def train(accelerator, config):
|
|||||||
|
|
||||||
train_loss = MeanMetric().to(model.device)
|
train_loss = MeanMetric().to(model.device)
|
||||||
|
|
||||||
|
if accelerator.state.deepspeed_plugin is not None:
|
||||||
|
gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[
|
||||||
|
"gradient_accumulation_steps"
|
||||||
|
]
|
||||||
|
|
||||||
for step, batch in enumerate(tqdm(train_dataloader)):
|
for step, batch in enumerate(tqdm(train_dataloader)):
|
||||||
model.train()
|
model.train()
|
||||||
outputs = model(**batch)
|
outputs = model(**batch)
|
||||||
loss = outputs.loss
|
loss = outputs.loss
|
||||||
|
loss = loss / gradient_accumulation_steps
|
||||||
|
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
# log LR in case something weird happens
|
# log LR in case something weird happens
|
||||||
if step % (config["eval_every"] // 10) == 0:
|
if step > 0 and step % (config["eval_every"] // 10) == 0:
|
||||||
if config["wandb"]:
|
if config["wandb"]:
|
||||||
accelerator.log({"lr": scheduler.get_last_lr()[0]}, step=step)
|
accelerator.log({"lr": scheduler.get_last_lr()[0]}, step=step)
|
||||||
|
|
||||||
scheduler.step()
|
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
||||||
optimizer.zero_grad()
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
loss_values = accelerator.gather_for_metrics({"loss": loss.detach()})
|
loss_values = accelerator.gather_for_metrics({"loss": loss.detach()})
|
||||||
train_loss.update(loss_values["loss"])
|
train_loss.update(loss_values["loss"])
|
||||||
|
Loading…
Reference in New Issue
Block a user