Comparing Variational Methods and Deep Neural Networks for image reconstruction problems
Summer school at the Alan Turing Institute
Project: How do convolutional neural networks perform for image reconstruction from indirect measurnments?
Specific questions: Dense networks versus convolution networks; Generalizability of learned network; Comparison with model-based compressed sensing TV or wavelet reconstruction.
Training model
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
This is not my code, I give credit to Clarice Poon
Copyright (c) 2017 Clarice Poon
"""
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint
import matplotlib.pyplot as plt
import os
from ellipse_data import create_dataset
from masks import lines_mask, rand_mask
from networks import get_unet
img_rows = 64
img_cols = 64
patch = 64
#create training data
s=0.50 #subsampling ratio
#create sampling mask
maskindx = lines_mask(img_rows,img_cols,s)
#create training dataset
imgs_exact, imgs_fbp = create_dataset(100,img_rows,img_cols,patch,maskindx,10,10)
#fit model
model = get_unet()
model.compile(optimizer=Adam(lr=1e-4), loss='mean_squared_error')
savedir = 'results'
if not os.path.isdir(savedir):
os.mkdir(savedir)
model_checkpoint = ModelCheckpoint(filepath=savedir + '/weights.hdf5', monitor='val_loss', save_best_only=True)
history = model.fit(imgs_fbp, imgs_exact, batch_size=1, epochs=100, verbose=1, shuffle=True,
validation_split=0.1,
callbacks=[model_checkpoint])
# summarize history for accuracy
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['loss','val_loss'], loc='upper left')
plt.show()
## evaluate the model
scores = model.evaluate(imgs_fbp, imgs_exact)
#prediction on training data
X2 = model.predict(imgs_fbp)
plt.figure(1)
plt.imshow(X2[1,:,:,0])
#prediction on test data
testimgs_exact, testimgs_fbp = create_dataset(2,img_rows,img_cols,img_rows,maskindx,10,0);
output = model.predict(testimgs_fbp)
#error = la.norm(testimgs_exact-output)
plt.figure(2)
plt.imshow(output[1,:,0:img_cols,0])
plt.show()