En este notebook se muestra cómo se puede definir un variational autoencoder. Para ello vamos a utilizar la librería Keras y el dataset MNIST. Pero vamos a ver primero un autoencoder básico.

En esta práctica vamos a hacer un uso intensivo de la GPU, así que es importante activar su uso desde la opción Configuración del cuaderno del menú Editar (esta opción debería estar habilitada por defecto, pero es recomendable que lo compruebes).

Un autoencoder básico

El dataset MNIST es un dataset ampliamente utilizado para probar algoritmos de aprendizaje automático. El problema que se intenta resolver con el dataset MNIST consiste en clasificar imágenes en escala de grises de dígitos manuscritos (de tamaño 28x28) en 10 categorías (del 0 al 9). En nuestro caso no vamos a crear un clasificador para el dataset de MNIST sino que lo vamos a utilizar para mostrar cómo podemos usar un auto-encoder para eliminar el ruido de las imágenes.

Comenzamos cargando las librerías necesarias.

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import backend as K
import warnings
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt

Dataset

Para entrenar el autoencoder vamos a usar el dataset de MNIST donde los píxeles estarán normalizados entre 0 y 1.

En este caso no nos interesan las etiquetas del dataset, solo las imágenes.

(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 1s 0us/step
11501568/11490434 [==============================] - 1s 0us/step

Normalizamos el dataset y lo preparamos para poder alimentar al autoencoder.

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = np.reshape(x_train, (len(x_train), 28, 28, 1))
x_test = np.reshape(x_test, (len(x_test), 28, 28, 1))

Autoencoder

Pasamos ahora a definir la arquitectura de nuestro autoencoder. Para ello lo primero que debemos hacer es definir la forma que tendrán los datos de entrada de nuestro autoencoder. El dataset MNIST consta de imágenes de tamaño $28\times 28$ en escala de grises, por lo que solo tienen un canal.

input_img = layers.Input(shape=(28,28,1))

Recordar que un autoencoder consta de un encoder y un decoder. Definimos a continuación nuestro encoder que va a constar de una pila de capas de convolución y de max pooling. Al aplicar el proceso de encoding llegamos a una reprentación final de tamaño (4,4,8) es decir 128 dimensiones, es decir hemos reducido casi a una sexta parte la codificación de nuestras imágenes.

x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(input_img)
# Capa de max pooling con filtro de tamaño 2x2 y aplicando padding
x = layers.MaxPooling2D((2, 2), padding='same')(x)
x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(x)
encoded = layers.MaxPooling2D((2, 2), padding='same')(x)

El decoder se define mediante una pila de capas de convolución y de upsampling (capas con la función inversa que las de pooling). Notar que la entrada de la primera capa del decoder es la salida del encoder. Notar que la arquitectura es simétrica a la del encoder.

x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(encoded)
 # Capa de upsampling con filtro de tamaño 2x2 y aplicando padding
x = layers.UpSampling2D((2, 2))(x)
x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = layers.UpSampling2D((2, 2))(x)
decoded = layers.Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)

Por último definimos nuestro modelo de autoencoder y lo compilamos. En Keras es necesario compilar un modelo para fijar el optimizador que se utilizará para entrenarlo (en este caso ADAM que es una variante del descenso de gradiente) y la función de pérdida (en este caso la binary crossentropy).

autoencoder = keras.Model(input_img, decoded)
autoencoder.compile(optimizer='adam', loss='binary_crossentropy')

Con la siguiente instrucción podemos mostrar la arquitectura de una red de Keras.

autoencoder.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 conv2d (Conv2D)             (None, 28, 28, 32)        320       
                                                                 
 max_pooling2d (MaxPooling2D  (None, 14, 14, 32)       0         
 )                                                               
                                                                 
 conv2d_1 (Conv2D)           (None, 14, 14, 32)        9248      
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 7, 7, 32)         0         
 2D)                                                             
                                                                 
 conv2d_2 (Conv2D)           (None, 7, 7, 32)          9248      
                                                                 
 up_sampling2d (UpSampling2D  (None, 14, 14, 32)       0         
 )                                                               
                                                                 
 conv2d_3 (Conv2D)           (None, 14, 14, 32)        9248      
                                                                 
 up_sampling2d_1 (UpSampling  (None, 28, 28, 32)       0         
 2D)                                                             
                                                                 
 conv2d_4 (Conv2D)           (None, 28, 28, 1)         289       
                                                                 
=================================================================
Total params: 28,353
Trainable params: 28,353
Non-trainable params: 0
_________________________________________________________________

También puede ser útil visualizar dicha red.

from keras.utils.vis_utils import plot_model
plot_model(autoencoder, to_file='autoencoder_plot.png', show_shapes=True, show_layer_names=True)

Vamos ahora a entrenar nuestro modelo para ello usamos el método fit que está disponible para cualquier modelo de Keras.

autoencoder.fit(x_train, x_train,
                epochs=50,
                batch_size=128,
                validation_data=(x_test, x_test))
Epoch 1/50
469/469 [==============================] - 19s 18ms/step - loss: 0.1155 - val_loss: 0.0767
Epoch 2/50
469/469 [==============================] - 7s 15ms/step - loss: 0.0742 - val_loss: 0.0715
Epoch 3/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0712 - val_loss: 0.0700
Epoch 4/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0697 - val_loss: 0.0686
Epoch 5/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0688 - val_loss: 0.0679
Epoch 6/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0680 - val_loss: 0.0673
Epoch 7/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0675 - val_loss: 0.0667
Epoch 8/50
469/469 [==============================] - 7s 15ms/step - loss: 0.0669 - val_loss: 0.0663
Epoch 9/50
469/469 [==============================] - 7s 15ms/step - loss: 0.0666 - val_loss: 0.0662
Epoch 10/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0662 - val_loss: 0.0657
Epoch 11/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0659 - val_loss: 0.0653
Epoch 12/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0657 - val_loss: 0.0653
Epoch 13/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0655 - val_loss: 0.0649
Epoch 14/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0653 - val_loss: 0.0647
Epoch 15/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0651 - val_loss: 0.0647
Epoch 16/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0649 - val_loss: 0.0645
Epoch 17/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0648 - val_loss: 0.0644
Epoch 18/50
469/469 [==============================] - 9s 19ms/step - loss: 0.0647 - val_loss: 0.0643
Epoch 19/50
469/469 [==============================] - 8s 16ms/step - loss: 0.0645 - val_loss: 0.0645
Epoch 20/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0644 - val_loss: 0.0640
Epoch 21/50
469/469 [==============================] - 7s 15ms/step - loss: 0.0643 - val_loss: 0.0641
Epoch 22/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0642 - val_loss: 0.0638
Epoch 23/50
469/469 [==============================] - 7s 15ms/step - loss: 0.0641 - val_loss: 0.0638
Epoch 24/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0640 - val_loss: 0.0636
Epoch 25/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0640 - val_loss: 0.0635
Epoch 26/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0639 - val_loss: 0.0635
Epoch 27/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0638 - val_loss: 0.0635
Epoch 28/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0638 - val_loss: 0.0634
Epoch 29/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0637 - val_loss: 0.0634
Epoch 30/50
469/469 [==============================] - 7s 15ms/step - loss: 0.0637 - val_loss: 0.0633
Epoch 31/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0636 - val_loss: 0.0632
Epoch 32/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0636 - val_loss: 0.0632
Epoch 33/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0635 - val_loss: 0.0631
Epoch 34/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0635 - val_loss: 0.0631
Epoch 35/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0634 - val_loss: 0.0630
Epoch 36/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0634 - val_loss: 0.0631
Epoch 37/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0634 - val_loss: 0.0630
Epoch 38/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0633 - val_loss: 0.0630
Epoch 39/50
469/469 [==============================] - 7s 15ms/step - loss: 0.0633 - val_loss: 0.0629
Epoch 40/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0633 - val_loss: 0.0629
Epoch 41/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0632 - val_loss: 0.0629
Epoch 42/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0632 - val_loss: 0.0629
Epoch 43/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0632 - val_loss: 0.0629
Epoch 44/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0632 - val_loss: 0.0628
Epoch 45/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0631 - val_loss: 0.0627
Epoch 46/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0631 - val_loss: 0.0628
Epoch 47/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0631 - val_loss: 0.0627
Epoch 48/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0631 - val_loss: 0.0628
Epoch 49/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0630 - val_loss: 0.0626
Epoch 50/50
469/469 [==============================] - 7s 14ms/step - loss: 0.0630 - val_loss: 0.0627
<keras.callbacks.History at 0x7f0a900f1610>

Vamos a mostrar la reconstrucción de algunos de los dígitos. Al ejecutar la siguiente celda, la primera fila muestra los dígitos originales y la segunda los reconstruidos.

decoded_imgs = autoencoder.predict(x_test)

n = 10
plt.figure(figsize=(20, 4))
for i in range(1,n):
    ax = plt.subplot(2, n, i)
    plt.imshow(x_test[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    ax = plt.subplot(2, n, i + n)
    plt.imshow(decoded_imgs[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

Aplicación a la eliminación de ruido

Vamos a poner ahora a nuestro encoder a trabajar en el problema de la eliminación de ruido. Esto va a ser tan sencillo como entrenar nuestro autoencoder para mapear dígitos con ruido a imágenes limpias.

Para ello, el primer paso es construir nuestro dataset con ruido aplicando un ruido Gaussiano a las imágenes, y luego limitando los valores al rango de 0 a 1.

En primer lugar procedemos a añadir ruido a las imágenes del dataset de MNIST.

noise_factor = 0.5
x_train_noisy = x_train + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_train.shape) 
x_test_noisy = x_test + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_test.shape) 

x_train_noisy = np.clip(x_train_noisy, 0., 1.)
x_test_noisy = np.clip(x_test_noisy, 0., 1.)

Podemos ver a continuación algunas de las imágenes a las que se les ha añadido el ruido.

n = 10
plt.figure(figsize=(20, 2))
for i in range(1, n + 1):
    ax = plt.subplot(1, n, i)
    plt.imshow(x_test_noisy[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

Autoencoder

Pasamos ahora a definir la arquitectura de nuestro autoencoder. Para ello lo primero que debemos hacer es definir la forma que tendrán los datos de entrada de nuestro autoencoder. El dataset MNIST consta de imágenes de tamaño $28\times 28$ en escala de grises, por lo que solo tienen un canal.

input_img = layers.Input(shape=(28,28,1))

Recordar que un autoencoder consta de un encoder y un decoder. Definimos a continuación nuestro encoder que va a constar de una pila de capas de convolución y de maxpooling.

x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(input_img)
x = layers.MaxPooling2D((2, 2), padding='same')(x)
x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(x)
encoded = layers.MaxPooling2D((2, 2), padding='same')(x)

El decoder se define mediante una pila de capas de convolución y de upsampling (capas con la función inversa que las de pooling). Notar que la entrada de la primera capa del decoder es la salida del encoder; además, al igual que antes la arquitectura es simétrica al encoder.

x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(encoded)
x = layers.UpSampling2D((2, 2))(x)
x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = layers.UpSampling2D((2, 2))(x)
decoded = layers.Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)

Por último definimos nuestro modelo de autoencoder y lo compilamos.

autoencoder_noise = keras.Model(input_img, decoded)
autoencoder_noise.compile(optimizer='adam', loss='binary_crossentropy')

Con la siguiente instrucción podemos mostrar la arquitectura del nuevo autoencoder.

autoencoder_noise.summary()
Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_2 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 conv2d_5 (Conv2D)           (None, 28, 28, 32)        320       
                                                                 
 max_pooling2d_2 (MaxPooling  (None, 14, 14, 32)       0         
 2D)                                                             
                                                                 
 conv2d_6 (Conv2D)           (None, 14, 14, 32)        9248      
                                                                 
 max_pooling2d_3 (MaxPooling  (None, 7, 7, 32)         0         
 2D)                                                             
                                                                 
 conv2d_7 (Conv2D)           (None, 7, 7, 32)          9248      
                                                                 
 up_sampling2d_2 (UpSampling  (None, 14, 14, 32)       0         
 2D)                                                             
                                                                 
 conv2d_8 (Conv2D)           (None, 14, 14, 32)        9248      
                                                                 
 up_sampling2d_3 (UpSampling  (None, 28, 28, 32)       0         
 2D)                                                             
                                                                 
 conv2d_9 (Conv2D)           (None, 28, 28, 1)         289       
                                                                 
=================================================================
Total params: 28,353
Trainable params: 28,353
Non-trainable params: 0
_________________________________________________________________

También puede ser útil visualizar dicha red.

plot_model(autoencoder_noise, to_file='autoencoder_noise_plot.png', show_shapes=True, show_layer_names=True)

Vamos ahora a entrenar nuestro modelo para ello usamos el método fit que está disponible para cualquier modelo de Keras.

autoencoder_noise.fit(x_train_noisy, x_train,
                epochs=100,
                batch_size=128,
                validation_data=(x_test_noisy, x_test))
Epoch 1/100
469/469 [==============================] - 8s 15ms/step - loss: 0.1726 - val_loss: 0.1164
Epoch 2/100
469/469 [==============================] - 7s 14ms/step - loss: 0.1134 - val_loss: 0.1092
Epoch 3/100
469/469 [==============================] - 6s 14ms/step - loss: 0.1083 - val_loss: 0.1059
Epoch 4/100
469/469 [==============================] - 7s 14ms/step - loss: 0.1056 - val_loss: 0.1041
Epoch 5/100
469/469 [==============================] - 7s 14ms/step - loss: 0.1037 - val_loss: 0.1021
Epoch 6/100
469/469 [==============================] - 7s 14ms/step - loss: 0.1023 - val_loss: 0.1009
Epoch 7/100
469/469 [==============================] - 7s 14ms/step - loss: 0.1011 - val_loss: 0.0998
Epoch 8/100
469/469 [==============================] - 7s 14ms/step - loss: 0.1001 - val_loss: 0.0990
Epoch 9/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0993 - val_loss: 0.0982
Epoch 10/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0987 - val_loss: 0.0979
Epoch 11/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0983 - val_loss: 0.0976
Epoch 12/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0979 - val_loss: 0.0973
Epoch 13/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0975 - val_loss: 0.0967
Epoch 14/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0971 - val_loss: 0.0972
Epoch 15/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0969 - val_loss: 0.0963
Epoch 16/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0966 - val_loss: 0.0964
Epoch 17/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0964 - val_loss: 0.0958
Epoch 18/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0962 - val_loss: 0.0955
Epoch 19/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0959 - val_loss: 0.0957
Epoch 20/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0958 - val_loss: 0.0954
Epoch 21/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0957 - val_loss: 0.0952
Epoch 22/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0956 - val_loss: 0.0954
Epoch 23/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0954 - val_loss: 0.0951
Epoch 24/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0953 - val_loss: 0.0948
Epoch 25/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0952 - val_loss: 0.0948
Epoch 26/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0951 - val_loss: 0.0953
Epoch 27/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0950 - val_loss: 0.0948
Epoch 28/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0949 - val_loss: 0.0947
Epoch 29/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0949 - val_loss: 0.0945
Epoch 30/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0948 - val_loss: 0.0944
Epoch 31/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0947 - val_loss: 0.0944
Epoch 32/100
469/469 [==============================] - 7s 15ms/step - loss: 0.0947 - val_loss: 0.0945
Epoch 33/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0946 - val_loss: 0.0943
Epoch 34/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0945 - val_loss: 0.0945
Epoch 35/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0945 - val_loss: 0.0944
Epoch 36/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0945 - val_loss: 0.0943
Epoch 37/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0944 - val_loss: 0.0941
Epoch 38/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0944 - val_loss: 0.0941
Epoch 39/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0943 - val_loss: 0.0942
Epoch 40/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0942 - val_loss: 0.0941
Epoch 41/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0943 - val_loss: 0.0942
Epoch 42/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0942 - val_loss: 0.0940
Epoch 43/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0942 - val_loss: 0.0939
Epoch 44/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0941 - val_loss: 0.0941
Epoch 45/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0941 - val_loss: 0.0938
Epoch 46/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0941 - val_loss: 0.0941
Epoch 47/100
469/469 [==============================] - 7s 15ms/step - loss: 0.0941 - val_loss: 0.0941
Epoch 48/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0940 - val_loss: 0.0939
Epoch 49/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0940 - val_loss: 0.0938
Epoch 50/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0939 - val_loss: 0.0939
Epoch 51/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0940 - val_loss: 0.0938
Epoch 52/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0939 - val_loss: 0.0939
Epoch 53/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0938 - val_loss: 0.0939
Epoch 54/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0938 - val_loss: 0.0938
Epoch 55/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0939 - val_loss: 0.0937
Epoch 56/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0938 - val_loss: 0.0938
Epoch 57/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0938 - val_loss: 0.0939
Epoch 58/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0938 - val_loss: 0.0939
Epoch 59/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0937 - val_loss: 0.0937
Epoch 60/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0937 - val_loss: 0.0937
Epoch 61/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0937 - val_loss: 0.0937
Epoch 62/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0937 - val_loss: 0.0936
Epoch 63/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0937 - val_loss: 0.0936
Epoch 64/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0937 - val_loss: 0.0938
Epoch 65/100
469/469 [==============================] - 7s 15ms/step - loss: 0.0937 - val_loss: 0.0937
Epoch 66/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0936 - val_loss: 0.0937
Epoch 67/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0936 - val_loss: 0.0939
Epoch 68/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0936 - val_loss: 0.0936
Epoch 69/100
469/469 [==============================] - 7s 15ms/step - loss: 0.0936 - val_loss: 0.0936
Epoch 70/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0936 - val_loss: 0.0936
Epoch 71/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0936 - val_loss: 0.0938
Epoch 72/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0936 - val_loss: 0.0938
Epoch 73/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0935 - val_loss: 0.0936
Epoch 74/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0935 - val_loss: 0.0936
Epoch 75/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0935 - val_loss: 0.0937
Epoch 76/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0935 - val_loss: 0.0938
Epoch 77/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0935 - val_loss: 0.0937
Epoch 78/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0934 - val_loss: 0.0935
Epoch 79/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0935 - val_loss: 0.0937
Epoch 80/100
469/469 [==============================] - 7s 15ms/step - loss: 0.0935 - val_loss: 0.0936
Epoch 81/100
469/469 [==============================] - 7s 15ms/step - loss: 0.0934 - val_loss: 0.0934
Epoch 82/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0935 - val_loss: 0.0940
Epoch 83/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0934 - val_loss: 0.0942
Epoch 84/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0934 - val_loss: 0.0936
Epoch 85/100
469/469 [==============================] - 7s 15ms/step - loss: 0.0934 - val_loss: 0.0934
Epoch 86/100
469/469 [==============================] - 7s 15ms/step - loss: 0.0934 - val_loss: 0.0936
Epoch 87/100
469/469 [==============================] - 7s 15ms/step - loss: 0.0934 - val_loss: 0.0934
Epoch 88/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0934 - val_loss: 0.0934
Epoch 89/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0934 - val_loss: 0.0935
Epoch 90/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0934 - val_loss: 0.0934
Epoch 91/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0934 - val_loss: 0.0935
Epoch 92/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0933 - val_loss: 0.0935
Epoch 93/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0933 - val_loss: 0.0935
Epoch 94/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0934 - val_loss: 0.0938
Epoch 95/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0933 - val_loss: 0.0935
Epoch 96/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0933 - val_loss: 0.0935
Epoch 97/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0933 - val_loss: 0.0935
Epoch 98/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0933 - val_loss: 0.0935
Epoch 99/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0933 - val_loss: 0.0935
Epoch 100/100
469/469 [==============================] - 7s 14ms/step - loss: 0.0933 - val_loss: 0.0934
<keras.callbacks.History at 0x7f0a234423d0>

Por último podemos predecir usando el modelo entrenado con las imágenes de test:

decoded_imgs = autoencoder_noise.predict(x_test_noisy)

Y a continuación mostrar el resultado obtenido.

n = 10
plt.figure(figsize=(20, 4))
for i in range(1,n):
    ax = plt.subplot(2, n, i)
    plt.imshow(x_test_noisy[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    ax = plt.subplot(2, n, i + n)
    plt.imshow(decoded_imgs[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

Variational autoencoder

Como hemos visto en teoría un autoencoder toma una imagen, la mapea a un espacio vectorial subyacente por medio de un encoder, y seguidamente la decodifica a una salida con el tamaño de la imagen original. En la práctica, los autoencoders no son especialmente útiles, y han sido reemplazados por los variational autoencoders (o VAEs).

Un VAE en lugar de comprimir la imagen en un vector fijo en el espacio subyacente, la convierte a los parámetros de una distribución estadística (representada mediante una media y una varianza). Esencialmente, esto significa que asumimos que la imagen original ha sido generada mediante un proceso estadístico, y que la aleatoriedad de dicho proceso debe ser tomada en cuenta a la hora de proceso de codificación y decodificación.

Un VAE usa los parámetros de media y varianza para tomar una muestra aleatoria de un elemento de la distribución, y decodifica dicho elemento de vuelta. Esto hace que mejore la robusted y fuerza a que el espacio subyacente obtenga representaciones significativas (notar que cualquier muestra de la distribución tiene que ser decodificada a una salida valida).

Desde el punto de vista técnico, un VAE funciona del siguiente modo:

  1. Un encoder convierte la entrada en dos parámetros de un espacio subyacente de representaciones que denotaremos por z_mean y z_log_var.
  2. Tomamos una muestra aleatoria z de la distribución normal que asumimos que genera la imagen de entrada mediante la fórmula z = z_mean + exp(z_log_var)*epsilon donde epsilon es un valor aleatorio pequeño.
  3. La muestra z se decodifica. Al tomar epsilon de manera aleatoria y con valor pequeño, el proceso asegura que cada punto que está cerca de la localización subyacente de la imagen puede ser decodificado a algo similar a la imagen de entrada.

Para entrenar un VAE se usan dos funciones de pérdida: una que es la función de pérdida de reconstrucción que fuerza a que las muestras decodificadas se ajusten a las entradas iniciales, y una función de pérdida de regularización que ayuda a una formación correcta de los espacios subyacentes y a que no se produzca sobreajuste.

Capa de Sampling

Lo primero que vamos a definir es una nueva capa encargada de tomar una muestra aleatoria a partir de los valores de z_mean y z_log_var. Para ello debemos definir una nueva clase que hereda de la clase Layer de Keras y definir la función call.

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

class Sampling(layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

Encoder

El encoder va a ser similar al encoder que vimos para el autoencoder, la principal diferencia es que va a producir dos vectores de salida, z_mean y z_log_var.

latent_dim = 2

encoder_inputs = keras.Input(shape=(28, 28, 1))
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()
Model: "encoder"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_3 (InputLayer)           [(None, 28, 28, 1)]  0           []                               
                                                                                                  
 conv2d_10 (Conv2D)             (None, 14, 14, 32)   320         ['input_3[0][0]']                
                                                                                                  
 conv2d_11 (Conv2D)             (None, 7, 7, 64)     18496       ['conv2d_10[0][0]']              
                                                                                                  
 flatten (Flatten)              (None, 3136)         0           ['conv2d_11[0][0]']              
                                                                                                  
 dense (Dense)                  (None, 16)           50192       ['flatten[0][0]']                
                                                                                                  
 z_mean (Dense)                 (None, 2)            34          ['dense[0][0]']                  
                                                                                                  
 z_log_var (Dense)              (None, 2)            34          ['dense[0][0]']                  
                                                                                                  
 sampling (Sampling)            (None, 2)            0           ['z_mean[0][0]',                 
                                                                  'z_log_var[0][0]']              
                                                                                                  
==================================================================================================
Total params: 69,076
Trainable params: 69,076
Non-trainable params: 0
__________________________________________________________________________________________________

Decoder

Ahora podemos definir el decoder, utilizando una arquitectura simétrica al encoder.

latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(7 * 7 * 64, activation="relu")(latent_inputs)
x = layers.Reshape((7, 7, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()
Model: "decoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_4 (InputLayer)        [(None, 2)]               0         
                                                                 
 dense_1 (Dense)             (None, 3136)              9408      
                                                                 
 reshape (Reshape)           (None, 7, 7, 64)          0         
                                                                 
 conv2d_transpose (Conv2DTra  (None, 14, 14, 64)       36928     
 nspose)                                                         
                                                                 
 conv2d_transpose_1 (Conv2DT  (None, 28, 28, 32)       18464     
 ranspose)                                                       
                                                                 
 conv2d_transpose_2 (Conv2DT  (None, 28, 28, 1)        289       
 ranspose)                                                       
                                                                 
=================================================================
Total params: 65,089
Trainable params: 65,089
Non-trainable params: 0
_________________________________________________________________

Por último vamos a definir un nuevo modelo que une nuestro encoder y decoder definidos anteriormente, y definimos nuestra función de pérdida que va a tener en cuenta la función de pérdida de reconstrucción y la función de pérdida de regularización. Para ello debemos definir una nueva clase que herede de Model.

class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
                )
            )
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

Entrenamiento

Finalmente, instanciamos el modelo y lo compilamos. Notar que no es necesario definir una función de pérdida de manera explícita ya que la hemos definido anteriormente. Esto supone que a la hora de entrenar el modelo no hará falta proporcionar la salida espera del modelo ya que es la misma que la entrada y de ello se encarga la capa definida anteriormente.

vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam())

Podemos ahora entrenar el modelo para lo que vamos a juntar los conjuntos de entrenamiento y test de MNIST.

(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255

vae.fit(mnist_digits, epochs=30, batch_size=128)
Epoch 1/30
547/547 [==============================] - 12s 18ms/step - loss: 258.3222 - reconstruction_loss: 207.3263 - kl_loss: 3.3780
Epoch 2/30
547/547 [==============================] - 10s 18ms/step - loss: 175.8255 - reconstruction_loss: 167.4281 - kl_loss: 5.5007
Epoch 3/30
547/547 [==============================] - 10s 18ms/step - loss: 167.3953 - reconstruction_loss: 160.0686 - kl_loss: 5.8272
Epoch 4/30
547/547 [==============================] - 10s 18ms/step - loss: 162.9637 - reconstruction_loss: 155.8265 - kl_loss: 6.1042
Epoch 5/30
547/547 [==============================] - 10s 18ms/step - loss: 159.7077 - reconstruction_loss: 153.3320 - kl_loss: 6.2538
Epoch 6/30
547/547 [==============================] - 10s 18ms/step - loss: 158.1543 - reconstruction_loss: 151.6510 - kl_loss: 6.3135
Epoch 7/30
547/547 [==============================] - 10s 18ms/step - loss: 156.9544 - reconstruction_loss: 150.2694 - kl_loss: 6.3696
Epoch 8/30
547/547 [==============================] - 10s 18ms/step - loss: 155.9744 - reconstruction_loss: 149.1816 - kl_loss: 6.4029
Epoch 9/30
547/547 [==============================] - 10s 18ms/step - loss: 154.6804 - reconstruction_loss: 148.4061 - kl_loss: 6.4388
Epoch 10/30
547/547 [==============================] - 10s 18ms/step - loss: 154.5413 - reconstruction_loss: 147.5961 - kl_loss: 6.4629
Epoch 11/30
547/547 [==============================] - 10s 18ms/step - loss: 153.7180 - reconstruction_loss: 147.1163 - kl_loss: 6.4786
Epoch 12/30
547/547 [==============================] - 10s 18ms/step - loss: 153.2216 - reconstruction_loss: 146.5846 - kl_loss: 6.4871
Epoch 13/30
547/547 [==============================] - 10s 18ms/step - loss: 152.5861 - reconstruction_loss: 146.1002 - kl_loss: 6.4937
Epoch 14/30
547/547 [==============================] - 10s 18ms/step - loss: 152.2902 - reconstruction_loss: 145.7991 - kl_loss: 6.5250
Epoch 15/30
547/547 [==============================] - 10s 18ms/step - loss: 151.8982 - reconstruction_loss: 145.4320 - kl_loss: 6.5198
Epoch 16/30
547/547 [==============================] - 10s 18ms/step - loss: 151.8144 - reconstruction_loss: 145.0116 - kl_loss: 6.5122
Epoch 17/30
547/547 [==============================] - 10s 18ms/step - loss: 151.3430 - reconstruction_loss: 144.8974 - kl_loss: 6.5174
Epoch 18/30
547/547 [==============================] - 10s 18ms/step - loss: 151.0552 - reconstruction_loss: 144.5603 - kl_loss: 6.5250
Epoch 19/30
547/547 [==============================] - 10s 18ms/step - loss: 150.9264 - reconstruction_loss: 144.3094 - kl_loss: 6.5371
Epoch 20/30
547/547 [==============================] - 10s 18ms/step - loss: 149.9879 - reconstruction_loss: 143.9915 - kl_loss: 6.5339
Epoch 21/30
547/547 [==============================] - 10s 18ms/step - loss: 150.5420 - reconstruction_loss: 143.8894 - kl_loss: 6.5456
Epoch 22/30
547/547 [==============================] - 10s 18ms/step - loss: 149.8986 - reconstruction_loss: 143.6472 - kl_loss: 6.5409
Epoch 23/30
547/547 [==============================] - 10s 18ms/step - loss: 149.7602 - reconstruction_loss: 143.5481 - kl_loss: 6.5314
Epoch 24/30
547/547 [==============================] - 10s 18ms/step - loss: 150.0828 - reconstruction_loss: 143.3013 - kl_loss: 6.5392
Epoch 25/30
547/547 [==============================] - 10s 18ms/step - loss: 149.4021 - reconstruction_loss: 143.1110 - kl_loss: 6.5467
Epoch 26/30
547/547 [==============================] - 10s 18ms/step - loss: 149.6587 - reconstruction_loss: 142.8784 - kl_loss: 6.5507
Epoch 27/30
547/547 [==============================] - 10s 18ms/step - loss: 149.4968 - reconstruction_loss: 142.8175 - kl_loss: 6.5572
Epoch 28/30
547/547 [==============================] - 10s 18ms/step - loss: 148.9915 - reconstruction_loss: 142.6800 - kl_loss: 6.5604
Epoch 29/30
547/547 [==============================] - 10s 18ms/step - loss: 149.2337 - reconstruction_loss: 142.4957 - kl_loss: 6.5624
Epoch 30/30
547/547 [==============================] - 10s 18ms/step - loss: 149.0440 - reconstruction_loss: 142.3415 - kl_loss: 6.5585
<keras.callbacks.History at 0x7f09ae6fe910>

Una vez que el modelo se ha entrenado, podemos usar el decoder para convertir puntos aleatorios del espacio subyacente en imágenes.

import matplotlib.pyplot as plt

def plot_latent_space(vae, n=30, figsize=15):
    # display a n*n 2D manifold of digits
    digit_size = 28
    scale = 1.0
    figure = np.zeros((digit_size * n, digit_size * n))
    # linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = vae.decoder.predict(z_sample)
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[
                i * digit_size : (i + 1) * digit_size,
                j * digit_size : (j + 1) * digit_size,
            ] = digit

    plt.figure(figsize=(figsize, figsize))
    start_range = digit_size // 2
    end_range = n * digit_size + start_range
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap="Greys_r")
    plt.show()


plot_latent_space(vae)

La cuadrícula anterior muestra una distribución continua de los dígitos de las distintas clases, y se puede ver cómo un dígito se transforma en otro al seguir un camino a través del espacio subyacente. Notar que hay direcciones subyacentes que tienen significado (como cuatri-ficar o uni-ficar).

Por último podemos ver los clústeres del espacio subyacente asociados a cada clase.

def plot_label_clusters(vae, data, labels):
    # display a 2D plot of the digit classes in the latent space
    z_mean, _, _ = encoder.predict(data)
    plt.figure(figsize=(12, 10))
    plt.scatter(z_mean[:, 0], z_mean[:, 1], c=labels)
    plt.colorbar()
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.set_cmap('viridis')
    plt.show()


(x_train, y_train), _ = keras.datasets.mnist.load_data()
x_train = np.expand_dims(x_train, -1).astype("float32") / 255

plot_label_clusters(vae, x_train, y_train)