Deep Learning for Hand Gesture Recognition with PyTorch: Quickstart

This notebook is a demo pytorch implementation of the deep learning model for hand gesture recognition introduced in the article Deep Learning for Hand Gesture Recognition on Skeletal Data from G. Devineau, F. Moutarde, W. Xi and J. Yang.

If you find this code useful in your research, please consider citing:

@inproceedings{devineau2018deep,
  title={Deep learning for hand gesture recognition on skeletal data},
  author={Devineau, Guillaume and Moutarde, Fabien and Xi, Wang and Yang, Jie},
  booktitle={2018 13th IEEE International Conference on Automatic Face \& Gesture Recognition (FG 2018)},
  pages={106--113},
  year={2018},
  organization={IEEE}
}

Fore a more detailed introduction to the model, please refer to this quick introduction to the model or read the original article.

For keras users, a simplified keras version of the current colab is also available for convenience.

Overview of the model

Model Overview

1. Imports

In [1]:
#!/usr/bin/python
# -*- coding: utf-8 -*-

from __future__ import unicode_literals, print_function, division
import sys
if sys.version_info.major < 3:
    print('You are using python 2, but you should rather use python 3.')
    print('    If you still want to use python 2, ensure you import:')
    print('    >> from __future__ import unicode_literals, print_function, division')

If you encounter an python error regarding a missing module at some point, uncomment the appropriate line in the cell below and run it.

In [2]:
# Uncomment the following lines to install modules if needed:
# -----------------------------------------------------------
# --- required: numpy + sklearn + scipy:
# !{sys.executable} -m pip install numpy scipy sklearn
# --- required: torch
# !{sys.executable} -m pip install torch  # <<------ advice: do not use this pip command and install torch via conda, with CUDA: see: https://pytorch.org/get-started/locally/
# --- bonus: tensorboardX:
# !{sys.executable} -m pip install tensorflow tensorflow-gpu
# !{sys.executable} -m pip install tensorboardX
# --- bonus: jupyter-lab:
# !{sys.executable} -m pip install matplotlib ipython jupyter jupyter-lab pandas tqdm
# !{sys.executable} -m pip install ipywidgets
# !jupyter nbextension enable --py widgetsnbextension
In [3]:
import itertools
import numpy
import torch
import pickle
from scipy import ndimage as ndimage
from sklearn.utils import shuffle
import time
import math
In [4]:
# (bonus) plot acc with tensorboard
#   Command to start tensorboard if installed (requires tensorflow):
#   $  tensorboard --logdir ./runs
try:
    from tensorboardX import SummaryWriter
except:
    # tensorboardX is not installed, just fail silently
    class SummaryWriter():
        def __init__(self):
            pass
        def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
            pass
In [5]:
print('Using python {}.{}, with modules versions'.format(sys.version_info.major, sys.version_info.minor))
print('-'*40)
print('numpy == {}'.format(numpy.__version__))
print('torch == {}'.format(torch.__version__))  # please use version 1.0 or above
Using python 3.7, with modules versions
----------------------------------------
numpy == 1.16.4
torch == 1.2.0

2. Utils

We define some functions we'll need later on.

2.1. Gesture Dataset

Option 1: Use your own custom Dataset

To train the model on your own custom hand gesture dataset, edit the cells below. You'll need to adapt load_data(), shuffle_dataset(), preprocess_data() and convert_to_pytorch_tensors() according to your needs.

The load_data() function should return the gesture tensors x and the labels y.

The shape of the x tensor should be (dataset_size, duration, n_channels) where n_channels = 3 * n_joints for 3D pose data.


If your dataset does not fit into memory, you can use the straightforward pytorch Dataset class:

from torch.utils.data.dataset import Dataset
class MyCustomDataset(Dataset):

    def __init__(self, root_dir):
        self.sequences_names = ... # list sequences
        self.root_dir = root_dir

    def __len__(self):
        return len(self.sequences_names)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        sample = ...load_somehow(self.sequences_names[idx]) # load sequence
        return sample

Option 2: Use DHG14/28 or SHREC 17 Dataset

If you don't have your own gesture dataset, you might want to download one of these hand gesture datasets:

or here (ready-to-use dataset, pre-loaded into a single pickle file shrec_data.pckl):

In [6]:
# Option 1: Custom Dataset
# ------------------------
# On the left bar of this colaboratory notebook there is a section called "Files".
# Upload your files there and use a path like "/content/each_file_you_just_uploaded" to load your data, e.g. in the load_data() function below
In [7]:
# Option 2: Download the SHREC17 dataset 
# (either download the file manually or use the following wget command:)
!wget https://cloud.mines-paristech.fr/index.php/s/9U4bjHrvp8u2pnS/download -O ./shrec_data.pckl
'wget' n'est pas reconnu en tant que commande interne
ou externe, un programme ex‚cutable ou un fichier de commandes.
In [8]:
def load_data(filepath='./shrec2017_skel-data.pckl'):
    """
    Returns hand gesture sequences (X) and their associated labels (Y).
    Each sequence has two different labels.
    The first label  Y describes the gesture class out of 14 possible gestures (e.g. swiping your hand to the right).
    The second label Y describes the gesture class out of 28 possible gestures (e.g. swiping your hand to the right with your index pointed, or not pointed).
    """
    file = open(filepath, 'rb')
    data = pickle.load(file, encoding='latin1')  # <<---- change to 'latin1' to 'utf8' if the data does not load
    file.close()
    return data['x_train'], data['x_test'], data['y_train_14'], data['y_train_28'], data['y_test_14'], data['y_test_28']
In [9]:
def resize_sequences_length(x_train, x_test, final_length=100):
    """
    Resize the time series by interpolating them to the same length
    """
    # please use python3. if you still use python2, important note: redefine the classic division operator / by importing it from the __future__ module
    x_train = numpy.array([numpy.array([ndimage.zoom(x_i.T[j], final_length / len(x_i), mode='reflect') for j in range(numpy.size(x_i, 1))]).T for x_i in x_train])
    x_test  = numpy.array([numpy.array([ndimage.zoom(x_i.T[j], final_length / len(x_i), mode='reflect') for j in range(numpy.size(x_i, 1)) ]).T for x_i in x_test])
    return x_train, x_test
In [10]:
def shuffle_dataset(x_train, x_test, y_train_14, y_train_28, y_test_14, y_test_28):
    """Shuffle the train/test data consistently."""
    # note: add random_state=0 for reproducibility
    x_train, y_train_14, y_train_28 = shuffle(x_train, y_train_14, y_train_28)
    x_test,  y_test_14,  y_test_28  = shuffle(x_test,  y_test_14,  y_test_28)
    return x_train, x_test, y_train_14, y_train_28, y_test_14, y_test_28
In [11]:
def preprocess_data(x_train, x_test, y_train_14, y_train_28, y_test_14, y_test_28):
    """
    Preprocess the data as you want: update as you want!
        - possible improvement idea: make a PCA here
    """
    x_train, x_test, y_train_14, y_train_28, y_test_14, y_test_28 = shuffle_dataset(x_train, x_test, y_train_14, y_train_28, y_test_14, y_test_28)
    x_train, x_test = resize_sequences_length(x_train, x_test, final_length=100)
    return x_train, x_test, y_train_14, y_train_28, y_test_14, y_test_28
In [12]:
def convert_to_pytorch_tensors(x_train, x_test, y_train_14, y_train_28, y_test_14, y_test_28):
    # as numpy
    y_train_14, y_train_28, y_test_14, y_test_28 = numpy.array(y_train_14), numpy.array(y_train_28), numpy.array(y_test_14), numpy.array(y_test_28)
    
    # -- REQUIRED by the pytorch loss function implementation --
    # Remove 1 to all classes items (1-14 => 0-13 and 1-28 => 0-27)
    y_train_14, y_train_28, y_test_14, y_test_28 = y_train_14 - 1, y_train_28 - 1, y_test_14 - 1, y_test_28 - 1
    
    # as torch
    x_train, x_test = torch.from_numpy(x_train), torch.from_numpy(x_test)
    y_train_14, y_train_28, y_test_14, y_test_28 = torch.from_numpy(y_train_14), torch.from_numpy(y_train_28), torch.from_numpy(y_test_14), torch.from_numpy(y_test_28)

    # -- REQUIRED by the pytorch loss function implementation --
    # correct the data type (for the loss function used)
    x_train, x_test = x_train.type(torch.FloatTensor), x_test.type(torch.FloatTensor)
    y_train_14, y_train_28, y_test_14, y_test_28 = y_train_14.type(torch.LongTensor), y_train_28.type(torch.LongTensor), y_test_14.type(torch.LongTensor), y_test_28.type(torch.LongTensor)
    
    return x_train, x_test, y_train_14, y_train_28, y_test_14, y_test_28
In [13]:
# -------------
# Misc.
# -------------
def batch(tensor, batch_size=32):
    """Return a list of (mini) batches"""
    tensor_list = []
    length = tensor.shape[0]
    i = 0
    while True:
        if (i + 1) * batch_size >= length:
            tensor_list.append(tensor[i * batch_size: length])
            return tensor_list
        tensor_list.append(tensor[i * batch_size: (i + 1) * batch_size])
        i += 1


def time_since(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '{:02d}m {:02d}s'.format(int(m), int(s))


def get_accuracy(model, x, y_ref):
    """Get the accuracy of the pytorch model on a batch"""
    acc = 0.
    model.eval()
    with torch.no_grad():
        predicted = model(x)
        _, predicted = predicted.max(dim=1)
        acc = 1.0 * (predicted == y_ref).sum().item() / y_ref.shape[0]

    return acc

3. Model

In [14]:
class HandGestureNet(torch.nn.Module):
    """
    [Devineau et al., 2018] Deep Learning for Hand Gesture Recognition on Skeletal Data

    Summary
    -------
        Deep Learning Model for Hand Gesture classification using pose data only (no need for RGBD)
        The model computes a succession of [convolutions and pooling] over time independently on each of the 66 (= 22 * 3) sequence channels.
        Each of these computations are actually done at two different resolutions, that are later merged by concatenation
        with the (pooled) original sequence channel.
        Finally, a multi-layer perceptron merges all of the processed channels and outputs a classification.
    
    TL;DR:
    ------
        input ------------------------------------------------> split into n_channels channels [channel_i]
            channel_i ----------------------------------------> 3x [conv/pool/dropout] low_resolution_i
            channel_i ----------------------------------------> 3x [conv/pool/dropout] high_resolution_i
            channel_i ----------------------------------------> pooled_i
            low_resolution_i, high_resolution_i, pooled_i ----> output_channel_i
        MLP(n_channels x [output_channel_i]) -------------------------> classification

    Article / PDF:
    --------------
        https://ieeexplore.ieee.org/document/8373818

    Please cite:
    ------------
        @inproceedings{devineau2018deep,
            title={Deep learning for hand gesture recognition on skeletal data},
            author={Devineau, Guillaume and Moutarde, Fabien and Xi, Wang and Yang, Jie},
            booktitle={2018 13th IEEE International Conference on Automatic Face \& Gesture Recognition (FG 2018)},
            pages={106--113},
            year={2018},
            organization={IEEE}
        }
    """
    
    def __init__(self, n_channels=66, n_classes=14, dropout_probability=0.2):

        super(HandGestureNet, self).__init__()
        
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.dropout_probability = dropout_probability

        # Layers ----------------------------------------------
        self.all_conv_high = torch.nn.ModuleList([torch.nn.Sequential(
            torch.nn.Conv1d(in_channels=1, out_channels=8, kernel_size=7, padding=3),
            torch.nn.ReLU(),
            torch.nn.AvgPool1d(2),

            torch.nn.Conv1d(in_channels=8, out_channels=4, kernel_size=7, padding=3),
            torch.nn.ReLU(),
            torch.nn.AvgPool1d(2),

            torch.nn.Conv1d(in_channels=4, out_channels=4, kernel_size=7, padding=3),
            torch.nn.ReLU(),
            torch.nn.Dropout(p=self.dropout_probability),
            torch.nn.AvgPool1d(2)
        ) for joint in range(n_channels)])

        self.all_conv_low = torch.nn.ModuleList([torch.nn.Sequential(
            torch.nn.Conv1d(in_channels=1, out_channels=8, kernel_size=3, padding=1),
            torch.nn.ReLU(),
            torch.nn.AvgPool1d(2),

            torch.nn.Conv1d(in_channels=8, out_channels=4, kernel_size=3, padding=1),
            torch.nn.ReLU(),
            torch.nn.AvgPool1d(2),

            torch.nn.Conv1d(in_channels=4, out_channels=4, kernel_size=3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Dropout(p=self.dropout_probability),
            torch.nn.AvgPool1d(2)
        ) for joint in range(n_channels)])

        self.all_residual = torch.nn.ModuleList([torch.nn.Sequential(
            torch.nn.AvgPool1d(2),
            torch.nn.AvgPool1d(2),
            torch.nn.AvgPool1d(2)
        ) for joint in range(n_channels)])

        self.fc = torch.nn.Sequential(
            torch.nn.Linear(in_features=9 * n_channels * 12, out_features=1936),  # <-- 12: depends of the sequences lengths (cf. below)
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=1936, out_features=n_classes)
        )

        # Initialization --------------------------------------
        # Xavier init
        for module in itertools.chain(self.all_conv_high, self.all_conv_low, self.all_residual):
            for layer in module:
                if layer.__class__.__name__ == "Conv1d":
                    torch.nn.init.xavier_uniform_(layer.weight, gain=torch.nn.init.calculate_gain('relu'))
                    torch.nn.init.constant_(layer.bias, 0.1)

        for layer in self.fc:
            if layer.__class__.__name__ == "Linear":
                torch.nn.init.xavier_uniform_(layer.weight, gain=torch.nn.init.calculate_gain('relu'))
                torch.nn.init.constant_(layer.bias, 0.1)

    def forward(self, input):
        """
        This function performs the actual computations of the network for a forward pass.

        Arguments
        ---------
            input: a tensor of gestures of shape (batch_size, duration, n_channels)
                   (where n_channels = 3 * n_joints for 3D pose data)
        """

        # Work on each channel separately
        all_features = []

        for channel in range(0, self.n_channels):
            input_channel = input[:, :, channel]

            # Add a dummy (spatial) dimension for the time convolutions
            # Conv1D format : (batch_size, n_feature_maps, duration)
            input_channel = input_channel.unsqueeze(1)

            high = self.all_conv_high[channel](input_channel)
            low = self.all_conv_low[channel](input_channel)
            ap_residual = self.all_residual[channel](input_channel)

            # Time convolutions are concatenated along the feature maps axis
            output_channel = torch.cat([
                high,
                low,
                ap_residual
            ], dim=1)
            all_features.append(output_channel)

        # Concatenate along the feature maps axis
        all_features = torch.cat(all_features, dim=1)
        
        # Flatten for the Linear layers
        all_features = all_features.view(-1, 9 * self.n_channels * 12)  # <-- 12: depends of the initial sequence length (100).
        # If you have shorter/longer sequences, you probably do NOT even need to modify the modify the network architecture:
        # resampling your input gesture from T timesteps to 100 timesteps will (surprisingly) probably actually work as well!

        # Fully-Connected Layers
        output = self.fc(all_features)

        return output

4. Data loading; Loss and Optimizer function; Neural Network Model creation

If you use a custom dataset, you'll likely need to change the data loading part here.

In [15]:
# -------------
# Data
# -------------

# Load the dataset
x_train, x_test, y_train_14, y_train_28, y_test_14, y_test_28 = load_data()

# Shuffle sequences and resize sequences
x_train, x_test, y_train_14, y_train_28, y_test_14, y_test_28 = preprocess_data(x_train, x_test, y_train_14, y_train_28, y_test_14, y_test_28)

# Convert to pytorch variables
x_train, x_test, y_train_14, y_train_28, y_test_14, y_test_28 = convert_to_pytorch_tensors(x_train, x_test, y_train_14, y_train_28, y_test_14, y_test_28)
C:\Users\fabien\Anaconda3\lib\site-packages\scipy\ndimage\interpolation.py:605: UserWarning: From scipy 0.13.0, the output shape of zoom() is calculated with round() instead of int() - for these inputs the size of the returned array has changed.
  "the returned array has changed.", UserWarning)

If you use a custom dataset, you'll likely want to change n_channels and n_classes to match your values.

In [16]:
# -------------
# Network instantiation
# -------------
model = HandGestureNet(n_channels=66, n_classes=14)

Reduce the learning rate to get smoother accuracy curves.

In [17]:
# -----------------------------------------------------
# Loss function & Optimizer
# -----------------------------------------------------
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)

5. Training loop

In [18]:
# -------------
# Training
# -------------


def train(model, criterion, optimizer,
          x_train, y_train, x_test, y_test,
          force_cpu=False, num_epochs=5):
    
    # use a GPU (for speed) if you have one
    device = torch.device("cuda") if torch.cuda.is_available() and not force_cpu else torch.device("cpu")
    model = model.to(device)
    x_train, y_train, x_test, y_test = x_train.to(device), y_train.to(device), x_test.to(device), y_test.to(device)
    
    # (bonus) log accuracy values to visualize them in tensorboard:
    writer = SummaryWriter()
    
    # Prepare all mini-batches
    x_train_batches = batch(x_train)
    y_train_batches = batch(y_train)
    
    # Training starting time
    start = time.time()

    print('[INFO] Started to train the model.')
    print('Training the model on {}.'.format('GPU' if device == torch.device('cuda') else 'CPU'))
    
    for ep in range(num_epochs):

        # Ensure we're still in training mode
        model.train()

        current_loss = 0.0

        for idx_batch, train_batches in enumerate(zip(x_train_batches, y_train_batches)):

            # get a mini-batch of sequences
            x_train_batch, y_train_batch = train_batches

            # zero the gradient parameters
            optimizer.zero_grad()

            # forward
            outputs = model(x_train_batch)

            # backward + optimize
            # backward
            loss = criterion(outputs, y_train_batch)
            loss.backward()
            # optimize
            optimizer.step()
            # for an easy access
            current_loss += loss.item()
        
        train_acc = get_accuracy(model, x_train, y_train)
        test_acc = get_accuracy(model, x_test, y_test)
        
        writer.add_scalar('data/accuracy_train', train_acc, ep)
        writer.add_scalar('data/accuracy_test', test_acc, ep)
        print('Epoch #{:03d} | Time elapsed : {} | Loss : {:.4e} | Accuracy_train : {:.4e} | Accuracy_test : {:.4e}'.format(
                ep + 1, time_since(start), current_loss, train_acc, test_acc))

    print('[INFO] Finished training the model. Total time : {}.'.format(time_since(start)))

You can now train the model on your dataset!

Note: You can use the GPUs provided by Google Colab to make the training faster: go to the “runtime” dropdown menu, select “change runtime type” and select GPU in the hardware accelerator drop-down menu.

In [19]:
# Please adjust the training epochs count, and the other hyperparams (lr, dropout, ...), for a non-overfitted training according to your own needs.
# tip: use tensorboard to display the accuracy (see cells above for tensorboard usage)

num_epochs = 20

train(model=model, criterion=criterion, optimizer=optimizer,
      x_train=x_train, y_train=y_train_14, x_test=x_test, y_test=y_test_14,
      num_epochs=num_epochs)
[INFO] Started to train the model.
Training the model on GPU.
Epoch #001 | Time elapsed : 00m 41s | Loss : 2.8282e+02 | Accuracy_train : 3.6480e-01 | Accuracy_test : 3.0227e-01
Epoch #002 | Time elapsed : 01m 20s | Loss : 9.9518e+01 | Accuracy_train : 5.7806e-01 | Accuracy_test : 5.2210e-01
Epoch #003 | Time elapsed : 01m 59s | Loss : 6.8627e+01 | Accuracy_train : 7.1173e-01 | Accuracy_test : 6.4516e-01
Epoch #004 | Time elapsed : 02m 39s | Loss : 5.1430e+01 | Accuracy_train : 7.8367e-01 | Accuracy_test : 7.1804e-01
Epoch #005 | Time elapsed : 03m 19s | Loss : 4.0985e+01 | Accuracy_train : 8.1429e-01 | Accuracy_test : 7.3238e-01
Epoch #006 | Time elapsed : 04m 01s | Loss : 3.4248e+01 | Accuracy_train : 8.6276e-01 | Accuracy_test : 7.6822e-01
Epoch #007 | Time elapsed : 04m 42s | Loss : 3.0172e+01 | Accuracy_train : 8.7398e-01 | Accuracy_test : 7.7897e-01
Epoch #008 | Time elapsed : 05m 23s | Loss : 2.5808e+01 | Accuracy_train : 8.8827e-01 | Accuracy_test : 7.7061e-01
Epoch #009 | Time elapsed : 06m 05s | Loss : 2.4689e+01 | Accuracy_train : 8.9796e-01 | Accuracy_test : 7.6941e-01
Epoch #010 | Time elapsed : 06m 45s | Loss : 2.0552e+01 | Accuracy_train : 9.0816e-01 | Accuracy_test : 7.9809e-01
Epoch #011 | Time elapsed : 07m 28s | Loss : 1.7200e+01 | Accuracy_train : 9.3673e-01 | Accuracy_test : 8.4110e-01
Epoch #012 | Time elapsed : 08m 09s | Loss : 1.5713e+01 | Accuracy_train : 9.3061e-01 | Accuracy_test : 8.2796e-01
Epoch #013 | Time elapsed : 08m 49s | Loss : 1.4629e+01 | Accuracy_train : 9.4337e-01 | Accuracy_test : 8.3035e-01
Epoch #014 | Time elapsed : 09m 29s | Loss : 1.3621e+01 | Accuracy_train : 9.3163e-01 | Accuracy_test : 8.2318e-01
Epoch #015 | Time elapsed : 10m 10s | Loss : 1.2829e+01 | Accuracy_train : 9.4235e-01 | Accuracy_test : 7.9809e-01
Epoch #016 | Time elapsed : 10m 50s | Loss : 1.1798e+01 | Accuracy_train : 9.4031e-01 | Accuracy_test : 8.1601e-01
Epoch #017 | Time elapsed : 11m 30s | Loss : 1.0827e+01 | Accuracy_train : 9.4439e-01 | Accuracy_test : 8.1720e-01
Epoch #018 | Time elapsed : 12m 10s | Loss : 9.8947e+00 | Accuracy_train : 9.4694e-01 | Accuracy_test : 8.2796e-01
Epoch #019 | Time elapsed : 12m 53s | Loss : 9.5964e+00 | Accuracy_train : 9.5765e-01 | Accuracy_test : 8.2318e-01
Epoch #020 | Time elapsed : 13m 33s | Loss : 9.1779e+00 | Accuracy_train : 9.5459e-01 | Accuracy_test : 8.4110e-01
[INFO] Finished training the model. Total time : 13m 33s.

6. (When you're happy with the training:) Save the trained model

In [20]:
torch.save(model.state_dict(), 'gesture_pretrained_model.pt')

7. Get a trained model

In [0]:
# Reminder: first redefine/load the HandGestureNet class before you use it, if you want to use it elsewhere
model = HandGestureNet(n_channels=66, n_classes=14)
model.load_state_dict(torch.load('gesture_pretrained_model.pt'))
model.eval()

# make predictions
with torch.no_grad():
    demo_gesture_batch = torch.randn(32, 100, 66)
    predictions = model(demo_gesture_batch)
    _, predictions = predictions.max(dim=1)
    print("Predicted gesture classes: {}".format(predictions.tolist()))
Predicted gesture classes: [4, 4, 4, 4, 7, 10, 4, 4, 4, 10, 7, 4, 4, 4, 12, 4, 12, 12, 4, 4, 4, 4, 4, 4, 4, 7, 4, 4, 4, 4, 4, 4]
In [0]:
# play with the model!
In [0]: