Custom Model Tutorial#

Run on Google Colab View source on GitHub Download notebook

Start EVA server#

We are reusing the start server notebook for launching the EVA server.

!wget -nc "https://raw.githubusercontent.com/georgia-tech-db/eva/master/tutorials/00-start-eva-server.ipynb"
%run 00-start-eva-server.ipynb
cursor = connect_to_server()
File ‘00-start-eva-server.ipynb’ already there; not retrieving.
nohup eva_server > eva.log 2>&1 &
Note: you may need to restart the kernel to use updated packages.

Download custom user-defined function (UDF), model, and video#

# Download UDF
!wget -nc https://www.dropbox.com/s/lharq14izp08bfz/gender.py

# Download built-in Face Detector
!wget -nc https://raw.githubusercontent.com/georgia-tech-db/eva/master/eva/udfs/face_detector.py

# Download models
!wget -nc https://www.dropbox.com/s/0y291evpqdfmv2z/gender.pth

# Download videos
!wget -nc https://www.dropbox.com/s/f5447euuuis1vdy/short.mp4
File ‘gender.py’ already there; not retrieving.
File ‘face_detector.py’ already there; not retrieving.
File ‘gender.pth’ already there; not retrieving.
File ‘short.mp4’ already there; not retrieving.

Load video for analysis#

cursor.execute("DROP TABLE TIKTOK;")
response = cursor.fetch_all()
print(response)
cursor.execute("LOAD VIDEO 'short.mp4' INTO TIKTOK;")
response = cursor.fetch_all()
print(response)
cursor.execute("""SELECT id FROM TIKTOK WHERE id < 5""")
response = cursor.fetch_all()
print(response)
@status: ResponseStatus.SUCCESS
@batch: 
                                     0
0  Table Successfully dropped: TIKTOK
@query_time: 0.04502507415600121
@status: ResponseStatus.SUCCESS
@batch: 
                            0
0  Number of loaded VIDEO: 1
@query_time: 0.06703275698237121
@status: ResponseStatus.SUCCESS
@batch: 
    tiktok.id
0          0
1          1
2          2
3          3
4          4
@query_time: 0.13041309895925224

Visualize Video#

from IPython.display import Video
Video("short.mp4", embed=True)

Create GenderCNN and FaceDetector UDFs#

cursor.execute("""DROP UDF GenderCNN;""")
response = cursor.fetch_all()
print(response)

cursor.execute("""CREATE UDF IF NOT EXISTS 
                  GenderCNN
                  INPUT (data NDARRAY UINT8(3, 224, 224)) 
                  OUTPUT (label TEXT(10)) 
                  TYPE  Classification 
                  IMPL 'gender.py';
        """)
response = cursor.fetch_all()
print(response)

cursor.execute("""CREATE UDF IF NOT EXISTS
                  FaceDetector
                  INPUT  (frame NDARRAY UINT8(3, ANYDIM, ANYDIM))
                  OUTPUT (bboxes NDARRAY FLOAT32(ANYDIM, 4),
                          scores NDARRAY FLOAT32(ANYDIM))
                  TYPE  FaceDetection
                  IMPL  'face_detector.py';
        """)
response = cursor.fetch_all()
print(response)
@status: ResponseStatus.SUCCESS
@batch: 
                                     0
0  UDF GenderCNN successfully dropped
@query_time: 0.0174811570905149
@status: ResponseStatus.SUCCESS
@batch: 
                                                    0
0  UDF GenderCNN successfully added to the database.
@query_time: 3.761177452048287
@status: ResponseStatus.SUCCESS
@batch: 
                                                  0
0  UDF FaceDetector already exists, nothing added.
@query_time: 0.009986212942749262

Run Face Detector on video#

cursor.execute("""SELECT id, FaceDetector(data).bboxes 
                  FROM TIKTOK WHERE id < 10""")
response = cursor.fetch_all()
print(response)
@status: ResponseStatus.SUCCESS
@batch: 
    tiktok.id                              facedetector.bboxes
0          0      [[ 90.70622 208.44966 281.64642 457.68872]]
1          1      [[ 91.01816 208.27583 281.0808  457.91995]]
2          2  [[ 90.358536 207.3743   283.4399   457.96234 ]]
3          3  [[ 90.694214 207.56027  284.37817  458.6282  ]]
4          4  [[ 90.684944 208.98653  282.1281   460.90894 ]]
5          5      [[ 89.47423 209.38083 283.45938 460.58548]]
6          6      [[ 88.50081 208.31546 283.29172 461.8374 ]]
7          7  [[ 89.838646 206.07619  282.93942  464.7494  ]]
8          8      [[ 90.18522 224.35588 281.29733 469.89603]]
9          9      [[ 94.34447 234.13255 279.6476  468.85303]]
@query_time: 1.7909722931217402

Composing UDFs in a query#

Detect gender of the faces detected in the video by composing a set of UDFs (GenderCNN, FaceDetector, and Crop)

cursor.execute("""SELECT id, bbox, GenderCNN(Crop(data, bbox)) 
                  FROM TIKTOK JOIN LATERAL  UNNEST(FaceDetector(data)) AS Face(bbox, conf)  
                  WHERE id < 50;""")
response = cursor.fetch_all()
print(response)
@status: ResponseStatus.SUCCESS
@batch: 
     tiktok.id                                     Face.bbox gendercnn.label
0           0   [90.70624, 208.44968, 281.64642, 457.68872]          female
1           1    [91.01816, 208.27583, 281.0808, 457.91992]          female
2           2   [90.358536, 207.3743, 283.43994, 457.96234]          female
3           3   [90.694214, 207.56027, 284.37817, 458.6282]          female
4           4   [90.684944, 208.98653, 282.1281, 460.90894]          female
5           5   [89.47423, 209.38083, 283.45938, 460.58545]          female
6           6   [88.50081, 208.31546, 283.29172, 461.83743]          female
7           7    [89.83865, 206.07619, 282.93942, 464.7494]          female
8           8    [90.18519, 224.35585, 281.2973, 469.89606]          female
9           9    [94.34447, 234.13254, 279.6476, 468.85303]          female
10         10   [94.53462, 231.94533, 280.37552, 469.60095]          female
11         11   [93.62811, 232.48692, 278.80774, 470.71677]          female
12         12    [94.5706, 232.88577, 280.19693, 469.20734]          female
13         13    [94.18951, 226.97621, 281.2876, 468.45206]          female
14         14  [93.782196, 225.13283, 281.57428, 469.45212]          female
15         15   [92.72016, 222.57924, 281.52145, 471.10934]          female
16         16   [91.76486, 220.04295, 282.50293, 472.32422]          female
17         17     [91.180595, 219.383, 282.56488, 472.7332]          female
18         18   [91.45817, 224.86871, 280.40808, 471.70938]          female
19         19   [91.75995, 229.18222, 278.51724, 470.76422]          female
20         20   [90.86253, 228.00526, 277.29852, 469.97522]          female
21         21    [86.87827, 220.22151, 278.28793, 474.4017]          female
22         22    [86.17063, 220.2833, 277.47998, 473.55865]          female
23         23     [87.24197, 223.13232, 276.06287, 472.329]          female
24         24      [85.91275, 221.83832, 276.25464, 473.94]          female
25         25   [86.46627, 223.12836, 276.40482, 473.91782]          female
26         26   [87.90794, 222.48033, 277.04114, 472.63095]          female
27         27   [87.26338, 222.81485, 277.85394, 472.65347]          female
28         28   [89.96093, 222.24153, 278.90247, 471.04422]          female
29         29   [92.93111, 221.20155, 279.88617, 468.67712]          female
30         30     [95.86487, 222.3673, 280.08804, 468.6138]          female
31         31   [97.352905, 222.22885, 282.08548, 470.4421]          female
32         32     [98.23183, 219.5644, 286.48532, 472.6992]          female
33         33     [99.83777, 223.5303, 286.11328, 472.0794]          female
34         34    [98.918564, 224.1231, 287.2161, 471.09912]          female
35         35     [99.63552, 223.40047, 288.5786, 471.4875]          female
36         36   [102.69069, 223.99548, 288.7781, 471.87228]          female
37         37    [101.27347, 223.9684, 290.67612, 472.4894]          female
38         38  [100.11153, 213.72807, 292.49606, 472.19666]          female
39         39   [98.218315, 210.41318, 293.1625, 473.68002]          female
40         40    [98.33731, 211.89398, 292.41443, 472.6961]          female
41         41    [97.3301, 211.76442, 291.71097, 472.21356]          female
42         42    [96.33257, 211.30165, 291.4119, 472.43082]          female
43         43  [96.392715, 212.37268, 290.87326, 471.60538]          female
44         44    [96.0622, 212.98589, 289.65698, 470.89297]          female
45         45  [94.273346, 213.00847, 289.17032, 470.96674]          female
46         46    [94.76167, 213.07986, 289.17407, 470.6863]          female
47         47   [94.89173, 214.84763, 288.04193, 470.53317]          female
48         48    [95.12451, 217.9163, 285.73044, 470.59174]          female
49         49  [95.080414, 221.98688, 285.05048, 471.19278]          female
@query_time: 2.2007903130725026

Visualize Output#

import cv2
from matplotlib import pyplot as plt

def annotate_video(detections, input_video_path, output_video_path):
    color=(207, 248, 64)
    thickness=4

    vcap = cv2.VideoCapture(input_video_path)
    width = int(vcap.get(3))
    height = int(vcap.get(4))
    fps = vcap.get(5)
    fourcc = cv2.VideoWriter_fourcc(*'MP4V') #codec
    video=cv2.VideoWriter(output_video_path, fourcc, fps, (width,height))

    frame_id = 0
    # Capture frame-by-frame
    ret, frame = vcap.read()  # ret = 1 if the video is captured; frame is the image

    while ret:
        df = detections
        df = df[['Face.bbox', 'gendercnn.label']][df['tiktok.id'] == frame_id]
        
        if df.size:
            for bbox, label in df.values:
                x1, y1, x2, y2 = bbox
                x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
                frame=cv2.rectangle(frame, (x1, y1), (x2, y2), color, thickness) # object bbox
                cv2.putText(frame, str(label), (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, thickness-1) # object label
            
            video.write(frame)
            # Show every fifth frame
            if frame_id % 5 == 0:
                plt.imshow(frame)
                plt.show()

        if frame_id == 50:
            return

        frame_id+=1
        ret, frame = vcap.read()

    video.release()
    vcap.release()
#!pip install ipywidgets
from ipywidgets import Video
input_path = 'short.mp4'
output_path = 'annotated_short.mp4'

dataframe = response.batch.frames
annotate_video(dataframe, input_path, output_path)
../../_images/e78f4c48bdbb6fd686d6bbfa32520519a2df3bd7a5d4c63c35b53318c29b90d3.png ../../_images/586c7b103200559fa76f3718e0d83bc785be00e76048c39357d15b2e20a57d13.png ../../_images/9d9dafe7e354bdd2186bb2b0cc7995164a7a9f3ed0031d33e0165729d5fe2ac8.png ../../_images/1ec83780299af661b836e7f67f6bab530696d1d3b1148416ec2aa1b287d2e345.png ../../_images/cc6dfff5f0e79363bf23f88e718819411ae00543f129219cd6294113f1d07062.png ../../_images/ffdf1cb67a492284fe712059fe56b6696c3d4c378688f750526695b578533cdc.png ../../_images/b9cc94ffc7d110d2a0acbf7e0da6999c5c6a47cbdc8e6f387c4aecfd40556070.png ../../_images/ab60550a621337216d907a0561902ebae0d788cab79fee3e7eefe04d90b76830.png ../../_images/b396ee2dfdb7d6664a4262aad8089ce58bc77f7bdaffd579ed3f977d308212c0.png ../../_images/bc9a268728399be4b80c4313c184026eb6fd918a62d2459b2ce3db90d5b44213.png