-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathutilities.py
More file actions
94 lines (76 loc) · 2.52 KB
/
utilities.py
File metadata and controls
94 lines (76 loc) · 2.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
import torch.functional as F
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from PIL import Image
import matplotlib.pyplot as plt
import flowlib
device = 'cuda'
ALPHA = 1e13 #previously 12, 2e10
BETA = 1e10 #1e6 #11,
GAMMA = 3e-2 #previously -3
LAMBDA_O = 1e6
LAMBDA_F = 1e4
IMG_SIZE = (640, 360)
VGG16_MEAN = [0.485, 0.456, 0.406]
VGG16_STD = [0.229, 0.224, 0.225]
def normalizeVGG16(img):
mean = img.new_tensor(VGG16_MEAN).view(-1, 1, 1)
std = img.new_tensor(VGG16_STD).view(-1, 1, 1)
img = img.div_(255.0)
return (img - mean) / std
normalize = transforms.Lambda(lambda x: normalizeVGG16(x))
transform1 = transforms.Compose([
transforms.Resize(IMG_SIZE),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.mul(255)),
normalize
])
transform2 = transforms.Compose([
transforms.Resize(IMG_SIZE),
transforms.ToTensor(),
transforms.Lambda(lambda x: (1-x).mul(255)),
normalize
])
def gram_matrix(input):
# print(input.size())
a, b, c, d = input.size()
features = input.view(a * b, c * d)
G = torch.mm(features, features.t())
return G.div(a * b * c * d)
def warp(x, flo):
"""
warp an image/tensor (im2) back to im1, according to the optical flow
x: [B, C, H, W] (im2)
flo: [B, 2, H, W] flow
"""
B, C, H, W = x.size()
# mesh grid
xx = torch.arange(0, W).view(1,-1).repeat(H,1)
yy = torch.arange(0, H).view(-1,1).repeat(1,W)
xx = xx.view(1,1,H,W).repeat(B,1,1,1)
yy = yy.view(1,1,H,W).repeat(B,1,1,1)
grid = torch.cat((xx,yy),1).float()
if x.is_cuda:
grid = grid.cuda()
vgrid = Variable(grid) + flo
# scale grid to [-1,1]
vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:]/max(W-1,1)-1.0
vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:]/max(H-1,1)-1.0
vgrid = vgrid.permute(0,2,3,1)
output = nn.functional.grid_sample(x, vgrid)
mask = torch.autograd.Variable(torch.ones(x.size())).cuda()
mask = nn.functional.grid_sample(mask, vgrid)
# if W==128:
# np.save('mask.npy', mask.cpu().data.numpy())
# np.save('warp.npy', output.cpu().data.numpy())
mask[mask<0.9999] = 0
mask[mask>0] = 1
return output*mask