Image Classification#

Run on Google Colab View source on GitHub Download notebook


Introduction#

In this tutorial, we present how to use PyTorch models in EvaDB to classify images. In particular, we focus on classifying images from the MNIST dataset that contains digits. EvaDB makes it easy to do image classification using its built-in support for PyTorch models.

In this tutorial, besides classifying images, we will also showcase a query where the modelโ€™s output is used to retrieve images with the digit 6.

Prerequisites#

To follow along, you will need to set up a local instance of EvaDB via pip.

Connect to EvaDB#

After installing EvaDB, use the following Python code to establish a connection and obtain a cursor for running EvaQL queries.

import evadb
cursor = evadb.connect().cursor()

We will assume that the input MNIST video is loaded into EvaDB. To download the video and load it into EvaDB, see the complete image classification notebook on Colab.

Create Image Classification Function#

To create a custom MnistImageClassifier function, use the CREATE FUNCTION statement. The code for the custom classification model is available here.

We will assume that the file is downloaded and stored as mnist_image_classifier.py. Now, run the following query to register the AI function:

CREATE FUNCTION
IF NOT EXISTS MnistImageClassifier
IMPL 'mnist_image_classifier.py';

Image Classification Queries#

After the function is registered in EvaDB, you can use it subsequent SQL queries in different ways.

In the following query, we call the classifier on every image in the video. The output of the function is stored in the label column (i.e., the digit associated with the given frame) of the output DataFrame.

SELECT MnistImageClassifier(data).label
FROM mnist_video;

This query returns the label of all the images:

+------------------------------+
|   mnistimageclassifier.label |
|------------------------------|
|                            6 |
|                            6 |
|                          ... |
|                          ... |
|                          ... |
|                          ... |
|                            4 |
|                            4 |
+------------------------------+

Filtering Based on AI Function#

In the following query, we use the output of the classifier to retrieve a subset of images that contain a particular digit (e.g., 6).

SELECT id, MnistImageClassifier(data).label
    FROM mnist_video
    WHERE MnistImageClassifier(data).label = '6';

Now, the DataFrame only contains images of the digit 6.

+------------------------------+
|   mnistimageclassifier.label |
|------------------------------|
|                            6 |
|                            6 |
+------------------------------+

Whatโ€™s Next?#

๐Ÿ‘‹ If you are excited about our vision of bringing AI inside databases, consider:

Language Models (๐Ÿฆ™) and Databases

Language Models (๐Ÿฆ™) and Databases#