diff --git a/gpt4all/inference/inference.py b/gpt4all/inference/inference.py index 5e351c46..2096819c 100644 --- a/gpt4all/inference/inference.py +++ b/gpt4all/inference/inference.py @@ -3,12 +3,13 @@ import torch import torch.nn as nn from argparse import ArgumentParser from gpt4all.utils.read import read_config -from accelerate.utils import set_seed +from accelerate.utils import set_seed from gpt4all.utils.data import load_data_for_inference +from gpt4all.utils.distributed_utils import rank0_print from tqdm import tqdm -from datasets import Dataset +from datasets import Dataset import torch.distributed as dist -from transformers.trainer_pt_utils import nested_numpify +from transformers.trainer_pt_utils import nested_numpify from transformers import DefaultDataCollator from torch.utils.data import DataLoader, DistributedSampler import numpy as np @@ -21,56 +22,64 @@ def calc_cross_entropy_no_reduction(lm_logits, labels): shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = nn.CrossEntropyLoss(reduction='none') + loss_fct = nn.CrossEntropyLoss(reduction="none") loss = loss_fct(shift_logits.permute(0, 2, 1), shift_labels).mean(dim=1) return loss -def rank0_print(msg): - if dist.get_rank() == 0: - print(msg) - - def inference(config): - set_seed(config['seed']) + set_seed(config["seed"]) rank0_print(f"World size: {dist.get_world_size()}") - tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length']) + tokenizer = AutoTokenizer.from_pretrained( + config["tokenizer_name"], model_max_length=config["max_length"] + ) # llama has no pad token, set it to new token if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - - train_dataset, val_dataset = load_data_for_inference(config, tokenizer) + train_dataset, val_dataset = load_data_for_inference(config, tokenizer) num_processes = dist.get_world_size() local_rank = dist.get_rank() - train_sampler = DistributedSampler(train_dataset, shuffle=False, drop_last=True, num_replicas=num_processes, rank=local_rank) + train_sampler = DistributedSampler( + train_dataset, + shuffle=False, + drop_last=True, + num_replicas=num_processes, + rank=local_rank, + ) train_dataloader = DataLoader( train_dataset, collate_fn=DefaultDataCollator(), batch_size=config["batch_size"], sampler=train_sampler, - drop_last=True + drop_last=True, ) - val_sampler = DistributedSampler(val_dataset, shuffle=False, drop_last=True, num_replicas=num_processes, rank=local_rank) + val_sampler = DistributedSampler( + val_dataset, + shuffle=False, + drop_last=True, + num_replicas=num_processes, + rank=local_rank, + ) val_dataloader = DataLoader( val_dataset, collate_fn=DefaultDataCollator(), batch_size=config["batch_size"], sampler=val_sampler, - drop_last=True + drop_last=True, ) - - model = AutoModelForCausalLM.from_pretrained(config["model_name"], - trust_remote_code=True, - torch_dtype=torch.bfloat16, - ) + model = AutoModelForCausalLM.from_pretrained( + config["model_name"], + trust_remote_code=True, + torch_dtype=torch.bfloat16, + ) model.to(f"cuda:{local_rank}") with torch.no_grad(): @@ -78,14 +87,18 @@ def inference(config): for batch in tqdm(train_dataloader, disable=local_rank != 0): batch["input_ids"] = batch["input_ids"].to(f"cuda:{local_rank}") batch["labels"] = batch["labels"].to(f"cuda:{local_rank}") - outputs = model(input_ids=batch["input_ids"], labels=batch["labels"], output_hidden_states=True) + outputs = model( + input_ids=batch["input_ids"], + labels=batch["labels"], + output_hidden_states=True, + ) loss = calc_cross_entropy_no_reduction(outputs.logits, batch["labels"]) train_outputs["loss"].extend(loss) embeddings = outputs.hidden_states[-1] batch_size = batch["input_ids"].shape[0] sequence_lengths = [] - # since we use mutiturn with multiple <|endoftext|>, we need to find the place where + # since we use mutiturn with multiple <|endoftext|>, we need to find the place where # <|endoftext|> is repeated for item in batch["input_ids"]: indices = torch.where(item == tokenizer.pad_token_id)[0] @@ -101,7 +114,9 @@ def inference(config): sequence_lengths.append(len(item) - 1) sequence_lengths = torch.tensor(sequence_lengths) - pooled_logits = embeddings[torch.arange(batch_size, device=embeddings.device), sequence_lengths] + pooled_logits = embeddings[ + torch.arange(batch_size, device=embeddings.device), sequence_lengths + ] train_outputs["embeddings"].append(pooled_logits) train_outputs["index"].extend(batch["index"].to(model.device)) @@ -120,29 +135,40 @@ def inference(config): # compute mask in pyarrow since it's super fast # ty @bmschmidt for showing me this! table = train_dataset.data - mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32())) + mask = pc.is_in(table["index"], value_set=pa.array(curr_idx, pa.int32())) filtered_table = table.filter(mask) # convert from pyarrow to Dataset filtered_train = Dataset.from_dict(filtered_table.to_pydict()) filtered_train = filtered_train.add_column("embeddings", df_train["embeddings"]) filtered_train = filtered_train.add_column("loss", df_train["loss"]) - filtered_train = filtered_train.add_column("is_train", [True] * len(filtered_train)) + filtered_train = filtered_train.add_column( + "is_train", [True] * len(filtered_train) + ) - filtered_train.to_json(f"inference/epoch_2_embeddings_train_shard_{local_rank}.jsonl", lines=True, orient="records", num_proc=64) + filtered_train.to_json( + f"inference/epoch_2_embeddings_train_shard_{local_rank}.jsonl", + lines=True, + orient="records", + num_proc=64, + ) val_outputs = {"loss": [], "embeddings": [], "index": []} for batch in tqdm(val_dataloader, disable=local_rank != 0): batch["input_ids"] = batch["input_ids"].to(f"cuda:{local_rank}") batch["labels"] = batch["labels"].to(f"cuda:{local_rank}") - outputs = model(input_ids=batch["input_ids"], labels=batch["labels"], output_hidden_states=True) + outputs = model( + input_ids=batch["input_ids"], + labels=batch["labels"], + output_hidden_states=True, + ) loss = calc_cross_entropy_no_reduction(outputs.logits, batch["labels"]) val_outputs["loss"].extend(loss) embeddings = outputs.hidden_states[-1] batch_size = batch["input_ids"].shape[0] sequence_lengths = [] - # since we use mutiturn with multiple <|endoftext|>, we need to find the place where + # since we use mutiturn with multiple <|endoftext|>, we need to find the place where # <|endoftext|> is repeated for item in batch["input_ids"]: indices = torch.where(item == tokenizer.pad_token_id)[0] @@ -158,7 +184,9 @@ def inference(config): sequence_lengths.append(len(item) - 1) sequence_lengths = torch.tensor(sequence_lengths) - pooled_logits = embeddings[torch.arange(batch_size, device=embeddings.device), sequence_lengths] + pooled_logits = embeddings[ + torch.arange(batch_size, device=embeddings.device), sequence_lengths + ] val_outputs["embeddings"].append(pooled_logits) val_outputs["index"].extend(batch["index"].to(model.device)) @@ -176,7 +204,7 @@ def inference(config): # compute mask in pyarrow since it's super fast # ty @bmschmidt for showing me this! table = val_dataset.data - mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32())) + mask = pc.is_in(table["index"], value_set=pa.array(curr_idx, pa.int32())) filtered_table = table.filter(mask) # convert from pyarrow to Dataset filtered_val = Dataset.from_dict(filtered_table.to_pydict()) @@ -184,8 +212,13 @@ def inference(config): filtered_val = filtered_val.add_column("loss", df_val["loss"]) filtered_val = filtered_val.add_column("is_train", [False] * len(filtered_val)) - filtered_val.to_json(f"inference/epoch_2_embeddings_val_shard_{local_rank}.jsonl", lines=True, orient="records", num_proc=64) - + filtered_val.to_json( + f"inference/epoch_2_embeddings_val_shard_{local_rank}.jsonl", + lines=True, + orient="records", + num_proc=64, + ) + def main(): dist.init_process_group("nccl") @@ -201,4 +234,3 @@ def main(): if __name__ == "__main__": # parse arguments by reading in a config main() - diff --git a/gpt4all/utils/distributed_utils.py b/gpt4all/utils/distributed_utils.py new file mode 100644 index 00000000..839a7a92 --- /dev/null +++ b/gpt4all/utils/distributed_utils.py @@ -0,0 +1,6 @@ +import torch.distributed as dist + + +def rank0_print(msg): + if dist.get_rank() == 0: + print(msg)