Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 02404a5d authored by Luigi Antelmi's avatar Luigi Antelmi
Browse files

add noisy example and plot results

parent ebab8c43
No related branches found
Tags v1.0.0
No related merge requests found
......@@ -6,6 +6,8 @@ import matplotlib.pyplot as plt
import mcvae.pytorch_modules
import mcvae.utilities
import mcvae.preprocessing
import mcvae.plot
import mcvae.diagnostics
DEVICE = mcvae.pytorch_modules.DEVICE
......@@ -16,6 +18,7 @@ n_channels = 3
n_feats = 4
true_lat_dims = 2
fit_lat_dims = 5
snr=10
np.random.seed(7)
z = np.random.randn(Nobs, true_lat_dims)
......@@ -30,9 +33,10 @@ generator = mcvae.pytorch_modules.ScenarioGenerator(
preprocpars = {'remove_mean': True, 'normalize': True, 'whitening': False}
x_ = generator(z)
x = mcvae.utilities.ltotensor(
mcvae.preprocessing.preprocess(x_, **preprocpars)
)
x, x_noisy = mcvae.utilities.preprocess_and_add_noise(x_, snr=snr)
#x = mcvae.utilities.ltotensor(
# mcvae.preprocessing.preprocess(x_, **preprocpars)
#)
# Send to GPU (if possible)
X = [c.to(DEVICE) for c in x] if torch.cuda.is_available() else x
......@@ -48,7 +52,7 @@ X = [c.to(DEVICE) for c in x] if torch.cuda.is_available() else x
init_dict = {
'n_channels': len(x),
'lat_dim': fit_lat_dims,
'n_feats': tuple([i.shape[1] for i in X])
'n_feats': tuple([i.shape[1] for i in X]),
}
adam_lr = 1e-3
......@@ -60,14 +64,14 @@ model = {}
torch.manual_seed(24)
model['mcvae'] = mcvae.pytorch_modules.MultiChannelBase(
**init_dict,
model_name_dict={**init_dict, 'adam_lr': adam_lr},
model_name_dict={**init_dict, 'adam_lr': adam_lr, 'snr': snr},
)
# Sparse Multi-Channel VAE
torch.manual_seed(24)
model['smcvae'] = mcvae.pytorch_modules.MultiChannelSparseVAE(
**init_dict,
model_name_dict={**init_dict, 'adam_lr': adam_lr},
model_name_dict={**init_dict, 'adam_lr': adam_lr, 'snr': snr},
)
for current_model in ['mcvae', 'smcvae']:
......@@ -100,12 +104,20 @@ for current_model in ['mcvae', 'smcvae']:
pred = {} # Prediction
z = {} # Latent Space
g = {} # Generative Parameters
x_hat = {} # reconstructed channels
for m in model.keys():
mcvae.diagnostics.plot_loss(model[m])
pred[m] = model[m](X)
x_hat[m] = model[m].reconstruct(pred[m])
z[m] = np.array([pred[m]['qzx'][i]['mu'].detach().numpy() for i in range(n_channels)]).reshape(-1)
g[m] = np.array([model[m].W_out[i].weight.detach().numpy() for i in range(n_channels)]).reshape(-1)
mcvae.plot.lsplom(mcvae.utilities.ltonumpy(x), title=f'Ground truth')
mcvae.plot.lsplom(mcvae.utilities.ltonumpy(x_noisy), title=f'Noisy data fitted by the models (snr={snr})')
for m in model.keys():
mcvae.plot.lsplom(mcvae.utilities.ltonumpy(x_hat[m]), title=f'Reconstructed with {m} model')
plt.figure()
plt.subplot(1,2,1)
plt.hist([z['smcvae'], z['mcvae']], bins=20, color=['k', 'gray'])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment