import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.autograd import Variable
from os import listdir
from os.path import isfile, join
from scipy.io import loadmat # for loading mat files
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.image as mpimg
import math
def ToGrayscale(sample):
# Convert to numpy array and then make the image grayscale
sample = np.asarray(sample)
sample = sample.sum(axis=-1) # sum over last axis
sample = sample - np.mean(sample)
sample = sample / np.max(sample) # divide by max over the image
sample = np.pad(sample, (7, 7), 'symmetric') # pad with symmetric borders (assumes 15x15 filter)
#sample = np.expand_dims(sample, axis=-1) # add extra channels dimension
return sample
class Net(nn.Module):
def __init__(self, filters_temp, filters_notemp, num_rfs):
super(Net, self).__init__()
# Convert set of numpy ndarray filters into Pytorch
self.filts_temp = nn.Parameter(torch.from_numpy(filters_temp).permute(1, 0, 2, 3, 4).float(),
requires_grad=False)
self.filts_notemp = nn.Parameter(torch.from_numpy(filters_notemp).permute(1, 0, 2, 3, 4).float(),
requires_grad=False)
self.num_rfs = num_rfs # Number of RFs to look out for correlations
def forward(self, x):
# Define spatial extent of correlations
corr_x = self.num_rfs * self.filts_temp.shape[3]
corr_y = self.num_rfs * self.filts_temp.shape[4]
num_filts = self.filts_temp.shape[0] + self.filts_notemp.shape[0]
# Convolve filters with input image
x = torch.squeeze(x)
x_new = x.expand(2, x.size()[0], x.size()[1])
x = x_new.view(1,1,x_new.size()[0], x_new.size()[1], x_new.size()[2])
x = x.float()
x_temp = F.relu(F.conv3d(x, self.filts_temp)/2)
x_notemp = F.relu(F.conv3d(x[:, :, x.size()[2]-1, :, :].unsqueeze(2), self.filts_notemp))
x = torch.cat((x_temp, x_notemp), dim=1).float()
# Normalization with added eps in denominator
x1 = torch.div(x, torch.sum(x, dim=1).unsqueeze(1) + np.finfo(float).eps)
x_max = x1.size()[4]
y_max = x1.size()[3]
x1_filts = x1[:, :, :, corr_y:y_max - corr_y, corr_x:x_max - corr_x].contiguous().view(num_filts, 1,1, y_max - 2 * corr_y, x_max - 2 * corr_x) # select subset
x1 = x1.view(1,1,num_filts,y_max,x_max)
x2 = F.conv3d(x1, x1_filts, groups=1)
x2 = x2.squeeze().view(1, num_filts, num_filts, 2 * corr_y + 1, 2 * corr_x + 1)
# We are using a 231x391 size filter
x2 = torch.div(x2, (y_max - 2 * corr_y) * (x_max - 2 * corr_y)) # normalize by size of filter
return x1, x2
def train():
model.eval()
mypath = '/home/dvoina/ramsmatlabprogram/BSR_2/BSDS500/data/images/train/im_dir'
# mypath = '/home/dvoina/vip_project/dir_for_videos'
onlyfiles = [f for f in listdir(mypath)]
for batch_idx in range(np.shape(onlyfiles)[0]):
data = mpimg.imread(
'/home/dvoina/ramsmatlabprogram/BSR_2/BSDS500/data/images/train/im_dir/' + onlyfiles[batch_idx])
data = ToGrayscale(data)
data = torch.from_numpy(data).float()
if cuda:
data = data.cuda()
with torch.no_grad():
data = Variable(data) # convert into pytorch variables
x1, x2 = model(data) # forward inference
x1 = x1.view(1,34,x1.size()[3],x1.size()[4])
filt_avgs = torch.mean(x1.data.view(1, x1.shape[1], -1), dim=2).squeeze()
fn_array.append(filt_avgs.cpu().numpy()) # load back to cpu and convert to numpy
# f_nn's
grid_space = 1 # can also choose 7 (original)
x2_subset = x2[:, :, :, (45 - 21):(45 + 21 + 1):grid_space, (45 - 21):(
45 + 21 + 1):grid_space].data.squeeze() # Python doesn't include end, so add 1
fnn_array.append(x2_subset.cpu().numpy())
return np.asarray(fn_array), np.asarray(fnn_array)
#first, load the filters (34 spatio-temporal filters)
filters34_2 = loadmat('data_filts_px3_v2.mat')
filters34_temp = np.array(filters34_2['data_filts2'][0,0:16].tolist())
filters34_temp = np.expand_dims(filters34_temp, axis=0)
filters34_temp = np.transpose(filters34_temp, (0,1,4,2,3))
filters34_notemp = np.array(filters34_2['data_filts2'][0,16:34].tolist())
filters34_notemp = np.expand_dims(filters34_notemp, axis=1)
filters34_notemp = np.expand_dims(filters34_notemp, axis=0)
# Let's zero mean the filters (make use of numpy broadcasting)
filters34_temp = np.transpose(np.transpose(filters34_temp, (0,2,3,4,1))-filters34_temp.reshape(1,filters34_temp.shape[1],-1).mean(axis=2), (0,4,1,2,3))
filters34_notemp = np.transpose(np.transpose(filters34_notemp, (0,2,3,4,1))-filters34_notemp.reshape(1,filters34_notemp.shape[1],-1).mean(axis=2), (0,4,1,2,3))
# Training settings
cuda = torch.cuda.is_available() # disables using the GPU and cuda if False
batch_size = 1 # input batch size for training (TODO: figure out how to group images with similar orientation)
# Create a new instance of the network
model = Net(filters34_temp, filters34_notemp, num_rfs=3)
if cuda:
model.cuda()
# Use a list for saving the activations for each image
fn_array = []
fnn_array = []
filt_avgs, fnn_avgs = train()
filt_avgs_images = np.mean(filt_avgs, axis=0)
fnn_avgs_images = np.mean(fnn_avgs, axis=0)
W = np.empty(fnn_avgs_images.shape)
for i in range(W.shape[0]):
for j in range(W.shape[1]):
W[i,j,:] = fnn_avgs_images[i,j,:].squeeze()/(filt_avgs_images[i]*filt_avgs_images[j]) - 1
def construct_row4(w, dim, flag):
Nx = dim[0]
Ny = dim[1]
center2 = int(math.floor(Ny/2))
#grid1 = np.concatenate((np.array(range(center2-3*7, center2, 7)), np.array(range(center2, center2+4*7, 7))))
#grid2 = np.concatenate((np.array(range(center2-3*7, center2, 7)), np.array(range(center2, center2+4*7, 7))))
grid1 = [4, 11, 18, 21, 24, 31, 38]
grid2 = grid1
W_fine = np.zeros((Nx,Ny))
for nx in range(7):
for ny in range(7):
W_fine[grid1[nx], grid2[ny]] = w[nx,ny];
if (nx==3) & (ny==3) & (flag==1):
W_fine[grid1[nx], grid2[ny]] = 0;
return W_fine
#W_stat2 = W[:,:,range(0,43,7),:]
#W_stat3 = W_stat2[:,:,:,range(0,43,7)]
W_stat2 = W[:, :, [4, 11, 18, 21, 24, 31, 38], :]
W_stat3 = W_stat2[:, :, :, [4, 11, 18, 21, 24, 31, 38]]
flag = 1
dim = [43,43]
NF = 34
W1_stat = np.zeros((NF, NF, dim[0], dim[1]))
for f1 in range(NF):
for f2 in range(NF):
W1_stat[f1,f2,:,:] = construct_row4(W_stat3[f1, f2, :, :], dim, flag)
W_stat = W1_stat
np.save('/home/dvoina/simple_vids/results/W_43x43_34filters_static_simple_3px_ReviewComplete2.npy', W)
np.save('/home/dvoina/simple_vids/results/W_43x43_34filters_static_simple_3px_ReviewSparse2.npy', W_stat)