diff --git a/inference.py b/inference.py index 1fb620ab..0fa73025 100644 --- a/inference.py +++ b/inference.py @@ -6,11 +6,11 @@ from read import read_config from accelerate.utils import set_seed from data import load_data_for_inference from tqdm import tqdm -from datasets import concatenate_datasets, Dataset +from datasets import Dataset import torch.distributed as dist -from transformers.trainer_pt_utils import ShardSampler, distributed_concat, nested_numpify +from transformers.trainer_pt_utils import nested_numpify from transformers import DefaultDataCollator -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, DistributedSampler import numpy as np @@ -46,7 +46,7 @@ def inference(config): num_processes = dist.get_world_size() local_rank = dist.get_rank() - train_sampler = ShardSampler(train_dataset, config["batch_size"], drop_last=True, num_processes=num_processes, process_index=local_rank) + train_sampler = DistributedSampler(train_dataset, shuffle=False, drop_last=True, num_replicas=num_processes, rank=local_rank) train_dataloader = DataLoader( train_dataset, collate_fn=DefaultDataCollator(), @@ -55,7 +55,7 @@ def inference(config): drop_last=True ) - val_sampler = ShardSampler(val_dataset, config["batch_size"], drop_last=True, num_processes=num_processes, process_index=local_rank) + val_sampler = DistributedSampler(val_dataset, shuffle=False, drop_last=True, num_replicas=num_processes, rank=local_rank) val_dataloader = DataLoader( val_dataset, collate_fn=DefaultDataCollator(), @@ -69,7 +69,6 @@ def inference(config): trust_remote_code=True, torch_dtype=torch.bfloat16, ) - model.to(f"cuda:{local_rank}") with torch.no_grad(): @@ -107,17 +106,23 @@ def inference(config): torch.cuda.empty_cache() - dist.barrier() - gathered_train = nested_numpify(distributed_concat(train_outputs)) - gathered_train["index"] = np.concatenate(gathered_train["index"]) - gathered_train["loss"] = np.concatenate(gathered_train["loss"]) - gathered_train["embeddings"] = np.concatenate(gathered_train["embeddings"]) + train_outputs = nested_numpify(train_outputs) + # stack since they're 0-dim arrays + train_outputs["index"] = np.stack(train_outputs["index"]) + train_outputs["loss"] = np.stack(train_outputs["loss"]) + train_outputs["embeddings"] = np.concatenate(train_outputs["embeddings"]) - df_train = Dataset.from_dict(gathered_train) + df_train = Dataset.from_dict(train_outputs) df_train = df_train.sort("index") - train_dataset = train_dataset.add_column("embeddings", df_train["embeddings"]) - train_dataset = train_dataset.add_column("loss", df_train["loss"]) - train_dataset = train_dataset.add_column("is_train", [True] * len(train_dataset)) + curr_idx = df_train["index"] + + filtered_train = train_dataset.filter(lambda example: example["index"] in curr_idx) + + filtered_train = filtered_train.add_column("embeddings", df_train["embeddings"]) + filtered_train = filtered_train.add_column("loss", df_train["loss"]) + filtered_train = filtered_train.add_column("is_train", [True] * len(filtered_train)) + + filtered_train.to_json(f"inference/epoch_2_embeddings_train_shard_{local_rank}.jsonl", lines=True, orient="records", num_proc=64) val_outputs = {"loss": [], "embeddings": [], "index": []} for batch in tqdm(val_dataloader, disable=local_rank != 0): @@ -153,25 +158,24 @@ def inference(config): torch.cuda.empty_cache() - dist.barrier() - gathered_val = nested_numpify(distributed_concat(val_outputs)) + val_outputs = nested_numpify(val_outputs) + val_outputs["index"] = np.stack(val_outputs["index"]) + val_outputs["loss"] = np.stack(val_outputs["loss"]) + val_outputs["embeddings"] = np.concatenate(val_outputs["embeddings"]) - gathered_val["index"] = np.concatenate(gathered_val["index"]) - gathered_val["loss"] = np.concatenate(gathered_val["loss"]) - gathered_val["embeddings"] = np.concatenate(gathered_val["embeddings"]) - - df_val = Dataset.from_dict(gathered_val) + df_val = Dataset.from_dict(val_outputs) df_val = df_val.sort("index") + curr_idx = df_val["index"] - val_dataset = val_dataset.add_column("embeddings", df_val["embeddings"]) - val_dataset = val_dataset.add_column("loss", df_val["loss"]) - val_dataset = val_dataset.add_column("is_train", [False] * len(val_dataset)) + filtered_val = val_dataset.filter(lambda example: example["index"] in curr_idx) - df = concatenate_datasets([train_dataset, val_dataset]) - if local_rank == 0: - df.to_json("epoch_1_checkpoint.jsonl", lines=True, orient="records", num_proc=64) + filtered_val = filtered_val.add_column("embeddings", df_val["embeddings"]) + filtered_val = filtered_val.add_column("loss", df_val["loss"]) + filtered_val = filtered_val.add_column("is_train", [False] * len(filtered_val)) + filtered_val.to_json(f"inference/epoch_2_embeddings_val_shard_{local_rank}.jsonl", lines=True, orient="records", num_proc=64) + def main(): dist.init_process_group("nccl") parser = ArgumentParser()