variational_autoencoder_deconv
This script demonstrates how to build a variational autoencoder with Keras and deconvolution layers. Reference: “Auto-Encoding Variational Bayes” https://arxiv.org/abs/1312.6114
library(keras)
K <- keras::backend()
#### Parameterization ####
# input image dimensions
img_rows <- 28L
img_cols <- 28L
# color channels (1 = grayscale, 3 = RGB)
img_chns <- 1L
# number of convolutional filters to use
filters <- 64L
# convolution kernel size
num_conv <- 3L
latent_dim <- 2L
intermediate_dim <- 128L
epsilon_std <- 1.0
# training parameters
batch_size <- 100L
epochs <- 5L
#### Model Construction ####
original_img_size <- c(img_rows, img_cols, img_chns)
x <- layer_input(shape = c(original_img_size))
conv_1 <- layer_conv_2d(
x,
filters = img_chns,
kernel_size = c(2L, 2L),
strides = c(1L, 1L),
padding = "same",
activation = "relu"
)
conv_2 <- layer_conv_2d(
conv_1,
filters = filters,
kernel_size = c(2L, 2L),
strides = c(2L, 2L),
padding = "same",
activation = "relu"
)
conv_3 <- layer_conv_2d(
conv_2,
filters = filters,
kernel_size = c(num_conv, num_conv),
strides = c(1L, 1L),
padding = "same",
activation = "relu"
)
conv_4 <- layer_conv_2d(
conv_3,
filters = filters,
kernel_size = c(num_conv, num_conv),
strides = c(1L, 1L),
padding = "same",
activation = "relu"
)
flat <- layer_flatten(conv_4)
hidden <- layer_dense(flat, units = intermediate_dim, activation = "relu")
z_mean <- layer_dense(hidden, units = latent_dim)
z_log_var <- layer_dense(hidden, units = latent_dim)
sampling <- function(args) {
z_mean <- args[, 1:(latent_dim)]
z_log_var <- args[, (latent_dim + 1):(2 * latent_dim)]
epsilon <- k_random_normal(
shape = c(k_shape(z_mean)[[1]]),
mean = 0.,
stddev = epsilon_std
)
z_mean + k_exp(z_log_var) * epsilon
}
z <- layer_concatenate(list(z_mean, z_log_var)) %>% layer_lambda(sampling)
output_shape <- c(batch_size, 14L, 14L, filters)
decoder_hidden <- layer_dense(units = intermediate_dim, activation = "relu")
decoder_upsample <- layer_dense(units = prod(output_shape[-1]), activation = "relu")
decoder_reshape <- layer_reshape(target_shape = output_shape[-1])
decoder_deconv_1 <- layer_conv_2d_transpose(
filters = filters,
kernel_size = c(num_conv, num_conv),
strides = c(1L, 1L),
padding = "same",
activation = "relu"
)
decoder_deconv_2 <- layer_conv_2d_transpose(
filters = filters,
kernel_size = c(num_conv, num_conv),
strides = c(1L, 1L),
padding = "same",
activation = "relu"
)
decoder_deconv_3_upsample <- layer_conv_2d_transpose(
filters = filters,
kernel_size = c(3L, 3L),
strides = c(2L, 2L),
padding = "valid",
activation = "relu"
)
decoder_mean_squash <- layer_conv_2d(
filters = img_chns,
kernel_size = c(2L, 2L),
strides = c(1L, 1L),
padding = "valid",
activation = "sigmoid"
)
hidden_decoded <- decoder_hidden(z)
up_decoded <- decoder_upsample(hidden_decoded)
reshape_decoded <- decoder_reshape(up_decoded)
deconv_1_decoded <- decoder_deconv_1(reshape_decoded)
deconv_2_decoded <- decoder_deconv_2(deconv_1_decoded)
x_decoded_relu <- decoder_deconv_3_upsample(deconv_2_decoded)
x_decoded_mean_squash <- decoder_mean_squash(x_decoded_relu)
# custom loss function
vae_loss <- function(x, x_decoded_mean_squash) {
x <- k_flatten(x)
x_decoded_mean_squash <- k_flatten(x_decoded_mean_squash)
xent_loss <- 1.0 * img_rows * img_cols *
loss_binary_crossentropy(x, x_decoded_mean_squash)
kl_loss <- -0.5 * k_mean(1 + z_log_var - k_square(z_mean) -
k_exp(z_log_var), axis = -1L)
k_mean(xent_loss + kl_loss)
}
## variational autoencoder
vae <- keras_model(x, x_decoded_mean_squash)
vae %>% compile(optimizer = "rmsprop", loss = vae_loss)
summary(vae)
## encoder: model to project inputs on the latent space
encoder <- keras_model(x, z_mean)
## build a digit generator that can sample from the learned distribution
gen_decoder_input <- layer_input(shape = latent_dim)
gen_hidden_decoded <- decoder_hidden(gen_decoder_input)
gen_up_decoded <- decoder_upsample(gen_hidden_decoded)
gen_reshape_decoded <- decoder_reshape(gen_up_decoded)
gen_deconv_1_decoded <- decoder_deconv_1(gen_reshape_decoded)
gen_deconv_2_decoded <- decoder_deconv_2(gen_deconv_1_decoded)
gen_x_decoded_relu <- decoder_deconv_3_upsample(gen_deconv_2_decoded)
gen_x_decoded_mean_squash <- decoder_mean_squash(gen_x_decoded_relu)
generator <- keras_model(gen_decoder_input, gen_x_decoded_mean_squash)
#### Data Preparation ####
mnist <- dataset_mnist()
data <- lapply(mnist, function(m) {
array_reshape(m$x / 255, dim = c(dim(m$x)[1], original_img_size))
})
x_train <- data$train
x_test <- data$test
#### Model Fitting ####
vae %>% fit(
x_train, x_train,
shuffle = TRUE,
epochs = epochs,
batch_size = batch_size,
validation_data = list(x_test, x_test)
)
#### Visualizations ####
library(ggplot2)
library(dplyr)
## display a 2D plot of the digit classes in the latent space
x_test_encoded <- predict(encoder, x_test, batch_size = batch_size)
x_test_encoded %>%
as_data_frame() %>%
mutate(class = as.factor(mnist$test$y)) %>%
ggplot(aes(x = V1, y = V2, colour = class)) + geom_point()
## display a 2D manifold of the digits
n <- 15 # figure with 15x15 digits
digit_size <- 28
# we will sample n points within [-4, 4] standard deviations
grid_x <- seq(-4, 4, length.out = n)
grid_y <- seq(-4, 4, length.out = n)
rows <- NULL
for(i in 1:length(grid_x)){
column <- NULL
for(j in 1:length(grid_y)){
z_sample <- matrix(c(grid_x[i], grid_y[j]), ncol = 2)
column <- rbind(column, predict(generator, z_sample) %>% matrix(ncol = digit_size))
}
rows <- cbind(rows, column)
}
rows %>% as.raster() %>% plot()