---
title: "MNIST hypertuning"
output:
  rmarkdown::html_vignette
vignette: >
  %\VignetteIndexEntry{MNIST hypertuning}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---

```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = TRUE,eval = F)
```

This tutorial shows the hyperparameter tuning for [MNIST dataset](http://yann.lecun.com/exdb/mnist/).

If you have difficulty with the following topic, [Keras tutorials](https://tensorflow.rstudio.com/articles/tutorial_basic_classification.html) would be very helpful.

## Input shape

At first, we should determine the input shape of our model.

```
inputs = tf$keras$Input(shape=list(28L, 28L, 1L))
```

## Layers

The next step is to put iterable parameters into the loop function ```for (i in 1:n) {}``` and define layers.

```
for (i in 1:3) {
tf$keras$layers$Conv2D(filters = hp$Int(paste('filters_', i, sep = ''), 4L, 32L, step=4L, default=8L), 
kernel_size = hp$Int(paste('kernel_size_', i, sep = ''), 3L, 5L),
activation ='relu',
padding='same')
}
```

# Functional API

Later, we should concatenate all the layers and create a model via [Functional API](https://tensorflow.rstudio.com/guides/keras/functional_api). Please, refer to the Functional API, if you want to get more details about it.

```
model = tf$keras$Model(inputs, outputs)
```

## Optimizers

Finally, we should determine our optimizer. Keras Tuner allows us to tune even optimizers.

```
optimizer = hp$Choice('optimizer', c('adam', 'sgd'))
```

As the ```build(hp)``` function is ready, we should build a Hyperband object, and start searching the best hyperparameters.

## Full code

Below is the full version of a tuning process for MNIST dataset.

```{r}
conv_build_model = function(hp) {
  'Builds a convolutional model.'
  inputs = tf$keras$Input(shape=list(28L, 28L, 1L))
  
  x = inputs
  
  for (i in 1:hp$Int('conv_layers', 1L, 3L, default=3L)) {
    x = tf$keras$layers$Conv2D(filters = hp$Int(paste('filters_', i, sep = ''), 
                                                4L, 32L, step=4L, default=8L),
                               kernel_size = hp$Int(paste('kernel_size_', i, sep = ''), 3L, 5L),
                               activation ='relu',
                               padding='same')(x)
    if (hp$Choice(paste('pooling', i, sep = ''), c('max', 'avg')) == 'max') {
      x = tf$keras$layers$MaxPooling2D()(x)
    } else {
      x = tf$keras$layers$AveragePooling2D()(x)
    }
    x = tf$keras$layers$BatchNormalization()(x) 
    x =  tf$keras$layers$ReLU()(x)
    
  }
  if (hp$Choice('global_pooling', c('max', 'avg')) == 'max') {
    x =  tf$keras$layers$GlobalMaxPool2D()(x)
  } else {
    x = tf$keras$layers$GlobalAveragePooling2D()(x)
  }
  
  outputs = tf$keras$layers$Dense(10L, activation='softmax')(x)
  model = tf$keras$Model(inputs, outputs)
  optimizer = hp$Choice('optimizer', c('adam', 'sgd'))
  model %>% compile(optimizer, loss='sparse_categorical_crossentropy', metrics='accuracy')
  return(model)
}

main = function() {
  tuner = Hyperband(
    hypermodel = conv_build_model,
    objective = 'val_accuracy',
    max_epochs = 8,
    factor = 2,
    hyperband_iterations = 3,
    directory = 'results_dir',
    project_name='mnist')
  
  # call keras library for downloading MNIST dataset
  library(keras)
  
  mnist_data = dataset_fashion_mnist()
  c(mnist_train, mnist_test) %<-%  mnist_data
  rm(mnist_data)
  
  # reshape data
  mnist_train$x = keras::k_reshape(mnist_train$x,shape = c(6e4,28,28,1))
  mnist_test$x = keras::k_reshape(mnist_test$x,shape = c(1e4,28,28,1))
  
  # call tfdatasets and slice dataset
  # turn data type into float 32 (features, not labels/outputs)
  library(tfdatasets)
  mnist_train =  tensor_slices_dataset(list(tf$dtypes$cast(
    mnist_train$x, 'float32') / 255., mnist_train$y)) %>% 
    dataset_shuffle(1e3) %>% dataset_batch(1e2) %>% dataset_repeat()
  
  mnist_test = tensor_slices_dataset(list(tf$dtypes$cast(
    mnist_test$x, 'float32') / 255., mnist_test$y)) %>%
    dataset_batch(1e2)
  
  # finally, begin a training with a bunch of parameters
  tuner %>% fit_tuner(x = mnist_train,
               steps_per_epoch=600,
               validation_data=mnist_test,
               validation_steps=100,
               epochs=2,
               callbacks=c(tf$keras$callbacks$EarlyStopping('val_accuracy')) 
  )
}

main()

```




