Configuring and Training a Multi-layer Perceptron (MLP) in SciKit-Learn

(Notebook prepared by Pr Fabien MOUTARDE, Center for Robotics, MINES ParisTech, PSL Université Paris)

1. Understand and experiment MLP on a VERY simple classification problem

Building, training and evaluating a simple Neural Network classifier (Multi Layer Perceptron, MLP)

The SciKit-learn class for MLP is MLPClassifier. Please first read the MLPClassifier documentation; to understand all parameters of the constructor. You can then begin by running the code block below, in which an initial set of hyper-parameter values has been chosen. YOU MAY NEED TO CHANGE AT LEAST THE NUMBER OF HIDDEN NEURONS (and probably other hyper-parameters) IN ORDER TO BE ABLE TO LEARN A CORRECT CLASSIFIER

Below, we visualize the learnt boundary between classes in (2D) input space

THIS SHOULD HELP YOU UNDERSTAND WHAT HAPPENS IF THERE ARE NOT ENOUGH HIDDEN NEURONS

Optional: add code that visualises on the same plot the straight lines corresponding to each hidden neuron (you will need to dig into MLPClassifier documentation to find the 2 input weights and the bias of each hidden neuron). YOU SHOULD NOTICE THAT THE CLASSIFICATION BOUNDARY IS SOME INTERPOLATION BETWEEN THOSE STRAIGHT LINES.

Now, check, by changing MLPClassifier parameters above and then rerunning training+eval+plots, the impact of main learning hyper-parameters:

Finally, use grid-search and cross-validation to find an optimized set of learning hyper-parameters (see code below).

Because the values of learning hyper-parameters can DRASTICALLY change the outcome of training, it is ESSENTIAL THAT YOU ALWAYS TRY TO FIND OPTIMIZED VALUES FOR THE ALGORITHM HYPER-PARAMETERS. And this ABSOLUTELY NEEDS TO BE DONE USING "VALIDATION", either with a validation set separate from the training set, or using cross-validation. CROSS-VALIDATION is the MOST ROBUST WAY OF FINDING OPTIMIZED HYPER-PARAMETRS VALUES, and the GridSearchCV function of SciKit-Learn makes this rather straightforward.

WARNING: GridSearchCV launches many successive training sessions, so can be rather long to execute if you compare too many combinations

2. WORK ON A REALISTIC DATASET: A SIMPLIFIED HANDWRITTEN DIGITS DATASET

Please FIRST READ the Digits DATASET DESCRIPTION. In this classification problem, there are 10 classes, with a total of 1797 examples (each one being a 64D vector corresponding to an 8x8 pixmap).

Assignment #1: find out what learning hyper-parameters should be modified in order to obtain a satisfying MLP digits classifier

Assignment #2: modify the code below to use cross-validation and find best training hyper-parameters and MLP classifier you can for this handwritten digits classification task.

Assignment #3: compute and plot the precision-recall curve (for each class). NB: search into sciKit-learn documentation to find the function for that, and then add a code cell that uses it.

Assignment #4: display the confusion matrix as a prettier and more easily understable plot (cf. example on https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html)

Assignment #5 (optional): plot the first layer of weights as images (see explanations and example code at http://scikit-learn.org/stable/auto_examples/neural_networks/plot_mnist_filters.html)