This notebook is a demo pytorch implementation of the deep learning model for hand gesture recognition introduced in the article Deep Learning for Hand Gesture Recognition on Skeletal Data from G. Devineau, F. Moutarde, W. Xi and J. Yang.
If you find this code useful in your research, please consider citing:
@inproceedings{devineau2018deep,
title={Deep learning for hand gesture recognition on skeletal data},
author={Devineau, Guillaume and Moutarde, Fabien and Xi, Wang and Yang, Jie},
booktitle={2018 13th IEEE International Conference on Automatic Face \& Gesture Recognition (FG 2018)},
pages={106--113},
year={2018},
organization={IEEE}
}
Fore a more detailed introduction to the model, please refer to this quick introduction to the model or read the original article.
For keras users, a simplified keras version of the current colab is also available for convenience.
#!/usr/bin/python
# -*- coding: utf-8 -*-
from __future__ import unicode_literals, print_function, division
import sys
if sys.version_info.major < 3:
print('You are using python 2, but you should rather use python 3.')
print(' If you still want to use python 2, ensure you import:')
print(' >> from __future__ import unicode_literals, print_function, division')
If you encounter an python error regarding a missing module at some point, uncomment the appropriate line in the cell below and run it.
# Uncomment the following lines to install modules if needed:
# -----------------------------------------------------------
# --- required: numpy + sklearn + scipy:
# !{sys.executable} -m pip install numpy scipy sklearn
# --- required: torch
# !{sys.executable} -m pip install torch # <<------ advice: do not use this pip command and install torch via conda, with CUDA: see: https://pytorch.org/get-started/locally/
# --- bonus: tensorboardX:
# !{sys.executable} -m pip install tensorflow tensorflow-gpu
# !{sys.executable} -m pip install tensorboardX
# --- bonus: jupyter-lab:
# !{sys.executable} -m pip install matplotlib ipython jupyter jupyter-lab pandas tqdm
# !{sys.executable} -m pip install ipywidgets
# !jupyter nbextension enable --py widgetsnbextension
import itertools
import numpy
import torch
import pickle
from scipy import ndimage as ndimage
from sklearn.utils import shuffle
import time
import math
# (bonus) plot acc with tensorboard
# Command to start tensorboard if installed (requires tensorflow):
# $ tensorboard --logdir ./runs
try:
from tensorboardX import SummaryWriter
except:
# tensorboardX is not installed, just fail silently
class SummaryWriter():
def __init__(self):
pass
def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
pass
print('Using python {}.{}, with modules versions'.format(sys.version_info.major, sys.version_info.minor))
print('-'*40)
print('numpy == {}'.format(numpy.__version__))
print('torch == {}'.format(torch.__version__)) # please use version 1.0 or above
We define some functions we'll need later on.
To train the model on your own custom hand gesture dataset, edit the cells below. You'll need to adapt load_data()
, shuffle_dataset()
, preprocess_data()
and convert_to_pytorch_tensors()
according to your needs.
The load_data()
function should return the gesture tensors x
and the labels y
.
The shape of the x
tensor should be (dataset_size, duration, n_channels)
where n_channels = 3 * n_joints
for 3D pose data.
If your dataset does not fit into memory, you can use the straightforward pytorch Dataset class:
from torch.utils.data.dataset import Dataset
class MyCustomDataset(Dataset):
def __init__(self, root_dir):
self.sequences_names = ... # list sequences
self.root_dir = root_dir
def __len__(self):
return len(self.sequences_names)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
sample = ...load_somehow(self.sequences_names[idx]) # load sequence
return sample
If you don't have your own gesture dataset, you might want to download one of these hand gesture datasets:
or here (ready-to-use dataset, pre-loaded into a single pickle file shrec_data.pckl
):
# Option 1: Custom Dataset
# ------------------------
# On the left bar of this colaboratory notebook there is a section called "Files".
# Upload your files there and use a path like "/content/each_file_you_just_uploaded" to load your data, e.g. in the load_data() function below
# Option 2: Download the SHREC17 dataset
# (either download the file manually or use the following wget command:)
!wget https://cloud.mines-paristech.fr/index.php/s/9U4bjHrvp8u2pnS/download -O ./shrec_data.pckl
def load_data(filepath='./shrec2017_skel-data.pckl'):
"""
Returns hand gesture sequences (X) and their associated labels (Y).
Each sequence has two different labels.
The first label Y describes the gesture class out of 14 possible gestures (e.g. swiping your hand to the right).
The second label Y describes the gesture class out of 28 possible gestures (e.g. swiping your hand to the right with your index pointed, or not pointed).
"""
file = open(filepath, 'rb')
data = pickle.load(file, encoding='latin1') # <<---- change to 'latin1' to 'utf8' if the data does not load
file.close()
return data['x_train'], data['x_test'], data['y_train_14'], data['y_train_28'], data['y_test_14'], data['y_test_28']
def resize_sequences_length(x_train, x_test, final_length=100):
"""
Resize the time series by interpolating them to the same length
"""
# please use python3. if you still use python2, important note: redefine the classic division operator / by importing it from the __future__ module
x_train = numpy.array([numpy.array([ndimage.zoom(x_i.T[j], final_length / len(x_i), mode='reflect') for j in range(numpy.size(x_i, 1))]).T for x_i in x_train])
x_test = numpy.array([numpy.array([ndimage.zoom(x_i.T[j], final_length / len(x_i), mode='reflect') for j in range(numpy.size(x_i, 1)) ]).T for x_i in x_test])
return x_train, x_test
def shuffle_dataset(x_train, x_test, y_train_14, y_train_28, y_test_14, y_test_28):
"""Shuffle the train/test data consistently."""
# note: add random_state=0 for reproducibility
x_train, y_train_14, y_train_28 = shuffle(x_train, y_train_14, y_train_28)
x_test, y_test_14, y_test_28 = shuffle(x_test, y_test_14, y_test_28)
return x_train, x_test, y_train_14, y_train_28, y_test_14, y_test_28
def preprocess_data(x_train, x_test, y_train_14, y_train_28, y_test_14, y_test_28):
"""
Preprocess the data as you want: update as you want!
- possible improvement idea: make a PCA here
"""
x_train, x_test, y_train_14, y_train_28, y_test_14, y_test_28 = shuffle_dataset(x_train, x_test, y_train_14, y_train_28, y_test_14, y_test_28)
x_train, x_test = resize_sequences_length(x_train, x_test, final_length=100)
return x_train, x_test, y_train_14, y_train_28, y_test_14, y_test_28
def convert_to_pytorch_tensors(x_train, x_test, y_train_14, y_train_28, y_test_14, y_test_28):
# as numpy
y_train_14, y_train_28, y_test_14, y_test_28 = numpy.array(y_train_14), numpy.array(y_train_28), numpy.array(y_test_14), numpy.array(y_test_28)
# -- REQUIRED by the pytorch loss function implementation --
# Remove 1 to all classes items (1-14 => 0-13 and 1-28 => 0-27)
y_train_14, y_train_28, y_test_14, y_test_28 = y_train_14 - 1, y_train_28 - 1, y_test_14 - 1, y_test_28 - 1
# as torch
x_train, x_test = torch.from_numpy(x_train), torch.from_numpy(x_test)
y_train_14, y_train_28, y_test_14, y_test_28 = torch.from_numpy(y_train_14), torch.from_numpy(y_train_28), torch.from_numpy(y_test_14), torch.from_numpy(y_test_28)
# -- REQUIRED by the pytorch loss function implementation --
# correct the data type (for the loss function used)
x_train, x_test = x_train.type(torch.FloatTensor), x_test.type(torch.FloatTensor)
y_train_14, y_train_28, y_test_14, y_test_28 = y_train_14.type(torch.LongTensor), y_train_28.type(torch.LongTensor), y_test_14.type(torch.LongTensor), y_test_28.type(torch.LongTensor)
return x_train, x_test, y_train_14, y_train_28, y_test_14, y_test_28
# -------------
# Misc.
# -------------
def batch(tensor, batch_size=32):
"""Return a list of (mini) batches"""
tensor_list = []
length = tensor.shape[0]
i = 0
while True:
if (i + 1) * batch_size >= length:
tensor_list.append(tensor[i * batch_size: length])
return tensor_list
tensor_list.append(tensor[i * batch_size: (i + 1) * batch_size])
i += 1
def time_since(since):
now = time.time()
s = now - since
m = math.floor(s / 60)
s -= m * 60
return '{:02d}m {:02d}s'.format(int(m), int(s))
def get_accuracy(model, x, y_ref):
"""Get the accuracy of the pytorch model on a batch"""
acc = 0.
model.eval()
with torch.no_grad():
predicted = model(x)
_, predicted = predicted.max(dim=1)
acc = 1.0 * (predicted == y_ref).sum().item() / y_ref.shape[0]
return acc
class HandGestureNet(torch.nn.Module):
"""
[Devineau et al., 2018] Deep Learning for Hand Gesture Recognition on Skeletal Data
Summary
-------
Deep Learning Model for Hand Gesture classification using pose data only (no need for RGBD)
The model computes a succession of [convolutions and pooling] over time independently on each of the 66 (= 22 * 3) sequence channels.
Each of these computations are actually done at two different resolutions, that are later merged by concatenation
with the (pooled) original sequence channel.
Finally, a multi-layer perceptron merges all of the processed channels and outputs a classification.
TL;DR:
------
input ------------------------------------------------> split into n_channels channels [channel_i]
channel_i ----------------------------------------> 3x [conv/pool/dropout] low_resolution_i
channel_i ----------------------------------------> 3x [conv/pool/dropout] high_resolution_i
channel_i ----------------------------------------> pooled_i
low_resolution_i, high_resolution_i, pooled_i ----> output_channel_i
MLP(n_channels x [output_channel_i]) -------------------------> classification
Article / PDF:
--------------
https://ieeexplore.ieee.org/document/8373818
Please cite:
------------
@inproceedings{devineau2018deep,
title={Deep learning for hand gesture recognition on skeletal data},
author={Devineau, Guillaume and Moutarde, Fabien and Xi, Wang and Yang, Jie},
booktitle={2018 13th IEEE International Conference on Automatic Face \& Gesture Recognition (FG 2018)},
pages={106--113},
year={2018},
organization={IEEE}
}
"""
def __init__(self, n_channels=66, n_classes=14, dropout_probability=0.2):
super(HandGestureNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.dropout_probability = dropout_probability
# Layers ----------------------------------------------
self.all_conv_high = torch.nn.ModuleList([torch.nn.Sequential(
torch.nn.Conv1d(in_channels=1, out_channels=8, kernel_size=7, padding=3),
torch.nn.ReLU(),
torch.nn.AvgPool1d(2),
torch.nn.Conv1d(in_channels=8, out_channels=4, kernel_size=7, padding=3),
torch.nn.ReLU(),
torch.nn.AvgPool1d(2),
torch.nn.Conv1d(in_channels=4, out_channels=4, kernel_size=7, padding=3),
torch.nn.ReLU(),
torch.nn.Dropout(p=self.dropout_probability),
torch.nn.AvgPool1d(2)
) for joint in range(n_channels)])
self.all_conv_low = torch.nn.ModuleList([torch.nn.Sequential(
torch.nn.Conv1d(in_channels=1, out_channels=8, kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.AvgPool1d(2),
torch.nn.Conv1d(in_channels=8, out_channels=4, kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.AvgPool1d(2),
torch.nn.Conv1d(in_channels=4, out_channels=4, kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.Dropout(p=self.dropout_probability),
torch.nn.AvgPool1d(2)
) for joint in range(n_channels)])
self.all_residual = torch.nn.ModuleList([torch.nn.Sequential(
torch.nn.AvgPool1d(2),
torch.nn.AvgPool1d(2),
torch.nn.AvgPool1d(2)
) for joint in range(n_channels)])
self.fc = torch.nn.Sequential(
torch.nn.Linear(in_features=9 * n_channels * 12, out_features=1936), # <-- 12: depends of the sequences lengths (cf. below)
torch.nn.ReLU(),
torch.nn.Linear(in_features=1936, out_features=n_classes)
)
# Initialization --------------------------------------
# Xavier init
for module in itertools.chain(self.all_conv_high, self.all_conv_low, self.all_residual):
for layer in module:
if layer.__class__.__name__ == "Conv1d":
torch.nn.init.xavier_uniform_(layer.weight, gain=torch.nn.init.calculate_gain('relu'))
torch.nn.init.constant_(layer.bias, 0.1)
for layer in self.fc:
if layer.__class__.__name__ == "Linear":
torch.nn.init.xavier_uniform_(layer.weight, gain=torch.nn.init.calculate_gain('relu'))
torch.nn.init.constant_(layer.bias, 0.1)
def forward(self, input):
"""
This function performs the actual computations of the network for a forward pass.
Arguments
---------
input: a tensor of gestures of shape (batch_size, duration, n_channels)
(where n_channels = 3 * n_joints for 3D pose data)
"""
# Work on each channel separately
all_features = []
for channel in range(0, self.n_channels):
input_channel = input[:, :, channel]
# Add a dummy (spatial) dimension for the time convolutions
# Conv1D format : (batch_size, n_feature_maps, duration)
input_channel = input_channel.unsqueeze(1)
high = self.all_conv_high[channel](input_channel)
low = self.all_conv_low[channel](input_channel)
ap_residual = self.all_residual[channel](input_channel)
# Time convolutions are concatenated along the feature maps axis
output_channel = torch.cat([
high,
low,
ap_residual
], dim=1)
all_features.append(output_channel)
# Concatenate along the feature maps axis
all_features = torch.cat(all_features, dim=1)
# Flatten for the Linear layers
all_features = all_features.view(-1, 9 * self.n_channels * 12) # <-- 12: depends of the initial sequence length (100).
# If you have shorter/longer sequences, you probably do NOT even need to modify the modify the network architecture:
# resampling your input gesture from T timesteps to 100 timesteps will (surprisingly) probably actually work as well!
# Fully-Connected Layers
output = self.fc(all_features)
return output
If you use a custom dataset, you'll likely need to change the data loading part here.
# -------------
# Data
# -------------
# Load the dataset
x_train, x_test, y_train_14, y_train_28, y_test_14, y_test_28 = load_data()
# Shuffle sequences and resize sequences
x_train, x_test, y_train_14, y_train_28, y_test_14, y_test_28 = preprocess_data(x_train, x_test, y_train_14, y_train_28, y_test_14, y_test_28)
# Convert to pytorch variables
x_train, x_test, y_train_14, y_train_28, y_test_14, y_test_28 = convert_to_pytorch_tensors(x_train, x_test, y_train_14, y_train_28, y_test_14, y_test_28)
If you use a custom dataset, you'll likely want to change n_channels
and n_classes
to match your values.
# -------------
# Network instantiation
# -------------
model = HandGestureNet(n_channels=66, n_classes=14)
Reduce the learning rate to get smoother accuracy curves.
# -----------------------------------------------------
# Loss function & Optimizer
# -----------------------------------------------------
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
# -------------
# Training
# -------------
def train(model, criterion, optimizer,
x_train, y_train, x_test, y_test,
force_cpu=False, num_epochs=5):
# use a GPU (for speed) if you have one
device = torch.device("cuda") if torch.cuda.is_available() and not force_cpu else torch.device("cpu")
model = model.to(device)
x_train, y_train, x_test, y_test = x_train.to(device), y_train.to(device), x_test.to(device), y_test.to(device)
# (bonus) log accuracy values to visualize them in tensorboard:
writer = SummaryWriter()
# Prepare all mini-batches
x_train_batches = batch(x_train)
y_train_batches = batch(y_train)
# Training starting time
start = time.time()
print('[INFO] Started to train the model.')
print('Training the model on {}.'.format('GPU' if device == torch.device('cuda') else 'CPU'))
for ep in range(num_epochs):
# Ensure we're still in training mode
model.train()
current_loss = 0.0
for idx_batch, train_batches in enumerate(zip(x_train_batches, y_train_batches)):
# get a mini-batch of sequences
x_train_batch, y_train_batch = train_batches
# zero the gradient parameters
optimizer.zero_grad()
# forward
outputs = model(x_train_batch)
# backward + optimize
# backward
loss = criterion(outputs, y_train_batch)
loss.backward()
# optimize
optimizer.step()
# for an easy access
current_loss += loss.item()
train_acc = get_accuracy(model, x_train, y_train)
test_acc = get_accuracy(model, x_test, y_test)
writer.add_scalar('data/accuracy_train', train_acc, ep)
writer.add_scalar('data/accuracy_test', test_acc, ep)
print('Epoch #{:03d} | Time elapsed : {} | Loss : {:.4e} | Accuracy_train : {:.4e} | Accuracy_test : {:.4e}'.format(
ep + 1, time_since(start), current_loss, train_acc, test_acc))
print('[INFO] Finished training the model. Total time : {}.'.format(time_since(start)))
You can now train the model on your dataset!
Note: You can use the GPUs provided by Google Colab to make the training faster: go to the “runtime” dropdown menu, select “change runtime type” and select GPU in the hardware accelerator drop-down menu.
# Please adjust the training epochs count, and the other hyperparams (lr, dropout, ...), for a non-overfitted training according to your own needs.
# tip: use tensorboard to display the accuracy (see cells above for tensorboard usage)
num_epochs = 20
train(model=model, criterion=criterion, optimizer=optimizer,
x_train=x_train, y_train=y_train_14, x_test=x_test, y_test=y_test_14,
num_epochs=num_epochs)
torch.save(model.state_dict(), 'gesture_pretrained_model.pt')
# Reminder: first redefine/load the HandGestureNet class before you use it, if you want to use it elsewhere
model = HandGestureNet(n_channels=66, n_classes=14)
model.load_state_dict(torch.load('gesture_pretrained_model.pt'))
model.eval()
# make predictions
with torch.no_grad():
demo_gesture_batch = torch.randn(32, 100, 66)
predictions = model(demo_gesture_batch)
_, predictions = predictions.max(dim=1)
print("Predicted gesture classes: {}".format(predictions.tolist()))
# play with the model!