import argparse
import glob
import os
import time
from itertools import islice
from multiprocessing import cpu_count

import numpy as np
import scann
import torch
from einops import rearrange
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder
from ldm.util import instantiate_from_config, parallel_data_prefetch
from omegaconf import OmegaConf
from PIL import Image
from torchvision.utils import make_grid
from tqdm import tqdm, trange

DATABASES = [
    "openimages",
    "artbench-art_nouveau",
    "artbench-baroque",
    "artbench-expressionism",
    "artbench-impressionism",
    "artbench-post_impressionism",
    "artbench-realism",
    "artbench-romanticism",
    "artbench-renaissance",
    "artbench-surrealism",
    "artbench-ukiyo_e",
]


def chunk(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())


def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model


class Searcher(object):
    def __init__(self, database, retriever_version="ViT-L/14"):
        assert database in DATABASES
        # self.database = self.load_database(database)
        self.database_name = database
        self.searcher_savedir = f"data/rdm/searchers/{self.database_name}"
        self.database_path = f"data/rdm/retrieval_databases/{self.database_name}"
        self.retriever = self.load_retriever(version=retriever_version)
        self.database = {"embedding": [], "img_id": [], "patch_coords": []}
        self.load_database()
        self.load_searcher()

    def train_searcher(self, k, metric="dot_product", searcher_savedir=None):
        print("Start training searcher")
        searcher = scann.scann_ops_pybind.builder(
            self.database["embedding"] / np.linalg.norm(self.database["embedding"], axis=1)[:, np.newaxis], k, metric
        )
        self.searcher = searcher.score_brute_force().build()
        print("Finish training searcher")

        if searcher_savedir is not None:
            print(f'Save trained searcher under "{searcher_savedir}"')
            os.makedirs(searcher_savedir, exist_ok=True)
            self.searcher.serialize(searcher_savedir)

    def load_single_file(self, saved_embeddings):
        compressed = np.load(saved_embeddings)
        self.database = {key: compressed[key] for key in compressed.files}
        print("Finished loading of clip embeddings.")

    def load_multi_files(self, data_archive):
        out_data = {key: [] for key in self.database}
        for d in tqdm(data_archive, desc=f"Loading datapool from {len(data_archive)} individual files."):
            for key in d.files:
                out_data[key].append(d[key])

        return out_data

    def load_database(self):
        print(f'Load saved patch embedding from "{self.database_path}"')
        file_content = glob.glob(os.path.join(self.database_path, "*.npz"))

        if len(file_content) == 1:
            self.load_single_file(file_content[0])
        elif len(file_content) > 1:
            data = [np.load(f) for f in file_content]
            prefetched_data = parallel_data_prefetch(
                self.load_multi_files, data, n_proc=min(len(data), cpu_count()), target_data_type="dict"
            )

            self.database = {
                key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in self.database
            }
        else:
            raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?')

        print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.')

    def load_retriever(
        self,
        version="ViT-L/14",
    ):
        model = FrozenClipImageEmbedder(model=version)
        if torch.cuda.is_available():
            model.cuda()
        model.eval()
        return model

    def load_searcher(self):
        print(f"load searcher for database {self.database_name} from {self.searcher_savedir}")
        self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir)
        print("Finished loading searcher.")

    def search(self, x, k):
        if self.searcher is None and self.database["embedding"].shape[0] < 2e4:
            self.train_searcher(k)  # quickly fit searcher on the fly for small databases
        assert self.searcher is not None, "Cannot search with uninitialized searcher"
        if isinstance(x, torch.Tensor):
            x = x.detach().cpu().numpy()
        if len(x.shape) == 3:
            x = x[:, 0]
        query_embeddings = x / np.linalg.norm(x, axis=1)[:, np.newaxis]

        start = time.time()
        nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k)
        end = time.time()

        out_embeddings = self.database["embedding"][nns]
        out_img_ids = self.database["img_id"][nns]
        out_pc = self.database["patch_coords"][nns]

        out = {
            "nn_embeddings": out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis],
            "img_ids": out_img_ids,
            "patch_coords": out_pc,
            "queries": x,
            "exec_time": end - start,
            "nns": nns,
            "q_embeddings": query_embeddings,
        }

        return out

    def __call__(self, x, n):
        return self.search(x, n)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # TODO: add n_neighbors and modes (text-only, text-image-retrieval, image-image retrieval etc)
    # TODO: add 'image variation' mode when knn=0 but a single image is given instead of a text prompt?
    parser.add_argument(
        "--prompt",
        type=str,
        nargs="?",
        default="a painting of a virus monster playing guitar",
        help="the prompt to render",
    )

    parser.add_argument(
        "--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples"
    )

    parser.add_argument(
        "--skip_grid",
        action="store_true",
        help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
    )

    parser.add_argument(
        "--ddim_steps",
        type=int,
        default=50,
        help="number of ddim sampling steps",
    )

    parser.add_argument(
        "--n_repeat",
        type=int,
        default=1,
        help="number of repeats in CLIP latent space",
    )

    parser.add_argument(
        "--plms",
        action="store_true",
        help="use plms sampling",
    )

    parser.add_argument(
        "--ddim_eta",
        type=float,
        default=0.0,
        help="ddim eta (eta=0.0 corresponds to deterministic sampling",
    )
    parser.add_argument(
        "--n_iter",
        type=int,
        default=1,
        help="sample this often",
    )

    parser.add_argument(
        "--H",
        type=int,
        default=768,
        help="image height, in pixel space",
    )

    parser.add_argument(
        "--W",
        type=int,
        default=768,
        help="image width, in pixel space",
    )

    parser.add_argument(
        "--n_samples",
        type=int,
        default=3,
        help="how many samples to produce for each given prompt. A.k.a batch size",
    )

    parser.add_argument(
        "--n_rows",
        type=int,
        default=0,
        help="rows in the grid (default: n_samples)",
    )

    parser.add_argument(
        "--scale",
        type=float,
        default=5.0,
        help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
    )

    parser.add_argument(
        "--from-file",
        type=str,
        help="if specified, load prompts from this file",
    )

    parser.add_argument(
        "--config",
        type=str,
        default="configs/retrieval-augmented-diffusion/768x768.yaml",
        help="path to config which constructs model",
    )

    parser.add_argument(
        "--ckpt",
        type=str,
        default="models/rdm/rdm768x768/model.ckpt",
        help="path to checkpoint of model",
    )

    parser.add_argument(
        "--clip_type",
        type=str,
        default="ViT-L/14",
        help="which CLIP model to use for retrieval and NN encoding",
    )
    parser.add_argument(
        "--database",
        type=str,
        default="artbench-surrealism",
        choices=DATABASES,
        help="The database used for the search, only applied when --use_neighbors=True",
    )
    parser.add_argument(
        "--use_neighbors",
        default=False,
        action="store_true",
        help="Include neighbors in addition to text prompt for conditioning",
    )
    parser.add_argument(
        "--knn",
        default=10,
        type=int,
        help="The number of included neighbors, only applied when --use_neighbors=True",
    )

    opt = parser.parse_args()

    config = OmegaConf.load(f"{opt.config}")
    model = load_model_from_config(config, f"{opt.ckpt}")

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = model.to(device)

    clip_text_encoder = FrozenCLIPTextEmbedder(opt.clip_type).to(device)

    if opt.plms:
        sampler = PLMSSampler(model)
    else:
        sampler = DDIMSampler(model)

    os.makedirs(opt.outdir, exist_ok=True)
    outpath = opt.outdir

    batch_size = opt.n_samples
    n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
    if not opt.from_file:
        prompt = opt.prompt
        assert prompt is not None
        data = [batch_size * [prompt]]

    else:
        print(f"reading prompts from {opt.from_file}")
        with open(opt.from_file, "r") as f:
            data = f.read().splitlines()
            data = list(chunk(data, batch_size))

    sample_path = os.path.join(outpath, "samples")
    os.makedirs(sample_path, exist_ok=True)
    base_count = len(os.listdir(sample_path))
    grid_count = len(os.listdir(outpath)) - 1

    print(f"sampling scale for cfg is {opt.scale:.2f}")

    searcher = None
    if opt.use_neighbors:
        searcher = Searcher(opt.database)

    with torch.no_grad():
        with model.ema_scope():
            for n in trange(opt.n_iter, desc="Sampling"):
                all_samples = list()
                for prompts in tqdm(data, desc="data"):
                    print("sampling prompts:", prompts)
                    if isinstance(prompts, tuple):
                        prompts = list(prompts)
                    c = clip_text_encoder.encode(prompts)
                    uc = None
                    if searcher is not None:
                        nn_dict = searcher(c, opt.knn)
                        c = torch.cat([c, torch.from_numpy(nn_dict["nn_embeddings"]).cuda()], dim=1)
                    if opt.scale != 1.0:
                        uc = torch.zeros_like(c)
                    if isinstance(prompts, tuple):
                        prompts = list(prompts)
                    shape = [16, opt.H // 16, opt.W // 16]  # note: currently hardcoded for f16 model
                    samples_ddim, _ = sampler.sample(
                        S=opt.ddim_steps,
                        conditioning=c,
                        batch_size=c.shape[0],
                        shape=shape,
                        verbose=False,
                        unconditional_guidance_scale=opt.scale,
                        unconditional_conditioning=uc,
                        eta=opt.ddim_eta,
                    )

                    x_samples_ddim = model.decode_first_stage(samples_ddim)
                    x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

                    for x_sample in x_samples_ddim:
                        x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
                        Image.fromarray(x_sample.astype(np.uint8)).save(
                            os.path.join(sample_path, f"{base_count:05}.png")
                        )
                        base_count += 1
                    all_samples.append(x_samples_ddim)

                if not opt.skip_grid:
                    # additionally, save as grid
                    grid = torch.stack(all_samples, 0)
                    grid = rearrange(grid, "n b c h w -> (n b) c h w")
                    grid = make_grid(grid, nrow=n_rows)

                    # to image
                    grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
                    Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
                    grid_count += 1

    print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")