| | |
| | |
| |
|
| | import copy |
| | import math |
| | from collections import namedtuple |
| | from contextlib import contextmanager, nullcontext |
| | from functools import partial, wraps |
| | from pathlib import Path |
| | from random import random |
| |
|
| | from einops import rearrange, repeat, reduce, pack, unpack |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | import torchvision.transforms as T |
| | from torch import einsum, nn |
| | from beartype.typing import List, Union |
| | from beartype import beartype |
| | from tqdm.auto import tqdm |
| | from pdb import set_trace as st |
| |
|
| | |
| | |
| |
|
| |
|
| | def exists(val): |
| | return val is not None |
| |
|
| |
|
| | def identity(t, *args, **kwargs): |
| | return t |
| |
|
| |
|
| | def divisible_by(numer, denom): |
| | return (numer % denom) == 0 |
| |
|
| |
|
| | def first(arr, d=None): |
| | if len(arr) == 0: |
| | return d |
| | return arr[0] |
| |
|
| |
|
| | def maybe(fn): |
| | @wraps(fn) |
| | def inner(x): |
| | if not exists(x): |
| | return x |
| | return fn(x) |
| |
|
| | return inner |
| |
|
| |
|
| | def once(fn): |
| | called = False |
| |
|
| | @wraps(fn) |
| | def inner(x): |
| | nonlocal called |
| | if called: |
| | return |
| | called = True |
| | return fn(x) |
| |
|
| | return inner |
| |
|
| |
|
| | print_once = once(print) |
| |
|
| |
|
| | def default(val, d): |
| | if exists(val): |
| | return val |
| | return d() if callable(d) else d |
| |
|
| |
|
| | def compact(input_dict): |
| | return {key: value for key, value in input_dict.items() if exists(value)} |
| |
|
| |
|
| | def maybe_transform_dict_key(input_dict, key, fn): |
| | if key not in input_dict: |
| | return input_dict |
| |
|
| | copied_dict = input_dict.copy() |
| | copied_dict[key] = fn(copied_dict[key]) |
| | return copied_dict |
| |
|
| |
|
| | def cast_uint8_images_to_float(images): |
| | if not images.dtype == torch.uint8: |
| | return images |
| | return images / 255 |
| |
|
| |
|
| | def module_device(module): |
| | return next(module.parameters()).device |
| |
|
| |
|
| | def zero_init_(m): |
| | nn.init.zeros_(m.weight) |
| | if exists(m.bias): |
| | nn.init.zeros_(m.bias) |
| |
|
| |
|
| | def eval_decorator(fn): |
| | def inner(model, *args, **kwargs): |
| | was_training = model.training |
| | model.eval() |
| | out = fn(model, *args, **kwargs) |
| | model.train(was_training) |
| | return out |
| |
|
| | return inner |
| |
|
| |
|
| | def pad_tuple_to_length(t, length, fillvalue=None): |
| | remain_length = length - len(t) |
| | if remain_length <= 0: |
| | return t |
| | return (*t, *((fillvalue, ) * remain_length)) |
| |
|
| |
|
| | |
| |
|
| |
|
| | class Identity(nn.Module): |
| | def __init__(self, *args, **kwargs): |
| | super().__init__() |
| |
|
| | def forward(self, x, *args, **kwargs): |
| | return x |
| |
|
| |
|
| | |
| |
|
| |
|
| | def log(t, eps: float = 1e-12): |
| | return torch.log(t.clamp(min=eps)) |
| |
|
| |
|
| | def l2norm(t): |
| | return F.normalize(t, dim=-1) |
| |
|
| |
|
| | def right_pad_dims_to(x, t): |
| | padding_dims = x.ndim - t.ndim |
| | if padding_dims <= 0: |
| | return t |
| | return t.view(*t.shape, *((1, ) * padding_dims)) |
| |
|
| |
|
| | def masked_mean(t, *, dim, mask=None): |
| | if not exists(mask): |
| | return t.mean(dim=dim) |
| |
|
| | denom = mask.sum(dim=dim, keepdim=True) |
| | mask = rearrange(mask, 'b n -> b n 1') |
| | masked_t = t.masked_fill(~mask, 0.) |
| |
|
| | return masked_t.sum(dim=dim) / denom.clamp(min=1e-5) |
| |
|
| |
|
| | def resize_image_to(image, |
| | target_image_size, |
| | clamp_range=None, |
| | mode='nearest'): |
| | orig_image_size = image.shape[-1] |
| |
|
| | if orig_image_size == target_image_size: |
| | return image |
| |
|
| | out = F.interpolate(image, target_image_size, mode=mode) |
| |
|
| | if exists(clamp_range): |
| | out = out.clamp(*clamp_range) |
| |
|
| | return out |
| |
|
| |
|
| | def calc_all_frame_dims(downsample_factors: List[int], frames): |
| | if not exists(frames): |
| | return (tuple(), ) * len(downsample_factors) |
| |
|
| | all_frame_dims = [] |
| |
|
| | for divisor in downsample_factors: |
| | assert divisible_by(frames, divisor) |
| | all_frame_dims.append((frames // divisor, )) |
| |
|
| | return all_frame_dims |
| |
|
| |
|
| | def safe_get_tuple_index(tup, index, default=None): |
| | if len(tup) <= index: |
| | return default |
| | return tup[index] |
| |
|
| |
|
| | |
| | |
| |
|
| |
|
| | def normalize_neg_one_to_one(img): |
| | return img * 2 - 1 |
| |
|
| |
|
| | def unnormalize_zero_to_one(normed_img): |
| | return (normed_img + 1) * 0.5 |
| |
|
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| |
|
| |
|
| | class PixelShuffleUpsample(nn.Module): |
| | """ |
| | code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts |
| | https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf |
| | """ |
| | def __init__(self, dim, dim_out=None): |
| | super().__init__() |
| | dim_out = default(dim_out, dim) |
| | conv = nn.Conv2d(dim, dim_out * 4, 1) |
| |
|
| | self.net = nn.Sequential(conv, nn.SiLU(), nn.PixelShuffle(2)) |
| |
|
| | self.init_conv_(conv) |
| |
|
| | def init_conv_(self, conv): |
| | o, i, h, w = conv.weight.shape |
| | conv_weight = torch.empty(o // 4, i, h, w) |
| | nn.init.kaiming_uniform_(conv_weight) |
| | conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...') |
| |
|
| | conv.weight.data.copy_(conv_weight) |
| | nn.init.zeros_(conv.bias.data) |
| |
|
| | def forward(self, x): |
| | return self.net(x) |
| |
|
| |
|
| | class ResidualBlock(nn.Module): |
| | def __init__(self, |
| | dim_in, |
| | dim_out, |
| | dim_inter=None, |
| | use_norm=True, |
| | norm_layer=nn.BatchNorm2d, |
| | bias=False): |
| | super().__init__() |
| | if dim_inter is None: |
| | dim_inter = dim_out |
| |
|
| | if use_norm: |
| | self.conv = nn.Sequential( |
| | norm_layer(dim_in), |
| | nn.ReLU(True), |
| | nn.Conv2d(dim_in, |
| | dim_inter, |
| | 3, |
| | 1, |
| | 1, |
| | bias=bias, |
| | padding_mode='reflect'), |
| | norm_layer(dim_inter), |
| | nn.ReLU(True), |
| | nn.Conv2d(dim_inter, |
| | dim_out, |
| | 3, |
| | 1, |
| | 1, |
| | bias=bias, |
| | padding_mode='reflect'), |
| | ) |
| | else: |
| | self.conv = nn.Sequential( |
| | nn.ReLU(True), |
| | nn.Conv2d(dim_in, dim_inter, 3, 1, 1), |
| | nn.ReLU(True), |
| | nn.Conv2d(dim_inter, dim_out, 3, 1, 1), |
| | ) |
| |
|
| | self.short_cut = None |
| | if dim_in != dim_out: |
| | self.short_cut = nn.Conv2d(dim_in, dim_out, 1, 1) |
| |
|
| | def forward(self, feats): |
| | feats_out = self.conv(feats) |
| | if self.short_cut is not None: |
| | feats_out = self.short_cut(feats) + feats_out |
| | else: |
| | feats_out = feats_out + feats |
| | return feats_out |
| |
|
| |
|
| | class Upsample(nn.Sequential): |
| | """Upsample module. |
| | Args: |
| | scale (int): Scale factor. Supported scales: 2^n and 3. |
| | num_feat (int): Channel number of intermediate features. |
| | """ |
| | def __init__(self, scale, num_feat): |
| | m = [] |
| | if (scale & (scale - 1)) == 0: |
| | for _ in range(int(math.log(scale, 2))): |
| | m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) |
| | m.append(nn.PixelShuffle(2)) |
| | elif scale == 3: |
| | m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) |
| | m.append(nn.PixelShuffle(3)) |
| | else: |
| | raise ValueError(f'scale {scale} is not supported. ' |
| | 'Supported scales: 2^n and 3.') |
| | super(Upsample, self).__init__(*m) |
| |
|
| |
|
| | class PixelUnshuffleUpsample(nn.Module): |
| | def __init__(self, output_dim, num_feat=128, num_out_ch=3, sr_ratio=4, *args, **kwargs) -> None: |
| | super().__init__() |
| |
|
| | self.conv_after_body = nn.Conv2d(output_dim, output_dim, 3, 1, 1) |
| | self.conv_before_upsample = nn.Sequential( |
| | nn.Conv2d(output_dim, num_feat, 3, 1, 1), |
| | nn.LeakyReLU(inplace=True)) |
| | self.upsample = Upsample(sr_ratio, num_feat) |
| | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) |
| |
|
| | def forward(self, x, input_skip_connection=True, *args, **kwargs): |
| | |
| | if input_skip_connection: |
| | x = self.conv_after_body(x) + x |
| | else: |
| | x = self.conv_after_body(x) |
| |
|
| | x = self.conv_before_upsample(x) |
| | x = self.conv_last(self.upsample(x)) |
| | return x |
| |
|
| |
|
| | class Conv3x3TriplaneTransformation(nn.Module): |
| | |
| | def __init__(self, input_dim, output_dim) -> None: |
| | super().__init__() |
| |
|
| | self.conv_after_unpachify = nn.Sequential( |
| | nn.Conv2d(input_dim, output_dim, 3, 1, 1), |
| | nn.LeakyReLU(inplace=True) |
| | ) |
| |
|
| | self.conv_before_rendering = nn.Sequential( |
| | nn.Conv2d(output_dim, output_dim, 3, 1, 1), |
| | nn.LeakyReLU(inplace=True)) |
| |
|
| | def forward(self, unpachified_latent): |
| | latent = self.conv_after_unpachify(unpachified_latent) |
| | latent = self.conv_before_rendering(latent) + latent |
| | return latent |
| |
|
| |
|
| | |
| | class NearestConvSR(nn.Module): |
| | """ |
| | code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts |
| | https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf |
| | """ |
| | def __init__(self, output_dim, num_feat=128, num_out_ch=3, sr_ratio=4, *args, **kwargs) -> None: |
| | super().__init__() |
| |
|
| | self.upscale = sr_ratio |
| |
|
| | self.conv_after_body = nn.Conv2d(output_dim, output_dim, 3, 1, 1) |
| | self.conv_before_upsample = nn.Sequential(nn.Conv2d(output_dim, num_feat, 3, 1, 1), |
| | nn.LeakyReLU(inplace=True)) |
| | self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) |
| | if self.upscale == 4: |
| | self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) |
| | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) |
| | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) |
| | |
| | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) |
| |
|
| | def forward(self, x, *args, **kwargs): |
| |
|
| | |
| | x = self.conv_after_body(x) + x |
| | x = self.conv_before_upsample(x) |
| | x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) |
| | if self.upscale == 4: |
| | x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) |
| | x = self.conv_last(self.lrelu(self.conv_hr(x))) |
| |
|
| |
|
| | return x |
| |
|
| | |
| | class NearestConvSR_Residual(NearestConvSR): |
| | |
| | |
| | def __init__(self, output_dim, num_feat=128, num_out_ch=3, sr_ratio=4, *args, **kwargs) -> None: |
| | super().__init__(output_dim, num_feat, num_out_ch, sr_ratio, *args, **kwargs) |
| | |
| | self.act = nn.Tanh() |
| |
|
| | def forward(self, x, base_x, *args, **kwargs): |
| | |
| | |
| | |
| | x = super().forward(x) |
| | x = self.act(x) |
| | scale = x.shape[-1] // base_x.shape[-1] |
| | x = x + F.interpolate(base_x, None, scale, 'bilinear', False) |
| |
|
| | |
| | return x |
| | |
| | class UpsampleOneStep(nn.Sequential): |
| | """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) |
| | Used in lightweight SR to save parameters. |
| | |
| | Args: |
| | scale (int): Scale factor. Supported scales: 2^n and 3. |
| | num_feat (int): Channel number of intermediate features. |
| | |
| | """ |
| |
|
| | def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): |
| | self.num_feat = num_feat |
| | self.input_resolution = input_resolution |
| | m = [] |
| | m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) |
| | m.append(nn.PixelShuffle(scale)) |
| | super(UpsampleOneStep, self).__init__(*m) |
| |
|
| | def flops(self): |
| | H, W = self.input_resolution |
| | flops = H * W * self.num_feat * 3 * 9 |
| | return flops |
| |
|
| | |
| |
|