import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.autograd import Variable
from os import listdir
from os.path import isfile, join
from scipy.io import loadmat
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.image as mpimg
import math
def ToGrayscale(sample):
sample = np.asarray(sample)
sample = sample.sum(axis=-1)
sample = sample - np.mean(sample)
sample = sample / np.max(sample)
sample = np.pad(sample, (7, 7), 'symmetric')
return sample
class Net(nn.Module):
def __init__(self, filters_temp, filters_notemp, num_rfs):
super(Net, self).__init__()
self.filts_temp = nn.Parameter(torch.from_numpy(filters_temp).permute(1, 0, 2, 3, 4).float(),
requires_grad=False)
self.filts_notemp = nn.Parameter(torch.from_numpy(filters_notemp).permute(1, 0, 2, 3, 4).float(),
requires_grad=False)
self.num_rfs = num_rfs
def forward(self, x):
corr_x = self.num_rfs * self.filts_temp.shape[3]
corr_y = self.num_rfs * self.filts_temp.shape[4]
num_filts = self.filts_temp.shape[0] + self.filts_notemp.shape[0]
x = torch.squeeze(x)
x_new = x.expand(2, x.size()[0], x.size()[1])
x = x_new.view(1,1,x_new.size()[0], x_new.size()[1], x_new.size()[2])
x = x.float()
x_temp = F.relu(F.conv3d(x, self.filts_temp)/2)
x_notemp = F.relu(F.conv3d(x[:, :, x.size()[2]-1, :, :].unsqueeze(2), self.filts_notemp))
x = torch.cat((x_temp, x_notemp), dim=1).float()
x1 = torch.div(x, torch.sum(x, dim=1).unsqueeze(1) + np.finfo(float).eps)
x_max = x1.size()[4]
y_max = x1.size()[3]
x1_filts = x1[:, :, :, corr_y:y_max - corr_y, corr_x:x_max - corr_x].contiguous().view(num_filts, 1,1, y_max - 2 * corr_y, x_max - 2 * corr_x)
x1 = x1.view(1,1,num_filts,y_max,x_max)
x2 = F.conv3d(x1, x1_filts, groups=1)
x2 = x2.squeeze().view(1, num_filts, num_filts, 2 * corr_y + 1, 2 * corr_x + 1)
x2 = torch.div(x2, (y_max - 2 * corr_y) * (x_max - 2 * corr_y))
return x1, x2
def train():
model.eval()
mypath = '/home/dvoina/ramsmatlabprogram/BSR_2/BSDS500/data/images/train/im_dir'
onlyfiles = [f for f in listdir(mypath)]
for batch_idx in range(np.shape(onlyfiles)[0]):
data = mpimg.imread(
'/home/dvoina/ramsmatlabprogram/BSR_2/BSDS500/data/images/train/im_dir/' + onlyfiles[batch_idx])
data = ToGrayscale(data)
data = torch.from_numpy(data).float()
if cuda:
data = data.cuda()
with torch.no_grad():
data = Variable(data)
x1, x2 = model(data)
x1 = x1.view(1,34,x1.size()[3],x1.size()[4])
filt_avgs = torch.mean(x1.data.view(1, x1.shape[1], -1), dim=2).squeeze()
fn_array.append(filt_avgs.cpu().numpy())
grid_space = 1
x2_subset = x2[:, :, :, (45 - 21):(45 + 21 + 1):grid_space, (45 - 21):(
45 + 21 + 1):grid_space].data.squeeze()
fnn_array.append(x2_subset.cpu().numpy())
return np.asarray(fn_array), np.asarray(fnn_array)
filters34_2 = loadmat('data_filts_px3_v2.mat')
filters34_temp = np.array(filters34_2['data_filts2'][0,0:16].tolist())
filters34_temp = np.expand_dims(filters34_temp, axis=0)
filters34_temp = np.transpose(filters34_temp, (0,1,4,2,3))
filters34_notemp = np.array(filters34_2['data_filts2'][0,16:34].tolist())
filters34_notemp = np.expand_dims(filters34_notemp, axis=1)
filters34_notemp = np.expand_dims(filters34_notemp, axis=0)
filters34_temp = np.transpose(np.transpose(filters34_temp, (0,2,3,4,1))-filters34_temp.reshape(1,filters34_temp.shape[1],-1).mean(axis=2), (0,4,1,2,3))
filters34_notemp = np.transpose(np.transpose(filters34_notemp, (0,2,3,4,1))-filters34_notemp.reshape(1,filters34_notemp.shape[1],-1).mean(axis=2), (0,4,1,2,3))
cuda = torch.cuda.is_available()
batch_size = 1
model = Net(filters34_temp, filters34_notemp, num_rfs=3)
if cuda:
model.cuda()
fn_array = []
fnn_array = []
filt_avgs, fnn_avgs = train()
filt_avgs_images = np.mean(filt_avgs, axis=0)
fnn_avgs_images = np.mean(fnn_avgs, axis=0)
W = np.empty(fnn_avgs_images.shape)
for i in range(W.shape[0]):
for j in range(W.shape[1]):
W[i,j,:] = fnn_avgs_images[i,j,:].squeeze()/(filt_avgs_images[i]*filt_avgs_images[j]) - 1
def construct_row4(w, dim, flag):
Nx = dim[0]
Ny = dim[1]
center2 = int(math.floor(Ny/2))
grid1 = [4, 11, 18, 21, 24, 31, 38]
grid2 = grid1
W_fine = np.zeros((Nx,Ny))
for nx in range(7):
for ny in range(7):
W_fine[grid1[nx], grid2[ny]] = w[nx,ny];
if (nx==3) & (ny==3) & (flag==1):
W_fine[grid1[nx], grid2[ny]] = 0;
return W_fine
W_stat2 = W[:, :, [4, 11, 18, 21, 24, 31, 38], :]
W_stat3 = W_stat2[:, :, :, [4, 11, 18, 21, 24, 31, 38]]
flag = 1
dim = [43,43]
NF = 34
W1_stat = np.zeros((NF, NF, dim[0], dim[1]))
for f1 in range(NF):
for f2 in range(NF):
W1_stat[f1,f2,:,:] = construct_row4(W_stat3[f1, f2, :, :], dim, flag)
W_stat = W1_stat
np.save('/home/dvoina/simple_vids/results/W_43x43_34filters_static_simple_3px_ReviewComplete2.npy', W)
np.save('/home/dvoina/simple_vids/results/W_43x43_34filters_static_simple_3px_ReviewSparse2.npy', W_stat)