Comparing Variational Methods and Deep Neural Networks for image reconstruction problems

Summer school at the Alan Turing Institute

Project: How do convolutional neural networks perform for image reconstruction from indirect measurnments?

Specific questions: Dense networks versus convolution networks; Generalizability of learned network; Comparison with model-based compressed sensing TV or wavelet reconstruction.

Variational Methods
Exact, Sampled, and Predicted Image (Left to Right). 64x64 image.

Training model

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
This is not my code, I give credit to Clarice Poon
Copyright (c) 2017 Clarice Poon
"""


from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint
import matplotlib.pyplot as plt
import os

from ellipse_data import create_dataset
from masks import lines_mask, rand_mask
from networks import get_unet


img_rows = 64
img_cols = 64
patch = 64


#create training data
s=0.50 #subsampling ratio

#create sampling mask
maskindx = lines_mask(img_rows,img_cols,s)

#create training dataset
imgs_exact, imgs_fbp = create_dataset(100,img_rows,img_cols,patch,maskindx,10,10)

#fit model
model = get_unet()
model.compile(optimizer=Adam(lr=1e-4), loss='mean_squared_error')

savedir = 'results'

if not os.path.isdir(savedir):
    os.mkdir(savedir)

model_checkpoint = ModelCheckpoint(filepath=savedir + '/weights.hdf5', monitor='val_loss', save_best_only=True)

history = model.fit(imgs_fbp, imgs_exact, batch_size=1, epochs=100, verbose=1, shuffle=True,
              validation_split=0.1,
              callbacks=[model_checkpoint])



# summarize history for accuracy
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['loss','val_loss'], loc='upper left')
plt.show()




## evaluate the model
scores = model.evaluate(imgs_fbp, imgs_exact)
#prediction on training data
X2 = model.predict(imgs_fbp)
plt.figure(1)
plt.imshow(X2[1,:,:,0])

#prediction on test data
testimgs_exact, testimgs_fbp = create_dataset(2,img_rows,img_cols,img_rows,maskindx,10,0);
output = model.predict(testimgs_fbp)
#error = la.norm(testimgs_exact-output)
plt.figure(2)
plt.imshow(output[1,:,0:img_cols,0])
plt.show()