element14 Community
element14 Community
    Register Log In
  • Site
  • Search
  • Log In Register
  • Community Hub
    Community Hub
    • What's New on element14
    • Feedback and Support
    • Benefits of Membership
    • Personal Blogs
    • Members Area
    • Achievement Levels
  • Learn
    Learn
    • Ask an Expert
    • eBooks
    • element14 presents
    • Learning Center
    • Tech Spotlight
    • STEM Academy
    • Webinars, Training and Events
    • Learning Groups
  • Technologies
    Technologies
    • 3D Printing
    • FPGA
    • Industrial Automation
    • Internet of Things
    • Power & Energy
    • Sensors
    • Technology Groups
  • Challenges & Projects
    Challenges & Projects
    • Design Challenges
    • element14 presents Projects
    • Project14
    • Arduino Projects
    • Raspberry Pi Projects
    • Project Groups
  • Products
    Products
    • Arduino
    • Avnet Boards Community
    • Dev Tools
    • Manufacturers
    • Multicomp Pro
    • Product Groups
    • Raspberry Pi
    • RoadTests & Reviews
  • Store
    Store
    • Visit Your Store
    • Choose another store...
      • Europe
      •  Austria (German)
      •  Belgium (Dutch, French)
      •  Bulgaria (Bulgarian)
      •  Czech Republic (Czech)
      •  Denmark (Danish)
      •  Estonia (Estonian)
      •  Finland (Finnish)
      •  France (French)
      •  Germany (German)
      •  Hungary (Hungarian)
      •  Ireland
      •  Israel
      •  Italy (Italian)
      •  Latvia (Latvian)
      •  
      •  Lithuania (Lithuanian)
      •  Netherlands (Dutch)
      •  Norway (Norwegian)
      •  Poland (Polish)
      •  Portugal (Portuguese)
      •  Romania (Romanian)
      •  Russia (Russian)
      •  Slovakia (Slovak)
      •  Slovenia (Slovenian)
      •  Spain (Spanish)
      •  Sweden (Swedish)
      •  Switzerland(German, French)
      •  Turkey (Turkish)
      •  United Kingdom
      • Asia Pacific
      •  Australia
      •  China
      •  Hong Kong
      •  India
      •  Korea (Korean)
      •  Malaysia
      •  New Zealand
      •  Philippines
      •  Singapore
      •  Taiwan
      •  Thailand (Thai)
      • Americas
      •  Brazil (Portuguese)
      •  Canada
      •  Mexico (Spanish)
      •  United States
      Can't find the country/region you're looking for? Visit our export site or find a local distributor.
  • Translate
  • Profile
  • Settings
Artificial Intelligence and Machine Learning
  • Technologies
  • More
Artificial Intelligence and Machine Learning
Blog Accelerating the MediaPipe models with Vitis-AI 3.5
  • Blog
  • Forum
  • Documents
  • Events
  • Polls
  • Files
  • Members
  • Mentions
  • Sub-Groups
  • Tags
  • More
  • Cancel
  • New
Join Artificial Intelligence and Machine Learning to participate - click to join for free!
  • Share
  • More
  • Cancel
Group Actions
  • Group RSS
  • More
  • Cancel
Engagement
  • Author Author: albertabeef
  • Date Created: 26 Aug 2024 12:25 PM Date Created
  • Views 1394 views
  • Likes 8 likes
  • Comments 2 comments
  • ultra96-v2
  • tria
  • ultrazed-ev
  • palm detection
  • vitis-ai 3.5
  • mediapipe
  • object detection
  • zynq
  • artificial intelligence
  • datasets
  • landmark detection
  • ultrascale+
  • hand landmarks
  • machine learning
  • vitis-ai
  • amd
  • ZUBoard
  • VEK280
  • Versal AI Edge
Related
Recommended

Accelerating the MediaPipe models with Vitis-AI 3.5

albertabeef
albertabeef
26 Aug 2024
MediaPipe Series - Part 3 - Accelerating the MediaPipe models with Vitis-AI 3.5

An exploration of accelerating the MediaPipe models with Vitis-AI 3.5.

Introduction

This project is part of a series on the subject of deploying the MediaPipe models to the edge on embedded platforms.

If you have not already read part 1 of this series, I urge you to start here:

  • [Hackster] Blazing Fast Models

In this project, I start by giving a recap of the challenges that can be expected when deploying the MediaPipe models, specifically with Vitis-AI 3.5.

Then I will address these challenges one by one, before deploying the models with Vitis-AI 3.5.

Finally, I will perform profiling to determine if our goal of acceleration was achieved.

Vitis-AI 3.5 Flow Overview

AMD’s Vitis-AI 3.5 workflow allows users to deploy models to AMD’s SoC embedded devices, such as Zynq UltraScale+ and Versal AI Edge.

image
Vitis AI 3.5–1000 Foot View (Camera: AMD)

This workflow supports the following frameworks:

  • TensorFlow
  • TensorFlow 2
  • PyTorch

The deployment involves the following tasks:

  • Model Inspection (not shown in flow diagram)
  • AI Quantization
  • AI Optimization (optional)
  • AI Compilation

The Model Inspection task allows the user to identify unsupported layers, or sequence of layers in the model that are not supported by the compiler. This step is crucial when training our own custom model, since we can adapt the model architecture to use layers that are supported by the target compiler prior to training, thus saving hours (or days) in our deployment flow.

The AI Quantization task analyzes the dynamic ranges of each layer in the model and converts the data types from floating point to fixed point. The DPU engine we are targeting supports 8 bit integer. In order to perform this analysis and conversion, a sub-set of the training dataset is required. The size of the required calibration data is in the order of several 100s to 1000s of samples.

image
Vitis AI 3.5 — Quantization WorkFlow (Camera: AMD)

As illustrated in the previous workflow graph, in certain cases (one of which I ran into was the MobileNet-V2 classification model), post training quantization may not yield sufficient accuracy and requires quantize-aware training (QAT). Fortunately, this was not the case for the MediaPipe models.

The AI Optimization task prunes weights that have minimal impact on the model’s accuracy, thus reducing the compute requirements of the model. This task is strongly recommended for large models, such as ResNet, that were architected solely for accuracy at the expense of redundancy. For smaller compact models, such as MobileNet, that were architected to be resource efficient, the Optimizer will not provide much gain, and is thus not required. For this project, we will not be using the AI Optimizer.

The AI Compilation task converts the quantized model to micro-code that can be run on the DPU engine.

Challenges of deploying MediaPipe with Vitis-AI 3.5

The first challenge that I encountered, in part 1, was the reality that the performance of the MediaPipe models significantly degrades when run on embedded platforms, compared to modern computers. This is the reason I am attempting to accelerate the models with Vitis-AI.

The second challenge is the fact that Google does not provide the dataset that was used to train the MediaPipe models. Since quantization requires a subset of this training data, this presents us with the challenge of coming up with this data ourselves.

The third challenge encountered when deploying the MediaPipe models to Vitis-AI 3.5 is that TFLite is not supported. This forces us to consider the use of an alternate framework, and the additional problem of model conversion.

Converting the TFLite models to PyTorch

The choice of converting the models to PyTorch was motivated mainly by existing efforts in the open-source community.

image
TFLite to PyTorch conversion (Camera: AlbertaBeef)

I was fortunate to find the following resources which convert the TFLite models to PyTorch:

  • [Vidur Satija] BlazePalm : vidursatija/BlazePalm
  • [Matthijs Hollemans] BlazeFace-PyTorch : hollance/BlazeFace-PyTorch
  • [Zak Murez] MediaPipePyTorch : zmurez/MediaPipePytorch

The three GitHub repositories seem to be inter-related, and build on top of each other.

The BlazePalm repository only provides the converted models for palm+hands, as well as the conversion script as a jupyter notebook.

Similarly, the BlazeFace-PyTorch repository provides the converted models for the face models, as well as the conversion script.

The MediaPipePyTorch repository provides the converted models for each of the hands, face, and pose models, but does not provide any of the conversion scripts.

The disadvantage of reusing this work is that it only works with the 0.07 version of the models (the latest versions are 0.10). I tried to convert a 0.10 version of the models with the provided conversion scripts, but this did not work. More investigation and work would be needed to make this happen.

We already saw in part 1 that differences are to be expected between version 0.07 and 0.10 of the models:

image

hand landmarks — 0.07 versus 0.10 lite versus 0.10 full (Camera: AlbertaBeef)

Nevertheless, I decided to go ahead with the PyTorch version of the 0.07 models for this project.

Creating a Calibration Dataset for Quantization

As described previously in the “Vitis-AI 3.5 Flow Overview” section, the quantization phase requires several hundreds to thousands of data samples, ideally a subset from the training data. Since we do not have access to the training dataset, we need to come up with this data ourselves.

We can generate the calibration dataset using a modified version of the blaze_app_python.py script, as follows:

image
gen_calib_hand_dataset.py (Camera: AlbertaBeef)

For each input image that contains at least one hand, we want to generate:

  • palm detection input images : resized image and padded to model’s input size
  • hand landmarks input images : cropped image of each hand, resized to model’s input size

Two possible sources for input images are the following:

  • Kaggle : many datasets exist, and may be reused
  • Pixabay : contains several interesting videos, from which images can be extracted

For specific examples for these two use cases, refer to my full write-up on Hackster:

  • [Hackster] Accelerating the MediaPipe models with Vitis-AI 3.5

If you know of other sources that can be used for the calibration dataset, please share your insight in the comments.

A Deeper Dive into the Palm Detection model

Before we tackle the deployment flow with Vitis-AI, it is worth taking a deeper dive into the models we will be working with. For this purpose, I will highlight the architecture of the palm detection model (0.07 version).

At a very high level, there are three convolutional neural network backbones that are used to extract features at three different scales. The outputs of these three backbones are combined together to feed two different heads : classifiers (containing score) and regressors (containing bounding box and additional keypoints).

image
Palm Detection (0.07) — Block Diagram 1 (Camera: AlbertaBeef)

The input to this model is a 256x256 RGB image, while the outputs of the model are 2944 candidate results, each containing:

  • score
  • bounding box (normalized to pre-determined anchor boxes)
  • keypoints (7 keypoints for palm detector)
image
Palm Detection (0.07) — Block Diagram 2 (Camera: AlbertaBeef)

The following block diagram illustrates details of the layers for the model. I have grouped together repeating patterns as “BLAZE BLOCK A”, “BLAZE BLOCK B”, and “BLAZE BLOCK C”, showing the details only for the first occurrence.

image
Palm Detection (0.07) — Block Diagram 3 (Camera: AlbertaBeef)

The following block diagram is the same as the previous one, but this time showing details of the “BLAZE BLOCK B” patterns, which will required further discussion during the deployment phase.

image
Palm Detection (0.07) — Block Diagram 4 (Camera: AlbertaBeef)

Model Inspection

As we saw previously in the “Vitis-AI Flow Overview” section, the deployment phase starts with an inspection of the model in order to determine if the layers are supported by the Vitis-AI compiler.

The following jupyter notebook contains the results for my initial model inspection:

  • https://github.com/AlbertaBeef/blaze_tutorial/blob/2023.1/vitis-ai/blazepalm_exploration11.ipynb

This notebook, executed within the Vitis-AI 3.5 docker container for PyTorch, produced the following results that indicate that three specific instances of the PAD layer are not supported:

**************************************************
* VITIS_AI Compilation - Xilinx Inc.
*
**
***********************************************
[UNILOG][INFO] Compile mode: dpu
[UNILOG][INFO] Debug mode: null
[UNILOG][INFO] Target architecture: DPUCZDX8G_ISA1_B4096
[UNILOG][INFO] Graph name: BlazePalm, with op num: 730
[UNILOG][INFO] Begin to compile...
[UNILOG][WARNING] xir::Op{name = BlazePalm__BlazePalm_Sequential_backbone1__BlazeBlock_9__ret_51, type = pad-fix} has been assigned to CPU: [DPU does not support CONSTANT mode. (only support SYMMETRIC mode for all devices and CONSTANT mode for some DPUv4e devices)].
[UNILOG][WARNING] xir::Op{name = BlazePalm__BlazePalm_Sequential_backbone1__BlazeBlock_17__ret_103, type = pad-fix} has been assigned to CPU: [DPU does not support CONSTANT mode. (only support SYMMETRIC mode for all devices and CONSTANT mode for some DPUv4e devices)].
[UNILOG][WARNING] xir::Op{name = BlazePalm__BlazePalm_Sequential_backbone2__BlazeBlock_0__ret_155, type = pad-fix} has been assigned to CPU: [DPU does not support CONSTANT mode. (only support SYMMETRIC mode for all devices and CONSTANT mode for some DPUv4e devices)].
[UNILOG][INFO] Total device subgraph number 10, DPU subgraph number 4
[UNILOG][INFO] Compile done.
...

If we trace this back to our block diagram, these are PAD layers contained in our “BLAZE BLOCK B” patterns, as shown in red below:

image
Palm Detection (0.07) — Unsupported PAD Layers (Camera: AlbertaBeef)

The model contains several other PAD layers which have supported configurations, with padding in the first (top/bottom) and second (left/right) dimensions. The PAD layers which are not supported are those which are performing the padding in the third (front/back) dimension, as shown below:

image
Unsupported PAD Layer (Camera: AlbertaBeef)

These three unsupported layers result in the generation of a model implemented with 7 sub-graphs : 4 running on the DPU, and the 3 PAD layers running on the CPU. This is not ideal, since the transfer of results between the CPU and DPU will affect performance.

image
Palm Detection (0.07) — Multiple Sub-Graphs (Camera: AlbertaBeef)

Before continuing with the deployment, I attempted to change the model to be fully supported by the Vitis-AI compiler. The main challenge in doing this is that I can not re-train the model (which is the recommended way forward when unsupported layers need to be handled). The modifications I make must preserve the weights of the trained model.

If you are interested in my alternate implementations for the unsupportred PAD layer, refer to my full write-up on Hackster:

  • [Hackster] Accelerating the MediaPipe models with Vitis-AI 3.5

Ultimately, I abandoned my attempts to avoid the unsupported PAD layer, and moved on with the original PyTorch model.

Fortunately, this is one of the options supported by the Vitis-AI workflow, as shown in the following flow diagram:

image
Vitis AI 3.5 — Model Inspection Decision Tree (Camera: AMD)

We will choose to use the Custom OP flow that allows certain layers to be executed on the CPU.

Model Deployment

The model that we deployed in the previous section did not yield good accuracy. This is because we used random values as calibration data for the quantization phase. This was fine for exploring the model architecture, but not for our final models.

Now that we are satisfied with our model architectures, we can deploy them with scripting, using the calibration data we have prepared.

I have prepared a script for this purpose:

  • vitisai_pytorch_flow.py

This script takes three (3) arguments when invoked:

  • name : BlazePalm, BlazeHandLandmark, etc …
  • resolution : input size (ie. 256)
  • process : inspect, quantize, all

The name argument indicates which model we are deploying, such as BlazePalm for the palm detector or BlazeHandLandmark for the hand landmark models. The resolution indicates the input size to the model.

These two arguments will determine which calibration dataset to use for the quantization. For example:

  • name=BlazePalm, size=256 => calib_palm_detection_256_dataset.npy
  • name=BlazeHandLandmark, size=256 => calib_hand_landmark_256_dataset.npy

The process argument indicates which task to run. By default specify “quantize” or “all” to quantize the model. The model compilation must be invoked using the vai_c_xir command after quantization.i

I have provided a second script which will call the vitisai_pytorch_flow.py script to quantize the models to be deployed, and also invokes the compilation command for each DPU architectures that are being targeted.

  • deploy_models.sh

You will want to modify the following lists before execution:

  • model_list : specify which model(s) you want to deploy
  • dpu_arch_list : specify which DPU architecture(s) you want to target

Below is a modified version of the script that will deploy the palm detection and hand landmarks models for the B512 and B128 DPU architectures.

# PyTorch models
model_palm_detector_v0_07=("BlazePalm", 256)
model_hand_landmark_v0_07=("BlazeHandLandmark", 256)
model_face_detector_v0_07_front=("BlazeFace", 128)
model_face_detector_v0_07_back=("BlazeFaceBack",256)
model_face_landmark_v0_07=("BlazeFaceLandmark", 192)
model_pose_detector_v0_07=("BlazePose", 128)
model_pose_landmark_v0_07=("BlazePoseLandmark", 256)
model_list=(
model_palm_detector_v0_07[@]
model_hand_landmark_v0_07[@]
)
# Versal AI Edge
dpu_c20b14=("C20B14","./arch/C20B14/arch-c20b14.json")
dpu_c20b1=("C20B1","./arch/C20B14/arch-c20b1.json")
# Zynq-UltraScale+
dpu_b4096=("B4096","./arch/B4096/arch-zcu104.json")
dpu_b3136=("B3136","./arch/B3136/arch-kv260.json")
dpu_b2304=("B2304","./arch/B2304/arch-b2304-lr.json")
dpu_b1152=("B1152","./arch/B1152/arch-b1152-hr.json")
dpu_b512=("B512","./arch/B512/arch-b512-lr.json")
dpu_b128=("B128","./arch/B128/arch-b128-lr.json")
#
dpu_arch_list=(
dpu_b512[@]
dpu_b128[@]
)
model_count=${#model_list[@]}
#echo $model_count
dpu_arch_count=${#dpu_arch_list[@]}
#echo $dpu_arch_count

# Model
for ((i=0; i<$model_count; i++))
do
model=${!model_list[i]}
model_array=(${model//,/ })
model_name=${model_array[0]}
input_resolution=${model_array[1]}
echo python3 vitisai_pytorch_flow.py --name ${model_name} --resolution ${input_resolution} --process all
python3 vitisai_pytorch_flow.py --name ${model_name} --resolution ${input_resolution} --process all | tee deploy_${model_name}_quantize.log
if [ ${model_name} == "BlazeFaceBack" ]
then
mv quantize_result/BlazeFace_int.pt quantize_result/BlazeFaceBack_int.pt
mv quantize_result/BlazeFace_int.onnx quantize_result/BlazeFaceBack_int.onnx
mv quantize_result/BlazeFace_int.xmodel quantize_result/BlazeFaceBack_int.xmodel
fi

for ((j=0; j<$dpu_arch_count; j++))
do
dpu=${!dpu_arch_list[j]}
dpu_array=(${dpu//,/ })
dpu_arch=${dpu_array[0]}
dpu_json=${dpu_array[1]}

echo vai_c_xir -x ./quantize_result/${model_name}_int.xmodel -a ${dpu_json} -o ./models/${model_name}/${dpu_arch} -n ${model_name}
vai_c_xir -x ./quantize_result/${model_name}_int.xmodel -a ${dpu_json} -o ./models/${model_name}/${dpu_arch} -n ${model_name} | tee deploy_${model_name}_compile.log
done
done

This script must be executed in the Vitis-AI 3.5 docker container for Pytorch. Launch the docker from the “blaze_tutorial” directory as follows:

$ ./vitis-ai/docker_run.sh xilinx/vitis-ai-pytorch-cpu

Inside the Vitis-AI 3.5 docker, launch the deploy_models.sh script as follows:

$ cd vitis-ai
$ source ./deploy_models.sh

When complete, the compiled models will be located in the following sub-directory:

  • models/BlazePalm/B128/BlazePalm.xmodel
  • models/BlazePalm/B512/BlazePalm.xmodel
  • models/BlazeHandLandmark/B128/BlazeHandLandmark.xmodel
  • models/BlazeHandLandmark/B512/BlazeHandLandmark.xmodel

For convenience, I have archived the compiled models for Vitis-AI in the following archive:

  • Vitis-AI models : blaze_vitisai_compiled_models.zip

Model Execution

In order to execute the models, we need to consider that each one of our models has multiple sub-graphs, as shown below:

image
Palm Detection (0.07) — Multiple Sub-Graphs on Zynq UltraScale+ (Camera: AlbertaBeef)

The CPU sub-graphs (implementing the unsupported PAD layer) are running on the Zynq UltraScale+ ARM processor(s) in the PS. The DPU sub-graphs are running on the DPU engine that is implemented in the PL. The transfers between the CPU and DPU sub-graphs occur on high-speed AXI interconnect.

Vitis-AI provides two methods for accomplishing the scheduling of these multiple sub-graphs:

  • manual scheduling : user invokes each of the DPU sub-graphs individually, and explicitly handles the transfers to/from each DPU sub-graph
  • graph runner : user invokes the model as a single entity, and the graph runner handles the the sub-graph details including transfers

I went through the exercise of implementing both versions, and now appreciate the convenience provided by the graph runner.

For reference, here are links to initial versions of model execution using manual scheduling versus using the graph runner:

  • blazepalm11a_detect_live.py (manual scheduling)
  • blazepalm11b_detect_live.py (graph runner)

The take-away is that the code for manual scheduling is much more complex,

For the purpose of this use case (support for PyTorch models, and Vitis-AI models), the “blaze_app_python” application was augmented with the following inference targets:

image
blaze_app_python — support for PyTorch and Vitis-AI (Camera: AlbertaBeef)

My final inference code for the Vitis-AI models can be found in the “blaze_app_python” repository, under the blaze_vitisai sub-directory:

  • blaze_app_python/blaze_vitisai/blazedetector.py
  • blaze_app_python/blaze_vitisai/blazelandmark.py

Note that the PyTorch inference can be run on a computer, while the Vitis-AI inference must be run on the Zynq UltraScale+ embedded platform (ie. ZUBoard).

Implementing Custom OPs

Since we are using the Custom OP flow, and know that certain layers will be executed on the CPU, we need to create the code for these unsupported layers layers.

Our models have the following unsupported layers:

  • Palm Detector : pad-fix
  • Hand Landmarks : pad-fix, eltwise-fix

Fortunately for us, these layers actually exist in the graph runner’s supported cpu tasks. They can be found here on our embedded platform:

  • pad-fix : /usr/lib/libvart_op_imp_pad-fix.so
  • eltwise-fix : /usr/lib/libvart_op_imp_eltwise-fix.so

If we try to run our model with the graph runner “as is”, however, we would get the following run-time errors from these existing cpu tasks:

pad-fix

•WARNING: Logging before InitGoogleLogging() is written to STDERR
F0326 20:22:02.581987 19129 pad-fix.cpp:46]
Check failed: input_shape[3] == output_shape[3] (32 vs. 64)
*** Check failure stack trace: ***
Aborted

eltwise-fix

•WARNING: Logging before InitGoogleLogging() is written to STDERR
F0326 20:19:26.838560 19071 eltwise-fix.cpp:35]
unsupported eltwise type: DIV
*** Check failure stack trace: ***
Aborted

For the case of the pad-fix layer, it does not support padding in the third dimension.

I have modified the following original code:

  • https://github.com/Xilinx/Vitis-AI/blob/master/src/vai_library/cpu_task/ops/pad-fix/pad-fix.cpp

To support padding in the third dimension in a modified version in my repo:

  • https://github.com/AlbertaBeef/blaze_app_python/blob/main/blaze_vitisai/cpu_tasks/op_pad-fix/pad-fix_custom.cpp

For the case of the eltwise-fix layer, it supports ADD and MULT operations, but not DIV.

I have modified the following original code:

  • https://github.com/Xilinx/Vitis-AI/blob/master/src/vai_library/cpu_task/ops/eltwise-fix/eltwise-fix.cpp

To support the DIV operation in a modified version in my repo:

  • https://github.com/AlbertaBeef/blaze_app_python/blob/main/blaze_vitisai/cpu_tasks/op_eltwise-fix/eltwise-fix_custom.cpp

The face landmarks model has an additional unsupported layer called “prelu”. When attempting to run this model, the following run-time error will occur:

WARNING: Logging before InitGoogleLogging() is written to STDERR
F0821 17:15:49.156904 1126 op_imp.cpp:110] [UNILOG][FATAL][VAILIB_CPU_RUNNER_OPEN_LIB_ERROR][dlopencan not open lib!] lib=libvart_op_imp_prelu.so;error=libvart_op_imp_prelu.so: cannot open shared object file: No such file or directory;op=xir::Op{name =
BlazeFaceLandmark__BlazeFaceLandmark_Sequential_backbone1__PReLU_1__ret_9, type = prelu}
*** Check failure stack trace: ***
Aborted

Unfortunately for this cpu task, there is no existing code, so a completely new library named libvart_op_imp_prelu.so would have to be created. I did not implement this layer … let me know if you have interest in this in order to use the accelerated face pipeline.

Launching the python application on ZUBoard

Using the blaze_app_python demo application, we can launch the 0.07 version of the model, compiled for Vitis-AI, as shown below:

image
python3 blaze_detect_live.py — pipeline=vai_hand_v0_07 (Video camera : AlbertaBeef)

The previous video has not been accelerated. It shows the frame rate to be approximately 12 fps when no hands are detected (one model running : palm detection), and approximately 8 fps when one hand has been detected (two models running : palm detection and hand landmarks).

It is worth noting that this is running with a single-threaded python script. There is an opportunity for increased performance with a multi-threaded implementation. While the graph runner is waiting for transfers from one model’s sub-graphs, another (or several other) model(s) could be launched in parallel …

There is also an opportunity to accelerate the rest of the pipeline with C++ code …

Benchmarking the models on ZUBoard

Using the blaze_app_python demo application, we are able to profile the original TFLite model against the accelerated Vitis-AI implementation on ZUBoard for the DPU B512 architecture:

image
Palm Detection + Hand Landmarks benchmarks — TFLite versus Vitis-AI (B512) (Camera: AlbertaBeef)

We can see that the Vitis-AI models provide the following performance increase:

  • palm detection (model only) : 14x faster
  • hand landmark (model only) : 10x faster
  • palm detection + hand landmark (full pipeline) : 8x faster

Again, it is worth noting that these benchmarks have been taken with a single-threaded python script. There is additional opportunity for acceleration with a multi-threaded implementation. While the graph runner is waiting for transfers from one model’s sub-graphs, another (or several other) model(s) could be launched in parallel …

There is also an opportunity to accelerate the rest of the pipeline with C++ code …

Benchmarking the models for various Vitis-AI 3.5 platforms

In order to get a better feeling of the acceleration achieved with Vitis-AI 3.5, I decided to perform similar profiling for the following platforms:

  • ZUBoard : dual-Cortex-A53 ARM processor / DPU (B512 and B128)
  • ZCU104 : quad-Cortex-A53 ARM processors / DPU (dual-B4096)
  • VEK280 : dual-Cortex-A72 ARM processors / DPU (C20B1 and C20B14)

In order to determine the acceleration achieved on each platform, the reference TFLite models needed to be benchmarked as well. Although we will only be accelerating version 0.07 of the palm detection + hand landmarks pipeline, I have profiled the 0.07 and 0.10 version of the models for reference:

image
Palm Detection + Hand Landmarks benchmarks (0.07 & 0.10) — TFLite Reference (Camera: AlbertaBeef)

Next, I profiled the 0.07 versions of the models deployed with Vitis-AI, and compared with the reference TFLite models:

image
Palm Detection + Hand Landmarks benchmarks (0.07) — TFLite Versus Vitis-AI (Camera: AlbertaBeef)

image
Palm Detection + Hand Landmarks benchmarks (0.07) — Vitis-AI execution times (Camera: AlbertaBeef)

If we analyze these results per platform, we can observe the following acceleration:

image
Palm Detection + Hand Landmarks (0.07) - Vitis-AI Acceleration for ZUBoard (Camera: AlbertaBeef)

image
Palm Detection + Hand Landmarks (0.07) - Vitis-AI Acceleration for ZCU104 (Camera: AlbertaBeef)

image
Palm Detection + Hand Landmarks (0.07) - Vitis-AI Acceleration for VEK280 (Camera: AlbertaBeef)

If we plot the acceleration ratio of the execution times for the Vitis-AI models with respect to the TFLite models for each platform, we get the following results:

image
Palm Detection + Hand Landmarks benchmarks (0.07) — Vitis-AI Acceleration ratios (Camera: AlbertaBeef)

The uncontested winner is the VEK280, even with the smallest version of its DPU engine (C20B1). Both of its TFLite and Vitis-AI models have the fastest performance.

However, if we look at acceleration ratio, we can observe the greatest acceleration on the ZCU104 platform. Specifically 26X faster for the palm detection model.

This is not the performance I was targeting, but significant acceleration nonetheless !

Going Further

For detailed instructions on deploying the models, and installing the python demo application, refer to my full write-up on Hackster:

  • [Hackster] Accelerating the MediaPipe models with Vitis-AI 3.5

I hope this project will inspire you to implement your own custom application.

Acknowledgements

I want to thank Google (https://www.hackster.io/google) for making the following available publicly:

  • MediaPipe

A big thanks to AMD for their Vitis-AI 3.5 framework:

  • [AMD] Vitis-AI 3.5 Documentation :
    https://xilinx.github.io/Vitis-AI/3.5/html/index.html

I want to thank the following developers for making their TFLite to PyTorch conversion work available publicly:

  • [Vidur Satija] BlazePalm : vidursatija/BlazePalm
  • [Matthijs Hollemans] BlazeFace-PyTorch : hollance/BlazeFace-PyTorch
  • [Zak Murez] MediaPipePyTorch : zmurez/MediaPipePytorch

References

  • [Google] MediaPipe Solutions Guide : https://ai.google.dev/edge/mediapipe/solutions/guide
  • [Vidur Satija] BlazePalm : vidursatija/BlazePalm
  • [Matthijs Hollemans] BlazeFace-PyTorch : hollance/BlazeFace-PyTorch
  • [Zak Murez] MediaPipePyTorch : zmurez/MediaPipePytorch
  • [AMD] Vitis-AI 3.5 Documentation :
    https://xilinx.github.io/Vitis-AI/3.5/html/index.html
  • [AlbertaBeef] blaze_app_python : AlbertaBeef/blaze_app_python
  • [AlbertaBeef] blaze_tutorial : AlbertaBeef/blaze_tutorial
  • [Hackster] Accelerating the MediaPipe models with Vitis-AI 3.5
  • Sign in to reply
  • flyingbean
    flyingbean 10 months ago

    Thanks for the detailed information and rich references about Vitis-AI 3.5 applications. I am planning to have a journey on this path soon. This blog is very informative.

    • Cancel
    • Vote Up 0 Vote Down
    • Sign in to reply
    • More
    • Cancel
  • DAB
    DAB 10 months ago

    Very interesting post.

    • Cancel
    • Vote Up 0 Vote Down
    • Sign in to reply
    • More
    • Cancel
element14 Community

element14 is the first online community specifically for engineers. Connect with your peers and get expert answers to your questions.

  • Members
  • Learn
  • Technologies
  • Challenges & Projects
  • Products
  • Store
  • About Us
  • Feedback & Support
  • FAQs
  • Terms of Use
  • Privacy Policy
  • Legal and Copyright Notices
  • Sitemap
  • Cookies

An Avnet Company © 2025 Premier Farnell Limited. All Rights Reserved.

Premier Farnell Ltd, registered in England and Wales (no 00876412), registered office: Farnell House, Forge Lane, Leeds LS12 2NE.

ICP 备案号 10220084.

Follow element14

  • X
  • Facebook
  • linkedin
  • YouTube