[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -1,22 +1,22 @@
import argparse, os, sys, glob
import clip
import torch
import torch.nn as nn
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange, repeat
from torchvision.utils import make_grid
import scann
import argparse
import glob
import os
import time
from itertools import islice
from multiprocessing import cpu_count
from ldm.util import instantiate_from_config, parallel_data_prefetch
import numpy as np
import scann
import torch
from einops import rearrange
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder
from ldm.util import instantiate_from_config, parallel_data_prefetch
from omegaconf import OmegaConf
from PIL import Image
from torchvision.utils import make_grid
from tqdm import tqdm, trange
DATABASES = [
"openimages",
@@ -59,29 +59,24 @@ def load_model_from_config(config, ckpt, verbose=False):
class Searcher(object):
def __init__(self, database, retriever_version='ViT-L/14'):
def __init__(self, database, retriever_version="ViT-L/14"):
assert database in DATABASES
# self.database = self.load_database(database)
self.database_name = database
self.searcher_savedir = f'data/rdm/searchers/{self.database_name}'
self.database_path = f'data/rdm/retrieval_databases/{self.database_name}'
self.searcher_savedir = f"data/rdm/searchers/{self.database_name}"
self.database_path = f"data/rdm/retrieval_databases/{self.database_name}"
self.retriever = self.load_retriever(version=retriever_version)
self.database = {'embedding': [],
'img_id': [],
'patch_coords': []}
self.database = {"embedding": [], "img_id": [], "patch_coords": []}
self.load_database()
self.load_searcher()
def train_searcher(self, k,
metric='dot_product',
searcher_savedir=None):
print('Start training searcher')
searcher = scann.scann_ops_pybind.builder(self.database['embedding'] /
np.linalg.norm(self.database['embedding'], axis=1)[:, np.newaxis],
k, metric)
def train_searcher(self, k, metric="dot_product", searcher_savedir=None):
print("Start training searcher")
searcher = scann.scann_ops_pybind.builder(
self.database["embedding"] / np.linalg.norm(self.database["embedding"], axis=1)[:, np.newaxis], k, metric
)
self.searcher = searcher.score_brute_force().build()
print('Finish training searcher')
print("Finish training searcher")
if searcher_savedir is not None:
print(f'Save trained searcher under "{searcher_savedir}"')
@@ -91,36 +86,40 @@ class Searcher(object):
def load_single_file(self, saved_embeddings):
compressed = np.load(saved_embeddings)
self.database = {key: compressed[key] for key in compressed.files}
print('Finished loading of clip embeddings.')
print("Finished loading of clip embeddings.")
def load_multi_files(self, data_archive):
out_data = {key: [] for key in self.database}
for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
for d in tqdm(data_archive, desc=f"Loading datapool from {len(data_archive)} individual files."):
for key in d.files:
out_data[key].append(d[key])
return out_data
def load_database(self):
print(f'Load saved patch embedding from "{self.database_path}"')
file_content = glob.glob(os.path.join(self.database_path, '*.npz'))
file_content = glob.glob(os.path.join(self.database_path, "*.npz"))
if len(file_content) == 1:
self.load_single_file(file_content[0])
elif len(file_content) > 1:
data = [np.load(f) for f in file_content]
prefetched_data = parallel_data_prefetch(self.load_multi_files, data,
n_proc=min(len(data), cpu_count()), target_data_type='dict')
prefetched_data = parallel_data_prefetch(
self.load_multi_files, data, n_proc=min(len(data), cpu_count()), target_data_type="dict"
)
self.database = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in
self.database}
self.database = {
key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in self.database
}
else:
raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?')
print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.')
def load_retriever(self, version='ViT-L/14', ):
def load_retriever(
self,
version="ViT-L/14",
):
model = FrozenClipImageEmbedder(model=version)
if torch.cuda.is_available():
model.cuda()
@@ -128,14 +127,14 @@ class Searcher(object):
return model
def load_searcher(self):
print(f'load searcher for database {self.database_name} from {self.searcher_savedir}')
print(f"load searcher for database {self.database_name} from {self.searcher_savedir}")
self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir)
print('Finished loading searcher.')
print("Finished loading searcher.")
def search(self, x, k):
if self.searcher is None and self.database['embedding'].shape[0] < 2e4:
self.train_searcher(k) # quickly fit searcher on the fly for small databases
assert self.searcher is not None, 'Cannot search with uninitialized searcher'
if self.searcher is None and self.database["embedding"].shape[0] < 2e4:
self.train_searcher(k) # quickly fit searcher on the fly for small databases
assert self.searcher is not None, "Cannot search with uninitialized searcher"
if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy()
if len(x.shape) == 3:
@@ -146,17 +145,19 @@ class Searcher(object):
nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k)
end = time.time()
out_embeddings = self.database['embedding'][nns]
out_img_ids = self.database['img_id'][nns]
out_pc = self.database['patch_coords'][nns]
out_embeddings = self.database["embedding"][nns]
out_img_ids = self.database["img_id"][nns]
out_pc = self.database["patch_coords"][nns]
out = {'nn_embeddings': out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis],
'img_ids': out_img_ids,
'patch_coords': out_pc,
'queries': x,
'exec_time': end - start,
'nns': nns,
'q_embeddings': query_embeddings}
out = {
"nn_embeddings": out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis],
"img_ids": out_img_ids,
"patch_coords": out_pc,
"queries": x,
"exec_time": end - start,
"nns": nns,
"q_embeddings": query_embeddings,
}
return out
@@ -173,20 +174,16 @@ if __name__ == "__main__":
type=str,
nargs="?",
default="a painting of a virus monster playing guitar",
help="the prompt to render"
help="the prompt to render",
)
parser.add_argument(
"--outdir",
type=str,
nargs="?",
help="dir to write results to",
default="outputs/txt2img-samples"
"--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples"
)
parser.add_argument(
"--skip_grid",
action='store_true',
action="store_true",
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
)
@@ -206,7 +203,7 @@ if __name__ == "__main__":
parser.add_argument(
"--plms",
action='store_true',
action="store_true",
help="use plms sampling",
)
@@ -287,14 +284,14 @@ if __name__ == "__main__":
parser.add_argument(
"--database",
type=str,
default='artbench-surrealism',
default="artbench-surrealism",
choices=DATABASES,
help="The database used for the search, only applied when --use_neighbors=True",
)
parser.add_argument(
"--use_neighbors",
default=False,
action='store_true',
action="store_true",
help="Include neighbors in addition to text prompt for conditioning",
)
parser.add_argument(
@@ -358,41 +355,43 @@ if __name__ == "__main__":
uc = None
if searcher is not None:
nn_dict = searcher(c, opt.knn)
c = torch.cat([c, torch.from_numpy(nn_dict['nn_embeddings']).cuda()], dim=1)
c = torch.cat([c, torch.from_numpy(nn_dict["nn_embeddings"]).cuda()], dim=1)
if opt.scale != 1.0:
uc = torch.zeros_like(c)
if isinstance(prompts, tuple):
prompts = list(prompts)
shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
conditioning=c,
batch_size=c.shape[0],
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
)
samples_ddim, _ = sampler.sample(
S=opt.ddim_steps,
conditioning=c,
batch_size=c.shape[0],
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
Image.fromarray(x_sample.astype(np.uint8)).save(
os.path.join(sample_path, f"{base_count:05}.png"))
os.path.join(sample_path, f"{base_count:05}.png")
)
base_count += 1
all_samples.append(x_samples_ddim)
if not opt.skip_grid:
# additionally, save as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = rearrange(grid, "n b c h w -> (n b) c h w")
grid = make_grid(grid, nrow=n_rows)
# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
grid_count += 1
print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")