mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-06 20:09:58 +00:00
chore: pull out common dist print fn
This commit is contained in:
parent
e255e0a805
commit
4671f4e82f
@ -3,12 +3,13 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from gpt4all.utils.read import read_config
|
from gpt4all.utils.read import read_config
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from gpt4all.utils.data import load_data_for_inference
|
from gpt4all.utils.data import load_data_for_inference
|
||||||
|
from gpt4all.utils.distributed_utils import rank0_print
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from transformers.trainer_pt_utils import nested_numpify
|
from transformers.trainer_pt_utils import nested_numpify
|
||||||
from transformers import DefaultDataCollator
|
from transformers import DefaultDataCollator
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -21,56 +22,64 @@ def calc_cross_entropy_no_reduction(lm_logits, labels):
|
|||||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
# Flatten the tokens
|
# Flatten the tokens
|
||||||
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
loss_fct = nn.CrossEntropyLoss(reduction="none")
|
||||||
loss = loss_fct(shift_logits.permute(0, 2, 1), shift_labels).mean(dim=1)
|
loss = loss_fct(shift_logits.permute(0, 2, 1), shift_labels).mean(dim=1)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def rank0_print(msg):
|
|
||||||
if dist.get_rank() == 0:
|
|
||||||
print(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def inference(config):
|
def inference(config):
|
||||||
set_seed(config['seed'])
|
set_seed(config["seed"])
|
||||||
|
|
||||||
rank0_print(f"World size: {dist.get_world_size()}")
|
rank0_print(f"World size: {dist.get_world_size()}")
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length'])
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
config["tokenizer_name"], model_max_length=config["max_length"]
|
||||||
|
)
|
||||||
# llama has no pad token, set it to new token
|
# llama has no pad token, set it to new token
|
||||||
if tokenizer.pad_token is None:
|
if tokenizer.pad_token is None:
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
|
||||||
train_dataset, val_dataset = load_data_for_inference(config, tokenizer)
|
train_dataset, val_dataset = load_data_for_inference(config, tokenizer)
|
||||||
|
|
||||||
num_processes = dist.get_world_size()
|
num_processes = dist.get_world_size()
|
||||||
local_rank = dist.get_rank()
|
local_rank = dist.get_rank()
|
||||||
|
|
||||||
train_sampler = DistributedSampler(train_dataset, shuffle=False, drop_last=True, num_replicas=num_processes, rank=local_rank)
|
train_sampler = DistributedSampler(
|
||||||
|
train_dataset,
|
||||||
|
shuffle=False,
|
||||||
|
drop_last=True,
|
||||||
|
num_replicas=num_processes,
|
||||||
|
rank=local_rank,
|
||||||
|
)
|
||||||
train_dataloader = DataLoader(
|
train_dataloader = DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
collate_fn=DefaultDataCollator(),
|
collate_fn=DefaultDataCollator(),
|
||||||
batch_size=config["batch_size"],
|
batch_size=config["batch_size"],
|
||||||
sampler=train_sampler,
|
sampler=train_sampler,
|
||||||
drop_last=True
|
drop_last=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
val_sampler = DistributedSampler(val_dataset, shuffle=False, drop_last=True, num_replicas=num_processes, rank=local_rank)
|
val_sampler = DistributedSampler(
|
||||||
|
val_dataset,
|
||||||
|
shuffle=False,
|
||||||
|
drop_last=True,
|
||||||
|
num_replicas=num_processes,
|
||||||
|
rank=local_rank,
|
||||||
|
)
|
||||||
val_dataloader = DataLoader(
|
val_dataloader = DataLoader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
collate_fn=DefaultDataCollator(),
|
collate_fn=DefaultDataCollator(),
|
||||||
batch_size=config["batch_size"],
|
batch_size=config["batch_size"],
|
||||||
sampler=val_sampler,
|
sampler=val_sampler,
|
||||||
drop_last=True
|
drop_last=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
|
config["model_name"],
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
)
|
)
|
||||||
model.to(f"cuda:{local_rank}")
|
model.to(f"cuda:{local_rank}")
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -78,7 +87,11 @@ def inference(config):
|
|||||||
for batch in tqdm(train_dataloader, disable=local_rank != 0):
|
for batch in tqdm(train_dataloader, disable=local_rank != 0):
|
||||||
batch["input_ids"] = batch["input_ids"].to(f"cuda:{local_rank}")
|
batch["input_ids"] = batch["input_ids"].to(f"cuda:{local_rank}")
|
||||||
batch["labels"] = batch["labels"].to(f"cuda:{local_rank}")
|
batch["labels"] = batch["labels"].to(f"cuda:{local_rank}")
|
||||||
outputs = model(input_ids=batch["input_ids"], labels=batch["labels"], output_hidden_states=True)
|
outputs = model(
|
||||||
|
input_ids=batch["input_ids"],
|
||||||
|
labels=batch["labels"],
|
||||||
|
output_hidden_states=True,
|
||||||
|
)
|
||||||
loss = calc_cross_entropy_no_reduction(outputs.logits, batch["labels"])
|
loss = calc_cross_entropy_no_reduction(outputs.logits, batch["labels"])
|
||||||
train_outputs["loss"].extend(loss)
|
train_outputs["loss"].extend(loss)
|
||||||
|
|
||||||
@ -101,7 +114,9 @@ def inference(config):
|
|||||||
sequence_lengths.append(len(item) - 1)
|
sequence_lengths.append(len(item) - 1)
|
||||||
|
|
||||||
sequence_lengths = torch.tensor(sequence_lengths)
|
sequence_lengths = torch.tensor(sequence_lengths)
|
||||||
pooled_logits = embeddings[torch.arange(batch_size, device=embeddings.device), sequence_lengths]
|
pooled_logits = embeddings[
|
||||||
|
torch.arange(batch_size, device=embeddings.device), sequence_lengths
|
||||||
|
]
|
||||||
|
|
||||||
train_outputs["embeddings"].append(pooled_logits)
|
train_outputs["embeddings"].append(pooled_logits)
|
||||||
train_outputs["index"].extend(batch["index"].to(model.device))
|
train_outputs["index"].extend(batch["index"].to(model.device))
|
||||||
@ -120,22 +135,33 @@ def inference(config):
|
|||||||
# compute mask in pyarrow since it's super fast
|
# compute mask in pyarrow since it's super fast
|
||||||
# ty @bmschmidt for showing me this!
|
# ty @bmschmidt for showing me this!
|
||||||
table = train_dataset.data
|
table = train_dataset.data
|
||||||
mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32()))
|
mask = pc.is_in(table["index"], value_set=pa.array(curr_idx, pa.int32()))
|
||||||
filtered_table = table.filter(mask)
|
filtered_table = table.filter(mask)
|
||||||
# convert from pyarrow to Dataset
|
# convert from pyarrow to Dataset
|
||||||
filtered_train = Dataset.from_dict(filtered_table.to_pydict())
|
filtered_train = Dataset.from_dict(filtered_table.to_pydict())
|
||||||
|
|
||||||
filtered_train = filtered_train.add_column("embeddings", df_train["embeddings"])
|
filtered_train = filtered_train.add_column("embeddings", df_train["embeddings"])
|
||||||
filtered_train = filtered_train.add_column("loss", df_train["loss"])
|
filtered_train = filtered_train.add_column("loss", df_train["loss"])
|
||||||
filtered_train = filtered_train.add_column("is_train", [True] * len(filtered_train))
|
filtered_train = filtered_train.add_column(
|
||||||
|
"is_train", [True] * len(filtered_train)
|
||||||
|
)
|
||||||
|
|
||||||
filtered_train.to_json(f"inference/epoch_2_embeddings_train_shard_{local_rank}.jsonl", lines=True, orient="records", num_proc=64)
|
filtered_train.to_json(
|
||||||
|
f"inference/epoch_2_embeddings_train_shard_{local_rank}.jsonl",
|
||||||
|
lines=True,
|
||||||
|
orient="records",
|
||||||
|
num_proc=64,
|
||||||
|
)
|
||||||
|
|
||||||
val_outputs = {"loss": [], "embeddings": [], "index": []}
|
val_outputs = {"loss": [], "embeddings": [], "index": []}
|
||||||
for batch in tqdm(val_dataloader, disable=local_rank != 0):
|
for batch in tqdm(val_dataloader, disable=local_rank != 0):
|
||||||
batch["input_ids"] = batch["input_ids"].to(f"cuda:{local_rank}")
|
batch["input_ids"] = batch["input_ids"].to(f"cuda:{local_rank}")
|
||||||
batch["labels"] = batch["labels"].to(f"cuda:{local_rank}")
|
batch["labels"] = batch["labels"].to(f"cuda:{local_rank}")
|
||||||
outputs = model(input_ids=batch["input_ids"], labels=batch["labels"], output_hidden_states=True)
|
outputs = model(
|
||||||
|
input_ids=batch["input_ids"],
|
||||||
|
labels=batch["labels"],
|
||||||
|
output_hidden_states=True,
|
||||||
|
)
|
||||||
loss = calc_cross_entropy_no_reduction(outputs.logits, batch["labels"])
|
loss = calc_cross_entropy_no_reduction(outputs.logits, batch["labels"])
|
||||||
val_outputs["loss"].extend(loss)
|
val_outputs["loss"].extend(loss)
|
||||||
|
|
||||||
@ -158,7 +184,9 @@ def inference(config):
|
|||||||
sequence_lengths.append(len(item) - 1)
|
sequence_lengths.append(len(item) - 1)
|
||||||
|
|
||||||
sequence_lengths = torch.tensor(sequence_lengths)
|
sequence_lengths = torch.tensor(sequence_lengths)
|
||||||
pooled_logits = embeddings[torch.arange(batch_size, device=embeddings.device), sequence_lengths]
|
pooled_logits = embeddings[
|
||||||
|
torch.arange(batch_size, device=embeddings.device), sequence_lengths
|
||||||
|
]
|
||||||
|
|
||||||
val_outputs["embeddings"].append(pooled_logits)
|
val_outputs["embeddings"].append(pooled_logits)
|
||||||
val_outputs["index"].extend(batch["index"].to(model.device))
|
val_outputs["index"].extend(batch["index"].to(model.device))
|
||||||
@ -176,7 +204,7 @@ def inference(config):
|
|||||||
# compute mask in pyarrow since it's super fast
|
# compute mask in pyarrow since it's super fast
|
||||||
# ty @bmschmidt for showing me this!
|
# ty @bmschmidt for showing me this!
|
||||||
table = val_dataset.data
|
table = val_dataset.data
|
||||||
mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32()))
|
mask = pc.is_in(table["index"], value_set=pa.array(curr_idx, pa.int32()))
|
||||||
filtered_table = table.filter(mask)
|
filtered_table = table.filter(mask)
|
||||||
# convert from pyarrow to Dataset
|
# convert from pyarrow to Dataset
|
||||||
filtered_val = Dataset.from_dict(filtered_table.to_pydict())
|
filtered_val = Dataset.from_dict(filtered_table.to_pydict())
|
||||||
@ -184,7 +212,12 @@ def inference(config):
|
|||||||
filtered_val = filtered_val.add_column("loss", df_val["loss"])
|
filtered_val = filtered_val.add_column("loss", df_val["loss"])
|
||||||
filtered_val = filtered_val.add_column("is_train", [False] * len(filtered_val))
|
filtered_val = filtered_val.add_column("is_train", [False] * len(filtered_val))
|
||||||
|
|
||||||
filtered_val.to_json(f"inference/epoch_2_embeddings_val_shard_{local_rank}.jsonl", lines=True, orient="records", num_proc=64)
|
filtered_val.to_json(
|
||||||
|
f"inference/epoch_2_embeddings_val_shard_{local_rank}.jsonl",
|
||||||
|
lines=True,
|
||||||
|
orient="records",
|
||||||
|
num_proc=64,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -201,4 +234,3 @@ def main():
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# parse arguments by reading in a config
|
# parse arguments by reading in a config
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
6
gpt4all/utils/distributed_utils.py
Normal file
6
gpt4all/utils/distributed_utils.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
|
def rank0_print(msg):
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
print(msg)
|
Loading…
Reference in New Issue
Block a user