Image Classification#
Introduction#
In this tutorial, we present how to use PyTorch
models in EvaDB to classify images. In particular, we focus on classifying images from the MNIST
dataset that contains digits
. EvaDB makes it easy to do image classification using its built-in support for PyTorch
models.
In this tutorial, besides classifying images, we will also showcase a query where the modelโs output is used to retrieve images with the digit 6
.
Prerequisites#
To follow along, you will need to set up a local instance of EvaDB via pip.
Connect to EvaDB#
After installing EvaDB, use the following Python code to establish a connection and obtain a cursor
for running EvaQL
queries.
import evadb
cursor = evadb.connect().cursor()
We will assume that the input MNIST
video is loaded into EvaDB
. To download the video and load it into EvaDB
, see the complete image classification notebook on Colab.
Create Image Classification Function#
To create a custom MnistImageClassifier
function, use the CREATE FUNCTION
statement. The code for the custom classification model is available here.
We will assume that the file is downloaded and stored as mnist_image_classifier.py
. Now, run the following query to register the AI function:
CREATE FUNCTION
IF NOT EXISTS MnistImageClassifier
IMPL 'mnist_image_classifier.py';
Image Classification Queries#
After the function is registered in EvaDB
, you can use it subsequent SQL queries in different ways.
In the following query, we call the classifier on every image in the video. The output of the function is stored in the label
column (i.e., the digit associated with the given frame) of the output DataFrame
.
SELECT MnistImageClassifier(data).label
FROM mnist_video;
This query returns the label of all the images:
+------------------------------+
| mnistimageclassifier.label |
|------------------------------|
| 6 |
| 6 |
| ... |
| ... |
| ... |
| ... |
| 4 |
| 4 |
+------------------------------+
Filtering Based on AI Function#
In the following query, we use the output of the classifier to retrieve a subset of images that contain a particular digit (e.g., 6
).
SELECT id, MnistImageClassifier(data).label
FROM mnist_video
WHERE MnistImageClassifier(data).label = '6';
Now, the DataFrame
only contains images of the digit 6
.
+------------------------------+
| mnistimageclassifier.label |
|------------------------------|
| 6 |
| 6 |
+------------------------------+
Whatโs Next?#
๐ If you are excited about our vision of bringing AI inside databases, consider:
๐ joining our Slack: https://evadb.ai/slack
๐ following us on Github: https://evadb.ai/github
๐ฆ following us on Twitter: https://evadb.ai/twitter
๐ following us on Medium: https://evadb.ai/blog
๐ฅ๏ธ contributing to EvaDB: https://evadb.ai/github

Language Models (๐ฆ) and Databases#