Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 5d67df9d authored by KOPANAS Georgios's avatar KOPANAS Georgios
Browse files

Add python training implementation

parent 42268ec6
No related branches found
No related tags found
No related merge requests found
Showing
with 2027 additions and 0 deletions
*.pyc
\ No newline at end of file
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
#define BINMASK_H
// A BitMask represents a bool array of shape (H, W, N). We pack values into
// the bits of unsigned ints; a single unsigned int has B = 32 bits, so to hold
// all values we use H * W * (N / B) = H * W * D values. We want to store
// BitMasks in shared memory, so we assume that the memory has already been
// allocated for it elsewhere.
class BitMask {
public:
__device__ BitMask(unsigned int* data, int H, int W, int N)
: data(data), H(H), W(W), B(8 * sizeof(unsigned int)), D(N / B) {
// TODO: check if the data is null.
N = ceilf(N % 32); // take ceil incase N % 32 != 0
block_clear(); // clear the data
}
// Use all threads in the current block to clear all bits of this BitMask
__device__ void block_clear() {
for (int i = threadIdx.x; i < H * W * D; i += blockDim.x) {
data[i] = 0;
}
__syncthreads();
}
__device__ int _get_elem_idx(int y, int x, int d) {
return y * W * D + x * D + d / B;
}
__device__ int _get_bit_idx(int d) {
return d % B;
}
// Turn on a single bit (y, x, d)
__device__ void set(int y, int x, int d) {
int elem_idx = _get_elem_idx(y, x, d);
int bit_idx = _get_bit_idx(d);
const unsigned int mask = 1U << bit_idx;
atomicOr(data + elem_idx, mask);
}
// Turn off a single bit (y, x, d)
__device__ void unset(int y, int x, int d) {
int elem_idx = _get_elem_idx(y, x, d);
int bit_idx = _get_bit_idx(d);
const unsigned int mask = ~(1U << bit_idx);
atomicAnd(data + elem_idx, mask);
}
// Check whether the bit (y, x, d) is on or off
__device__ bool get(int y, int x, int d) {
int elem_idx = _get_elem_idx(y, x, d);
int bit_idx = _get_bit_idx(d);
return (data[elem_idx] >> bit_idx) & 1U;
}
// Compute the number of bits set in the row (y, x, :)
__device__ int count(int y, int x) {
int total = 0;
for (int i = 0; i < D; ++i) {
int elem_idx = y * W * D + x * D + i;
unsigned int elem = data[elem_idx];
total += __popc(elem);
}
return total;
}
private:
unsigned int* data;
int H, W, B, D;
};
#include <torch/extension.h>
#include "soft_depth_test.h"
#include "rasterize_points.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("soft_depth_test", &SoftDepthTest);
m.def("rasterize_points", &RasterizePoints);
m.def("rasterize_points_backward", &RasterizePointsBackward);
}
\ No newline at end of file
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
// Given a pixel coordinate 0 <= i < S, convert it to a normalized device
// coordinate in the range [-1, 1]. We divide the NDC range into S evenly-sized
// pixels, and assume that each pixel falls in the *center* of its range.
__device__ inline float PixToNdc(int i, int S) {
// NDC x-offset + (i * pixel_width + half_pixel_width)
return -1 + (2 * i + 1.0f) / S;
}
__device__ inline float NdcToPix(float i, int S) {
return ((i + 1.0)*S - 1.0)/2.0;
}
// The maximum number of points per pixel that we can return. Since we use
// thread-local arrays to hold and sort points, the maximum size of the array
// needs to be known at compile time. There might be some fancy template magic
// we could use to make this more dynamic, but for now just fix a constant.
// TODO: is 8 enough? Would increasing have performance considerations?
const int32_t kMaxPointsPerPixel = 101;
const int32_t kMaxPointPerPixelLocal = 101;
template <typename T>
__device__ inline void BubbleSort(T* arr, int n) {
bool already_sorted;
// Bubble sort. We only use it for tiny thread-local arrays (n < 8); in this
// regime we care more about warp divergence than computational complexity.
for (int i = 0; i < n - 1; ++i) {
already_sorted=true;
for (int j = 0; j < n - i - 1; ++j) {
if (arr[j + 1] < arr[j]) {
already_sorted = false;
T temp = arr[j];
arr[j] = arr[j + 1];
arr[j + 1] = temp;
}
}
if (already_sorted)
break;
}
}
__device__ inline void BubbleSort2(int32_t* arr, const float* points, int n) {
bool already_sorted;
// Bubble sort. We only use it for tiny thread-local arrays (n < 8); in this
// regime we care more about warp divergence than computational complexity.
for (int i = 0; i < n - 1; ++i) {
already_sorted=true;
for (int j = 0; j < n - i - 1; ++j) {
float p_j0_z = points[arr[j]*3 + 2];
float p_j1_z = points[arr[j+1]*3 + 2];
if (p_j1_z < p_j0_z) {
already_sorted = false;
int32_t temp = arr[j];
arr[j] = arr[j + 1];
arr[j + 1] = temp;
}
}
if (already_sorted)
break;
}
}
This diff is collapsed.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
#include <torch/extension.h>
#include <cstdio>
#include <tuple>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizePointsGKCuda(
const torch::Tensor& points,
const torch::Tensor& colors,
const torch::Tensor& point_score,
const torch::Tensor& inv_cov,
const int max_radius,
const int image_height,
const int image_width,
const int points_per_pixel,
const float zfar,
const float znear,
const float gamma);
// ****************************************************************************
// * BACKWARD PASS *
// ****************************************************************************
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
RasterizePointsBackwardCuda(
const torch::Tensor& points,
const torch::Tensor& colors,
const torch::Tensor& inv_cov,
const int max_radius,
const torch::Tensor& idxs,
const torch::Tensor& k_idxs,
const float znear,
const float zfar,
const float gamma,
const torch::Tensor& grad_out_color);
// Args:
// points: Tensor of shape (P, 3) giving (packed) positions for
// points in all N pointclouds in the batch where P is the total
// number of points in the batch across all pointclouds. These points
// are expected to be in NDC coordinates in the range [-1, 1].
// idxs: int32 Tensor of shape (N, H, W, K) (from forward pass)
// grad_zbuf: float32 Tensor of shape (N, H, W, K) giving upstream gradient
// d(loss)/d(zbuf) of the distances from each pixel to its nearest
// points.
// grad_dists: Tensor of shape (N, H, W, K) giving upstream gradient
// d(loss)/d(dists) of the dists tensor returned by the forward
// pass.
//
// Returns:
// grad_points: float32 Tensor of shape (N, P, 3) giving downstream gradients
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
RasterizePointsBackward(
const torch::Tensor& points,
const torch::Tensor& colors,
const torch::Tensor& inv_cov,
const int max_radius,
const torch::Tensor& idxs,
const torch::Tensor& k_idxs,
const float znear,
const float zfar,
const float gamma,
const torch::Tensor& grad_out_color) {
if (points.type().is_cuda()) {
return RasterizePointsBackwardCuda(points, colors,
inv_cov, max_radius,
idxs, k_idxs, znear, zfar,
gamma, grad_out_color);
} else {
AT_ERROR("No CPU support");
}
}
// ****************************************************************************
// * MAIN ENTRY POINT *
// ****************************************************************************
// This is the main entry point for the forward pass of the point rasterizer;
// it uses either naive or coarse-to-fine rasterization based on bin_size.
//
// Args:
// points: Tensor of shape (P, 3) giving (packed) positions for
// points in all N pointclouds in the batch where P is the total
// number of points in the batch across all pointclouds. These points
// are expected to be in NDC coordinates in the range [-1, 1].
// cloud_to_packed_first_idx: LongTensor of shape (N) giving the index in
// points_packed of the first point in each pointcloud
// in the batch where N is the batch size.
// num_points_per_cloud: LongTensor of shape (N) giving the number of points
// for each pointcloud in the batch.
// radius: Radius of each point (in NDC units)
// image_size: (S) Size of the image to return (in pixels)
// points_per_pixel: (K) The number of points to return for each pixel
// bin_size: Bin size (in pixels) for coarse-to-fine rasterization. Setting
// bin_size=0 uses naive rasterization instead.
// max_points_per_bin: The maximum number of points allowed to fall into each
// bin when using coarse-to-fine rasterization.
//
// Returns:
// idxs: int32 Tensor of shape (N, S, S, K) giving the indices of the
// closest K points along the z-axis for each pixel, padded with -1 for
// pixels hit by fewer than K points. The indices refer to points in
// points packed i.e a tensor of shape (P, 3) representing the flattened
// points for all pointclouds in the batch.
// zbuf: float32 Tensor of shape (N, S, S, K) giving the depth of each of each
// closest point for each pixel
// dists: float32 Tensor of shape (N, S, S, K) giving squared Euclidean
// distance in the (NDC) x/y plane between each pixel and its K closest
// points along the z axis.
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizePoints(
const torch::Tensor& points,
const torch::Tensor& colors,
const torch::Tensor& point_score,
const torch::Tensor& inv_cov,
const int max_radius,
const int image_height,
const int image_width,
const int points_per_pixel,
const float zfar,
const float znear,
const float gamma)
{
return RasterizePointsGKCuda(
points,
colors,
point_score,
inv_cov,
max_radius,
image_height,
image_width,
points_per_pixel,
zfar,
znear,
gamma);
}
import torch
from diff_rasterization import _C
def rasterize_points(
points,
features,
point_score,
inv_cov,
max_radius,
image_height: int = 256,
image_width: int = 256,
points_per_pixel: int = 8,
zfar: float = -0.5,
znear: float = -0.5,
gamma: float = None
):
return _RasterizePoints.apply(
points,
features,
point_score,
inv_cov,
max_radius,
image_height,
image_width,
points_per_pixel,
zfar,
znear,
gamma
)
class _RasterizePoints(torch.autograd.Function):
@staticmethod
def forward(
ctx,
points, # (P, 3)
colors, # (P, C)
point_score, # (P, 1)
inv_cov, # (P, 4)
max_radius,
image_height: int = 256,
image_width: int = 256,
points_per_pixel: int = 8,
zfar: float = -0.5,
znear: float = -0.5,
gamma: float = None
):
# TODO: Add better error handling for when there are more than
# max_points_per_bin in any bin.
args = (
points,
colors,
point_score,
inv_cov,
max_radius,
image_height,
image_width,
points_per_pixel,
zfar,
znear,
gamma
)
idx, color, k_idxs, depth_gmms, num_gmms, point_score = _C.rasterize_points(*args)
ctx.znear = znear
ctx.zfar = zfar
ctx.gamma = gamma
ctx.max_radius = max_radius
ctx.save_for_backward(points, colors, inv_cov, idx, k_idxs)
return idx, color, depth_gmms, num_gmms, point_score
@staticmethod
def backward(ctx, grad_idx, grad_out_color, grad_depth_gmms, grad_num_gmms, grad_point_score):
grad_points = None
grad_colors = None
grad_inv_cov = None
grad_max_radius = None
grad_image_height = None
grad_image_width = None
grad_points_per_pixel = None
grad_bin_size = None
grad_zfar = None
grad_znear = None
grad_gamma = None
znear = ctx.znear
zfar = ctx.zfar
gamma = ctx.gamma
max_radius = ctx.max_radius
points, colors, inv_cov, idx, k_idxs = ctx.saved_tensors
args = (points, colors, inv_cov, max_radius, idx, k_idxs, znear, zfar, gamma, grad_out_color)
grad_points, grad_colors, grad_inv_cov = _C.rasterize_points_backward(*args)
grads = (
grad_points,
grad_colors,
grad_inv_cov,
grad_max_radius,
grad_image_height,
grad_image_width,
grad_points_per_pixel,
grad_bin_size,
grad_zfar,
grad_znear,
grad_gamma
)
return grads
from typing import NamedTuple
import torch.nn as nn
from .rasterize_points import rasterize_points
class PointsRasterizationSettings(NamedTuple):
image_height: int = 256
image_width: int =256
points_per_pixel: int = 8
zfar: float = None
znear: float = None
gamma: float = None
class PointsRasterizer(nn.Module):
def __init__(self, raster_settings):
super().__init__()
self.raster_settings = raster_settings
def forward(self, points_screen, features, point_score, inv_cov, max_radius):
raster_settings = self.raster_settings
idx, color, depth_gmms, num_gmms, mask = rasterize_points(
points_screen,
features,
point_score,
inv_cov,
max_radius,
image_height=raster_settings.image_height,
image_width=raster_settings.image_width,
points_per_pixel=raster_settings.points_per_pixel,
zfar=raster_settings.zfar,
znear=raster_settings.znear,
gamma=raster_settings.gamma
)
return color, depth_gmms, num_gmms, mask
from setuptools import setup
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
setup(
name="diff_rast",
ext_modules=[
CUDAExtension(
name="diff_rasterization._C",
sources=["soft_depth_test.cu", "rasterize_points.cu", "ext.cpp"],
extra_compile_args={"nvcc": [], "cxx": []})
],
cmdclass={
'build_ext': BuildExtension
}
)
#include <stdio.h>
#include <torch/extension.h>
#include "rasterization_utils.cuh"
__device__ inline float H_step(float d0, float d1) {
if (d0 < d1)
return 0;
else if (d0 > d1)
return 1;
return 0.5;
}
__device__ inline float H_tri(float d0, float d1) {
float sigma = 7.0;
if (d0 < d1 - sigma)
return 0;
else if (d0 > d1 + sigma)
return 1;
else if (d0 < d1) {
return (d0 - d1 + sigma) * (d0 - d1 + sigma) / (2 * sigma * sigma);
}
else if (d0 > d1) {
return 1.0 - (d1 + sigma - d0) * (d1 + sigma - d0) / (2 * sigma * sigma);
}
return 0.5;
}
__device__ float recurseProd(int start_v, int current_v,
const float* depth_gmms, // (N, H, W, 150, 2)
const int* num_gmms,
float d,
const int N,
const int H,
const int W,
const int yi,
const int xi
) {
if (start_v == current_v)
return 1.0;
const int id_NHW = current_v * H * W + yi * W + xi;
const int v_point_num = num_gmms[id_NHW];
float inner_sum = 0;
float lower_prod = recurseProd(start_v, (current_v + 1) % N, depth_gmms, num_gmms, d, N, H, W, yi, xi);
for (int k = v_point_num - 1; k >= 0; k--) {
const int id_NHWP2 =
current_v * H * W * (kMaxPointsPerPixel + 1) * 2 +
yi * W * (kMaxPointsPerPixel + 1) * 2 +
xi * (kMaxPointsPerPixel + 1) * 2 +
k * 2;
float d_new = depth_gmms[id_NHWP2];
float h = H_tri(d_new, d);
if (h == 0) {
break;//We start from high depth back to front, if this h is zero, all the subsequent will also be;
}
else {
float alpha_new = depth_gmms[id_NHWP2 + 1];
inner_sum += alpha_new * h;
}
}
return inner_sum * lower_prod;
}
__device__ float linProd(int start_v,
const float* depth_gmms, // (N, H, W, 150, 2)
const int* num_gmms,
float d,
const int N,
const int H,
const int W,
const int yi,
const int xi
) {
float prod = 1.0;
for (int v_it = 1; v_it < N; v_it++) {
const int current_v = (start_v + v_it) % N;
const int id_NHW = current_v * H * W + yi * W + xi;
const int v_point_num = num_gmms[id_NHW];
float inner_sum = 0;
for (int k = v_point_num - 1; k >= 0; k--) {
const int id_NHWP2 =
current_v * H * W * (kMaxPointsPerPixel + 1) * 2 +
yi * W * (kMaxPointsPerPixel + 1) * 2 +
xi * (kMaxPointsPerPixel + 1) * 2 +
k * 2;
float d_new = depth_gmms[id_NHWP2];
float h = H_tri(d_new, d);
if (h == 0) {
break;//We start from high depth back to front, if this h is zero, all the subsequent will also be;
}
else {
float alpha_new = depth_gmms[id_NHWP2 + 1];
inner_sum += alpha_new * h;
}
}
prod *= inner_sum;
}
return prod;
}
__global__ void ComputeProbabilityDMM(
const float* depth_gmms, // (N, H, W, 150, 2)
const int* num_gmms,
const int N,
const int H,
const int W,
float* output)
{
// One thread per output pixel
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockDim.x * blockIdx.x + threadIdx.x;
for (int i = tid; i < H * W; i += num_threads) {
const int yi = i / W;
const int xi = i % W;
for (int v = 0; v < N; v++) { //Loop over all the views
float p = 0.0;
const int id_NHW = v * H * W + yi * W + xi;
const int v_point_num = num_gmms[id_NHW];
for (int k = 0; k < v_point_num; k++) {
const int id_NHWP2 =
v * H * W * (kMaxPointsPerPixel + 1) * 2 +
yi * W * (kMaxPointsPerPixel + 1) * 2 +
xi * (kMaxPointsPerPixel + 1) * 2 +
k * 2;
float d = depth_gmms[id_NHWP2];
float alpha = depth_gmms[id_NHWP2 + 1];
//float prod = recurseProd(v, (v+1)%N,depth_gmms, num_gmms, d, N, H, W, yi, xi);
float prod = linProd(v, depth_gmms, num_gmms, d, N, H, W, yi, xi);
p += alpha * prod;
}
output[id_NHW] = p;
}
}
}
torch::Tensor SoftDepthTestCuda(
const torch::Tensor& depth_gmms,
const torch::Tensor& num_gmms
)
{
const int N = depth_gmms.size(0);
const int H = depth_gmms.size(1);
const int W = depth_gmms.size(2);
const int P = depth_gmms.size(3);
auto float_opts = depth_gmms.options().dtype(torch::kFloat32);
torch::Tensor out_probability = torch::full({ N, 1, H, W }, 0.0, float_opts);
const size_t blocks = 1024;
const size_t threads = 64;
ComputeProbabilityDMM << <blocks, threads >> > (
depth_gmms.contiguous().data<float>(),
num_gmms.contiguous().data<int32_t>(),
N,
H,
W,
out_probability.contiguous().data<float>());
return out_probability;
}
#pragma once
#include <torch/extension.h>
#include <cstdio>
torch::Tensor SoftDepthTestCuda(
const torch::Tensor& depth_gmms,
const torch::Tensor& num_gmms
);
torch::Tensor SoftDepthTest(
const torch::Tensor& depth_gmms,
const torch::Tensor& num_gmms)
{
return SoftDepthTestCuda(
depth_gmms,
num_gmms);
}
\ No newline at end of file
import torch
from diff_rasterization import _C
class _SoftDepthTest(torch.autograd.Function):
@staticmethod
def forward(
ctx,
depth_gmms,
num_gmms,
):
args = (
depth_gmms,
num_gmms,
)
prob_map = _C.soft_depth_test(*args)
return prob_map
"""
@staticmethod
def backward(ctx, grad_idx, grad_out_color, grad_depth, grad_mask):
grad_points = None
grad_colors = None
grad_inv_cov = None
grad_max_radius = None
grad_image_height = None
grad_image_width = None
grad_points_per_pixel = None
grad_bin_size = None
grad_zfar = None
grad_znear = None
grad_gamma = None
znear = ctx.znear
zfar = ctx.zfar
gamma = ctx.gamma
max_radius = ctx.max_radius
points, colors, inv_cov, idx, k_idxs = ctx.saved_tensors
args = (points, colors, inv_cov, max_radius, idx, k_idxs, znear, zfar, gamma, grad_out_color)
grad_points, grad_colors, grad_inv_cov = _C.rasterize_points_backward(*args)
grads = (
grad_points,
grad_colors,
grad_inv_cov,
grad_max_radius,
grad_image_height,
grad_image_width,
grad_points_per_pixel,
grad_bin_size,
grad_zfar,
grad_znear,
grad_gamma
)
return grads
"""
\ No newline at end of file
import torch
from .modules.lpips import LPIPS
def lpips(x: torch.Tensor,
y: torch.Tensor,
net_type: str = 'alex',
version: str = '0.1'):
r"""Function that measures
Learned Perceptual Image Patch Similarity (LPIPS).
Arguments:
x, y (torch.Tensor): the input tensors to compare.
net_type (str): the network type to compare the features:
'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
version (str): the version of LPIPS. Default: 0.1.
"""
device = x.device
criterion = LPIPS(net_type, version).to(device)
return criterion(x, y)
import torch
import torch.nn as nn
from .networks import get_network, LinLayers
from .utils import get_state_dict
class LPIPS(nn.Module):
r"""Creates a criterion that measures
Learned Perceptual Image Patch Similarity (LPIPS).
Arguments:
net_type (str): the network type to compare the features:
'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
version (str): the version of LPIPS. Default: 0.1.
"""
def __init__(self, net_type: str = 'alex', version: str = '0.1'):
assert version in ['0.1'], 'v0.1 is only supported now'
super(LPIPS, self).__init__()
# pretrained network
self.net = get_network(net_type)
# linear layers
self.lin = LinLayers(self.net.n_channels_list)
self.lin.load_state_dict(get_state_dict(net_type, version))
def forward(self, x: torch.Tensor, y: torch.Tensor):
feat_x, feat_y = self.net(x), self.net(y)
diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
return torch.sum(torch.cat(res, 0), 0, True)
from typing import Sequence
from itertools import chain
import torch
import torch.nn as nn
from torchvision import models
from .utils import normalize_activation
def get_network(net_type: str):
if net_type == 'alex':
return AlexNet()
elif net_type == 'squeeze':
return SqueezeNet()
elif net_type == 'vgg':
return VGG16()
else:
raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
class LinLayers(nn.ModuleList):
def __init__(self, n_channels_list: Sequence[int]):
super(LinLayers, self).__init__([
nn.Sequential(
nn.Identity(),
nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
) for nc in n_channels_list
])
for param in self.parameters():
param.requires_grad = False
class BaseNet(nn.Module):
def __init__(self):
super(BaseNet, self).__init__()
# register buffer
self.register_buffer(
'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
self.register_buffer(
'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
def set_requires_grad(self, state: bool):
for param in chain(self.parameters(), self.buffers()):
param.requires_grad = state
def z_score(self, x: torch.Tensor):
return (x - self.mean) / self.std
def forward(self, x: torch.Tensor):
x = self.z_score(x)
output = []
for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
x = layer(x)
if i in self.target_layers:
output.append(normalize_activation(x))
if len(output) == len(self.target_layers):
break
return output
class SqueezeNet(BaseNet):
def __init__(self):
super(SqueezeNet, self).__init__()
self.layers = models.squeezenet1_1(True).features
self.target_layers = [2, 5, 8, 10, 11, 12, 13]
self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
self.set_requires_grad(False)
class AlexNet(BaseNet):
def __init__(self):
super(AlexNet, self).__init__()
self.layers = models.alexnet(True).features
self.target_layers = [2, 5, 8, 10, 12]
self.n_channels_list = [64, 192, 384, 256, 256]
self.set_requires_grad(False)
class VGG16(BaseNet):
def __init__(self):
super(VGG16, self).__init__()
self.layers = models.vgg16(True).features
self.target_layers = [4, 9, 16, 23, 30]
self.n_channels_list = [64, 128, 256, 512, 512]
self.set_requires_grad(False)
from collections import OrderedDict
import torch
def normalize_activation(x, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
return x / (norm_factor + eps)
def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
# build url
url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
+ f'master/lpips/weights/v{version}/{net_type}.pth'
# download
old_state_dict = torch.hub.load_state_dict_from_url(
url, progress=True,
map_location=None if torch.cuda.is_available() else torch.device('cpu')
)
# rename keys
new_state_dict = OrderedDict()
for key, val in old_state_dict.items():
new_key = key
new_key = new_key.replace('lin', '')
new_key = new_key.replace('model.', '')
new_state_dict[new_key] = val
return new_state_dict
from pathlib import Path
import os
from PIL import Image
import torch
import torchvision.transforms.functional as tf
import pytorch_ssim
from lpips_pytorch import lpips
import json
def mse(img1, img2):
return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
def psnr(img1, img2):
mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
return 20 * torch.log10(1.0 / torch.sqrt(mse))
device = torch.device("cuda:0")
torch.cuda.set_device(device)
scenes = [Path(r"F:\results\results_db\truck\leave_one_out")]
full_dict = {}
for scene in scenes:
scene_name = str(scene).split("\\")[3]
print(scene_name)
full_dict[scene_name] = {}
for algorithm in os.listdir(scene):
print(algorithm)
if algorithm == "masks":
continue
full_dict[scene_name][algorithm] = {}
gt_dir = scene/algorithm/"gt"
renders_dir = scene/algorithm/"renders"
renders = []
gts = []
for fname in os.listdir(renders_dir):
render = Image.open(renders_dir/fname)
gt = Image.open(gt_dir/fname)
renders.append(tf.to_tensor(render).unsqueeze(0).to(device))
gts.append(tf.to_tensor(gt).unsqueeze(0).to(device))
ssims = []
psnrs = []
lpipss = []
for idx in range(len(renders)):
print(idx)
ssims.append(pytorch_ssim.ssim(renders[idx], gts[idx]))
psnrs.append(psnr(renders[idx], gts[idx]))
lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg'))
print("SSIM: {}".format(torch.tensor(ssims).mean()))
print("PSNR: {}".format(torch.tensor(psnrs).mean()))
print("LPIPS: {}".format(torch.tensor(lpipss).mean()))
full_dict[scene_name][algorithm].update({"SSIM": torch.tensor(ssims).mean().item(),
"PSNR": torch.tensor(psnrs).mean().item(),
"LPIPS": torch.tensor(lpipss).mean().item()})
with open('F:/results/results_db/leave_one_out_ablation_data.json', 'w') as fp:
json.dump(full_dict, fp, indent=True)
import torch
def l1_loss(network_output, gt):
return torch.abs((network_output - gt)).mean()
import torch.nn as nn
from model_defs.netblocs import ConvModule, FixupResidualChain
class FixUpResNet(nn.Module):
def __init__(self, in_channels, internal_depth, blocks, kernel_size, dropout):
super(FixUpResNet, self).__init__()
self.encoder = nn.Sequential(
ConvModule(in_channels, internal_depth, ksize=kernel_size, pad=True, activation="relu", norm_layer=None, padding_mode="reflect"),
nn.Dropout(dropout),
FixupResidualChain(internal_depth, ksize=kernel_size, depth=int(blocks/4), padding_mode="reflect", dropout=dropout),
)
self.decoder = nn.Sequential(
FixupResidualChain(internal_depth, ksize=kernel_size, depth=int(blocks/4), padding_mode="reflect", dropout=dropout),
ConvModule(internal_depth, 3, ksize=kernel_size, pad=True, activation=None, norm_layer=None, padding_mode="reflect")
)
def forward(self, x, w):
x_encoded = self.encoder(x)
x_scaled = (x_encoded*w).sum(dim=0, keepdim=True)/(w.sum(dim=0, keepdim=True)+0.0000001)
x_out = self.decoder(x_scaled)
return x_out
\ No newline at end of file
import torch.nn as nn
import torch as th
from collections import OrderedDict
import numpy as np
class ConvModule(nn.Module):
"""Basic convolution module with conv + norm(optional) + activation(optional).
Args:
n_in(int): number of input channels.
n_out(int): number of output channels.
ksize(int): size of the convolution kernel (square).
stride(int): downsampling factor
pad(bool): if True, zero pad the convolutions to maintain a constant size.
activation(Union[str, None]): nonlinear activation function between convolutions.
norm_layer(Union[str, None]): normalization to apply between the convolution modules.
"""
def __init__(self, n_in, n_out, ksize=3, stride=1, pad=True,
activation=None, norm_layer=None, padding_mode="reflect", use_bias = False):
super(ConvModule, self).__init__()
assert isinstance(
n_in, int) and n_in > 0, "Input channels should be a positive integer got {}".format(n_in)
assert isinstance(
n_out, int) and n_out > 0, "Output channels should be a positive integer got {}".format(n_out)
assert isinstance(
ksize, int) and ksize > 0, "Kernel size should be a positive integer got {}".format(ksize)
layers = OrderedDict()
padding = (ksize - 1) // 2 if pad else 0
if padding_mode=="reflect":
layers["pad"] = nn.ReflectionPad2d(padding)
padding=0
layers["conv"] = nn.Conv2d(n_in, n_out, ksize, stride=stride,
padding=padding, bias=use_bias, padding_mode=padding_mode)
if norm_layer is not None:
layers["norm"] = _get_norm_layer(norm_layer, n_out)
if activation is not None:
layers["activation"] = _get_activation(activation)
# Initialize parameters
_init_fc_or_conv(layers["conv"], activation)
self.net = nn.Sequential(layers)
def forward(self, x):
x=self.net(x)
return x
class ConvChain(nn.Module):
"""Linear chain of convolution layers.
Args:
n_in(int): number of input channels.
ksize(int or list of int): size of the convolution kernel (square).
width(int or list of int): number of features channels in the intermediate layers.
depth(int): number of layers
strides(list of int): stride between kernels. If None, defaults to 1 for all.
pad(bool): if True, zero pad the convolutions to maintain a constant size.
activation(str): nonlinear activation function between convolutions.
norm_layer(str): normalization to apply between the convolution modules.
"""
def __init__(self, n_in, ksize=3, width=64, depth=3, strides=None, pad=True,
activation="relu", norm_layer=None, padding_mode="reflect"):
super(ConvChain, self).__init__()
assert isinstance(
n_in, int) and n_in > 0, "Input channels should be a positive integer"
assert (isinstance(ksize, int) and ksize > 0) or isinstance(
ksize, list), "Kernel size should be a positive integer or a list of integers"
assert isinstance(
depth, int) and depth > 0, "Depth should be a positive integer"
assert isinstance(width, int) or isinstance(
width, list), "Width should be a list or an int"
_in = [n_in]
if strides is None:
_strides = [1]*depth
else:
assert isinstance(strides, list), "strides should be a list"
assert len(strides) == depth, "strides should have `depth` elements"
_strides = strides
if isinstance(width, int):
_in = _in + [width]*(depth-1)
_out = [width]*depth
elif isinstance(width, list):
assert len(width) == depth, "Specifying width with a list should have `depth` elements"
_in = _in + width[:-1]
_out = width
if isinstance(ksize, int):
_ksizes = [ksize]*depth
elif isinstance(ksize, list):
assert len(
ksize) == depth, "kernel size list should have 'depth' entries"
_ksizes = ksize
_activations = [activation]*depth
_padding_modes = [padding_mode]*depth
# dont normalize in/out layers
_norms = [norm_layer]*depth
# Core processing layers, no norm at the first layer
layers=OrderedDict()
for lvl in range(depth):
layers["conv{}".format(lvl)] = ConvModule(_in[lvl], _out[lvl], _ksizes[lvl], stride=_strides[lvl], pad=pad,
activation=_activations[lvl], norm_layer=_norms[lvl], padding_mode=_padding_modes[lvl], use_bias=False)
self.net = nn.Sequential(layers)
def forward(self, x):
x = self.net(x)
return x
class FixupBasicBlock(nn.Module):
# https://openreview.net/pdf?id=H1gsz30cKX
expansion = 1
def __init__(self, n_features, ksize=3, padding=True, padding_mode="reflect",
activation="relu", dropout=0.0):
super(FixupBasicBlock, self).__init__()
self.bias1a = nn.Parameter(th.zeros(1))
self.conv1 = ConvModule(n_features, n_features, ksize=ksize, stride=1,
pad=padding, activation=None, norm_layer=None,
padding_mode=padding_mode)
self.dropout1 = nn.Dropout(dropout)
self.bias1b = nn.Parameter(th.zeros(1))
self.activation = _get_activation(activation)
self.bias2a = nn.Parameter(th.zeros(1))
self.conv2 = ConvModule(n_features, n_features, ksize=ksize, stride=1,
pad=padding, activation=None, norm_layer=None,
padding_mode=padding_mode)
self.dropout2 = nn.Dropout(dropout)
self.scale = nn.Parameter(th.ones(1))
self.bias2b = nn.Parameter(th.zeros(1))
self.activation2 = _get_activation(activation)
def forward(self, x):
identity = x
out = self.conv1(x + self.bias1a)
out = self.dropout1(out)
out = self.activation(out + self.bias1b)
out = self.conv2(out + self.bias2a)
out = self.dropout2(out)
out = out * self.scale + self.bias2b
out += identity
out = self.activation2(out)
return out
class FixupResidualChain(nn.Module):
"""Linear chain of residual blocks.
Args:
n_features(int): number of input channels.
ksize(int): size of the convolution kernel (square).
depth(int): number of residual blocks
convs_per_block(int): number of convolution per residual block
activation(str): nonlinear activation function between convolutions.
"""
def __init__(self, n_features, depth=3, ksize=3, activation="relu", padding_mode="reflect", dropout=0.0):
super(FixupResidualChain, self).__init__()
assert isinstance(
n_features, int) and n_features > 0, "Number of feature channels should be a positive integer"
assert (isinstance(ksize, int) and ksize > 0) or isinstance(
ksize, list), "Kernel size should be a positive integer or a list of integers"
assert isinstance(
depth, int) and depth > 0 and depth < 16, "Depth should be a positive integer lower than 16"
self.depth = depth
# Core processing layers
layers = OrderedDict()
for lvl in range(depth):
blockname="resblock{}".format(lvl)
layers[blockname]=FixupBasicBlock(
n_features, ksize=ksize, activation=activation,
padding_mode=padding_mode, dropout=dropout)
self.net=nn.Sequential(layers)
self._reset_weights()
def _reset_weights(self):
for m in self.net.modules():
if isinstance(m, FixupBasicBlock):
nn.init.normal_(m.conv1.net.conv.weight, mean=0, std=np.sqrt(2 /
(m.conv1.net.conv.weight.shape[0] * np.prod(m.conv1.net.conv.weight.shape[2:]))) * self.depth ** (-0.5))
nn.init.constant_(m.conv2.net.conv.weight, 0)
def forward(self, x):
x = self.net(x)
return x
class FixUNet(nn.Module):
"""Simple UNet with downsampling and concat operations.
Args:
n_in(int): number of input channels.
n_out(int): number of input channels.
ksize(int): size of the convolution kernel (square).
width(int): number of features channels in the first hidden layers.
increase_factor(float): ratio of feature increase between scales.
num_convs(int): number of conv layers per level
num_levels(int): number of scales
activation(str): nonlinear activation function between convolutions.
norm_layer(str): normalization to apply between the convolution modules.
"""
def __init__(self, n_in, ksize=3,
num_convs=1, num_levels=4, activation="relu",
interp_mode="bilinear",padding_mode="reflect"):
super(FixUNet, self).__init__()
child = None
lvl_in = []
for lvl in range(num_levels-1, -1, -1):
n_child_out = n_in
lvl_in = n_in
if lvl == num_levels-1:
n_child_out = 0
u_lvl = FixUNet._FixUNetLevel(
lvl_in, ksize, num_convs, activation,
child=child, n_child_out=n_child_out,
interp_mode=interp_mode,padding_mode=padding_mode)
child = u_lvl
self.top_level = u_lvl
def forward(self, x):
return self.top_level(x)
class _FixUNetLevel(nn.Module):
def __init__(self, n_in, ksize, num_convs, activation,
child=None, n_child_out=0, interp_mode="bilinear",padding_mode="reflect"):
super(FixUNet._FixUNetLevel, self).__init__()
self.left = FixupResidualChain(n_features=n_in,depth=num_convs,ksize=ksize,activation=activation,padding_mode=padding_mode)
if n_child_out>0 :
self.right = nn.Sequential(
FixupResidualChain(n_features=n_in+n_child_out,ksize=ksize,depth=num_convs,padding_mode=padding_mode,activation=activation),
ConvModule(n_in+n_child_out, n_in, ksize=3, pad=True, activation=None, norm_layer=None, padding_mode="reflect")
)
else:
self.right = FixupResidualChain(n_features=n_in,ksize=ksize,depth=num_convs,padding_mode=padding_mode,activation=activation)
self.child = nn.Identity()
self.hasChild=False
if child is not None:
self.child = child
self.hasChild=True
self.interp_mode = interp_mode
def forward(self, x):
left_features = self.left(x)
if self.hasChild :
ds = nn.functional.interpolate(
left_features, scale_factor=0.5, recompute_scale_factor=True, mode=self.interp_mode, align_corners=True)
#ds = nn.functional.adaptive_max_pool2d(left_features,
# output_size=(left_features.shape[-2]//2,left_features.shape[-1]//2)
# )
child_features = self.child(ds)
us = nn.functional.interpolate(
child_features, size=left_features.shape[-2:],
mode=self.interp_mode, align_corners=True)
# skip connection
left_features = th.cat([left_features, us], dim=1)
output = self.right(left_features)
return output
class _Interpolate(nn.Module):
def __init__(self, scale_factor, mode, align_corners):
super(UNet._UNetLevel._Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.scale_factor = scale_factor
self.align_corners = align_corners
self.mode = mode
def forward(self, x):
x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
return x
class UNet(nn.Module):
"""Simple UNet with downsampling and concat operations.
Args:
n_in(int): number of input channels.
n_out(int): number of input channels.
ksize(int): size of the convolution kernel (square).
width(int): number of features channels in the first hidden layers.
increase_factor(float): ratio of feature increase between scales.
num_convs(int): number of conv layers per level
num_levels(int): number of scales
activation(str): nonlinear activation function between convolutions.
norm_layer(str): normalization to apply between the convolution modules.
"""
def _width(self, lvl):
return int(min(self.base_width*self.increase_factor**(lvl), self.max_width))
def __init__(self, n_in, n_out, ksize=3, base_width=64, max_width=512, increase_factor=2,
num_convs=1, num_levels=4, activation="relu", norm_layer=None,
interp_mode="bilinear",padding_mode="reflect"):
super(UNet, self).__init__()
self.increase_factor = increase_factor
self.max_width = max_width
self.base_width = base_width
child = None
lvl_in = []
for lvl in range(num_levels-1, -1, -1):
lvl_w = self._width(lvl)
n_child_out = self._width(lvl)
if lvl == 0:
lvl_in = n_in
lvl_out = n_out
else:
lvl_in = self._width(lvl-1)
lvl_out = self._width(lvl-1)
if lvl == num_levels-1:
n_child_out = 0
u_lvl = UNet._UNetLevel(
lvl_in, lvl_out, lvl_w, ksize, num_convs, activation,
norm_layer, child=child, n_child_out=n_child_out,
interp_mode=interp_mode,padding_mode=padding_mode)
child = u_lvl
self.top_level = u_lvl
def forward(self, x):
return self.top_level(x)
class _UNetLevel(nn.Module):
def __init__(self, n_in, n_out, width, ksize, num_convs, activation,
norm_layer, child=None, n_child_out=0, interp_mode="bilinear",padding_mode="reflect"):
super(UNet._UNetLevel, self).__init__()
self.left = ConvChain(n_in, ksize=ksize, width=width, depth=num_convs, pad=True,
activation=activation, norm_layer=norm_layer,padding_mode=padding_mode)
w = [width] * (num_convs-1) + [n_out]
self.right = ConvChain(width + n_child_out, ksize=ksize, width=w, depth=num_convs, pad=True,
activation=activation, norm_layer=norm_layer,padding_mode=padding_mode)
self.child = nn.Identity()
self.hasChild=False
if child is not None:
self.child = child
self.hasChild=True
self.interp_mode = interp_mode
def forward(self, x):
left_features = self.left(x)
if self.hasChild :
ds = nn.functional.interpolate(
left_features, scale_factor=0.5, recompute_scale_factor=True, mode=self.interp_mode, align_corners=True)
child_features = self.child(ds)
us = nn.functional.interpolate(
child_features, size=left_features.shape[-2:],
mode=self.interp_mode, align_corners=True)
# skip connection
left_features = th.cat([left_features, us], dim=1)
output = self.right(left_features)
return output
class _Interpolate(nn.Module):
def __init__(self, scale_factor, mode, align_corners):
super(UNet._UNetLevel._Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.scale_factor = scale_factor
self.align_corners = align_corners
self.mode = mode
def forward(self, x):
x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
return x
# Helpers ---------------------------------------------------------------------
def _get_norm_layer(norm_layer, channels):
valid = ["instance", "batch"]
assert norm_layer in valid, "norm_layer should be one of {}".format(valid)
if norm_layer == "instance":
layer = nn.InstanceNorm2d(channels, affine=True)
elif norm_layer == "batch":
layer = nn.BatchNorm2d(channels, affine=True)
nn.init.constant_(layer.bias, 0.0)
nn.init.constant_(layer.weight, 1.0)
return layer
def _get_activation(activation):
valid = ["relu", "leaky_relu", "lrelu", "elu", "selu"]
assert activation in valid, "activation should be one of {}".format(valid)
if activation == "relu":
return nn.ReLU(inplace=True)
if activation == "leaky_relu" or activation == "lrelu":
return nn.LeakyReLU(inplace=True)
if activation == "elu":
return nn.ELU(inplace=True)
if activation == "selu":
return nn.SELU(inplace=True)
return None
def _init_fc_or_conv(fc_conv, activation):
gain = 1.0
if activation is not None:
try:
gain = nn.init.calculate_gain(activation)
except:
print("Warning using gain of ",gain," for activation: ",activation)
nn.init.xavier_uniform_(fc_conv.weight, gain)
if fc_conv.bias is not None:
nn.init.constant_(fc_conv.bias, 0.0)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment