mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-06 12:06:54 +00:00
chore: pull out common dist print fn
This commit is contained in:
parent
e255e0a805
commit
4671f4e82f
@ -5,6 +5,7 @@ from argparse import ArgumentParser
|
||||
from gpt4all.utils.read import read_config
|
||||
from accelerate.utils import set_seed
|
||||
from gpt4all.utils.data import load_data_for_inference
|
||||
from gpt4all.utils.distributed_utils import rank0_print
|
||||
from tqdm import tqdm
|
||||
from datasets import Dataset
|
||||
import torch.distributed as dist
|
||||
@ -21,53 +22,61 @@ def calc_cross_entropy_no_reduction(lm_logits, labels):
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||||
loss_fct = nn.CrossEntropyLoss(reduction="none")
|
||||
loss = loss_fct(shift_logits.permute(0, 2, 1), shift_labels).mean(dim=1)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def rank0_print(msg):
|
||||
if dist.get_rank() == 0:
|
||||
print(msg)
|
||||
|
||||
|
||||
def inference(config):
|
||||
set_seed(config['seed'])
|
||||
set_seed(config["seed"])
|
||||
|
||||
rank0_print(f"World size: {dist.get_world_size()}")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length'])
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
config["tokenizer_name"], model_max_length=config["max_length"]
|
||||
)
|
||||
# llama has no pad token, set it to new token
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
|
||||
train_dataset, val_dataset = load_data_for_inference(config, tokenizer)
|
||||
|
||||
num_processes = dist.get_world_size()
|
||||
local_rank = dist.get_rank()
|
||||
|
||||
train_sampler = DistributedSampler(train_dataset, shuffle=False, drop_last=True, num_replicas=num_processes, rank=local_rank)
|
||||
train_sampler = DistributedSampler(
|
||||
train_dataset,
|
||||
shuffle=False,
|
||||
drop_last=True,
|
||||
num_replicas=num_processes,
|
||||
rank=local_rank,
|
||||
)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
collate_fn=DefaultDataCollator(),
|
||||
batch_size=config["batch_size"],
|
||||
sampler=train_sampler,
|
||||
drop_last=True
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
val_sampler = DistributedSampler(val_dataset, shuffle=False, drop_last=True, num_replicas=num_processes, rank=local_rank)
|
||||
val_sampler = DistributedSampler(
|
||||
val_dataset,
|
||||
shuffle=False,
|
||||
drop_last=True,
|
||||
num_replicas=num_processes,
|
||||
rank=local_rank,
|
||||
)
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
collate_fn=DefaultDataCollator(),
|
||||
batch_size=config["batch_size"],
|
||||
sampler=val_sampler,
|
||||
drop_last=True
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
config["model_name"],
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
@ -78,7 +87,11 @@ def inference(config):
|
||||
for batch in tqdm(train_dataloader, disable=local_rank != 0):
|
||||
batch["input_ids"] = batch["input_ids"].to(f"cuda:{local_rank}")
|
||||
batch["labels"] = batch["labels"].to(f"cuda:{local_rank}")
|
||||
outputs = model(input_ids=batch["input_ids"], labels=batch["labels"], output_hidden_states=True)
|
||||
outputs = model(
|
||||
input_ids=batch["input_ids"],
|
||||
labels=batch["labels"],
|
||||
output_hidden_states=True,
|
||||
)
|
||||
loss = calc_cross_entropy_no_reduction(outputs.logits, batch["labels"])
|
||||
train_outputs["loss"].extend(loss)
|
||||
|
||||
@ -101,7 +114,9 @@ def inference(config):
|
||||
sequence_lengths.append(len(item) - 1)
|
||||
|
||||
sequence_lengths = torch.tensor(sequence_lengths)
|
||||
pooled_logits = embeddings[torch.arange(batch_size, device=embeddings.device), sequence_lengths]
|
||||
pooled_logits = embeddings[
|
||||
torch.arange(batch_size, device=embeddings.device), sequence_lengths
|
||||
]
|
||||
|
||||
train_outputs["embeddings"].append(pooled_logits)
|
||||
train_outputs["index"].extend(batch["index"].to(model.device))
|
||||
@ -120,22 +135,33 @@ def inference(config):
|
||||
# compute mask in pyarrow since it's super fast
|
||||
# ty @bmschmidt for showing me this!
|
||||
table = train_dataset.data
|
||||
mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32()))
|
||||
mask = pc.is_in(table["index"], value_set=pa.array(curr_idx, pa.int32()))
|
||||
filtered_table = table.filter(mask)
|
||||
# convert from pyarrow to Dataset
|
||||
filtered_train = Dataset.from_dict(filtered_table.to_pydict())
|
||||
|
||||
filtered_train = filtered_train.add_column("embeddings", df_train["embeddings"])
|
||||
filtered_train = filtered_train.add_column("loss", df_train["loss"])
|
||||
filtered_train = filtered_train.add_column("is_train", [True] * len(filtered_train))
|
||||
filtered_train = filtered_train.add_column(
|
||||
"is_train", [True] * len(filtered_train)
|
||||
)
|
||||
|
||||
filtered_train.to_json(f"inference/epoch_2_embeddings_train_shard_{local_rank}.jsonl", lines=True, orient="records", num_proc=64)
|
||||
filtered_train.to_json(
|
||||
f"inference/epoch_2_embeddings_train_shard_{local_rank}.jsonl",
|
||||
lines=True,
|
||||
orient="records",
|
||||
num_proc=64,
|
||||
)
|
||||
|
||||
val_outputs = {"loss": [], "embeddings": [], "index": []}
|
||||
for batch in tqdm(val_dataloader, disable=local_rank != 0):
|
||||
batch["input_ids"] = batch["input_ids"].to(f"cuda:{local_rank}")
|
||||
batch["labels"] = batch["labels"].to(f"cuda:{local_rank}")
|
||||
outputs = model(input_ids=batch["input_ids"], labels=batch["labels"], output_hidden_states=True)
|
||||
outputs = model(
|
||||
input_ids=batch["input_ids"],
|
||||
labels=batch["labels"],
|
||||
output_hidden_states=True,
|
||||
)
|
||||
loss = calc_cross_entropy_no_reduction(outputs.logits, batch["labels"])
|
||||
val_outputs["loss"].extend(loss)
|
||||
|
||||
@ -158,7 +184,9 @@ def inference(config):
|
||||
sequence_lengths.append(len(item) - 1)
|
||||
|
||||
sequence_lengths = torch.tensor(sequence_lengths)
|
||||
pooled_logits = embeddings[torch.arange(batch_size, device=embeddings.device), sequence_lengths]
|
||||
pooled_logits = embeddings[
|
||||
torch.arange(batch_size, device=embeddings.device), sequence_lengths
|
||||
]
|
||||
|
||||
val_outputs["embeddings"].append(pooled_logits)
|
||||
val_outputs["index"].extend(batch["index"].to(model.device))
|
||||
@ -176,7 +204,7 @@ def inference(config):
|
||||
# compute mask in pyarrow since it's super fast
|
||||
# ty @bmschmidt for showing me this!
|
||||
table = val_dataset.data
|
||||
mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32()))
|
||||
mask = pc.is_in(table["index"], value_set=pa.array(curr_idx, pa.int32()))
|
||||
filtered_table = table.filter(mask)
|
||||
# convert from pyarrow to Dataset
|
||||
filtered_val = Dataset.from_dict(filtered_table.to_pydict())
|
||||
@ -184,7 +212,12 @@ def inference(config):
|
||||
filtered_val = filtered_val.add_column("loss", df_val["loss"])
|
||||
filtered_val = filtered_val.add_column("is_train", [False] * len(filtered_val))
|
||||
|
||||
filtered_val.to_json(f"inference/epoch_2_embeddings_val_shard_{local_rank}.jsonl", lines=True, orient="records", num_proc=64)
|
||||
filtered_val.to_json(
|
||||
f"inference/epoch_2_embeddings_val_shard_{local_rank}.jsonl",
|
||||
lines=True,
|
||||
orient="records",
|
||||
num_proc=64,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
@ -201,4 +234,3 @@ def main():
|
||||
if __name__ == "__main__":
|
||||
# parse arguments by reading in a config
|
||||
main()
|
||||
|
||||
|
6
gpt4all/utils/distributed_utils.py
Normal file
6
gpt4all/utils/distributed_utils.py
Normal file
@ -0,0 +1,6 @@
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def rank0_print(msg):
|
||||
if dist.get_rank() == 0:
|
||||
print(msg)
|
Loading…
Reference in New Issue
Block a user