Optimize inference using `torch.compile()`
Optimize inference using torch.compile()
This guide aims to provide a benchmark on the inference speed-ups introduced with torch.compile()
for computer vision models in 🌍 Transformers.
Benefits of torch.compile
Depending on the model and the GPU, torch.compile()
yields up to 30% speed-up during inference. To use torch.compile()
, simply install any version of torch
above 2.0.
Compiling a model takes time, so it’s useful if you are compiling the model only once instead of every time you infer. To compile any computer vision model of your choice, call torch.compile()
on the model as shown below:
Copied
compile()
comes with multiple modes for compiling, which essentially differ in compilation time and inference overhead. max-autotune
takes longer than reduce-overhead
but results in faster inference. Default mode is fastest for compilation but is not as efficient compared to reduce-overhead
for inference time. In this guide, we used the default mode. You can learn more about it here.
We benchmarked torch.compile
with different computer vision models, tasks, types of hardware, and batch sizes on torch
version 2.0.1.
Benchmarking code
Below you can find the benchmarking code for each task. We warm up the GPU before inference and take the mean time of 300 inferences, using the same image each time.
Image Classification with ViT
Copied
Object Detection with DETR
Copied
Image Segmentation with Segformer
Copied
Below you can find the list of the models we benchmarked.
Image Classification
Image Segmentation
Object Detection
Below you can find visualization of inference durations with and without torch.compile()
and percentage improvements for each model in different hardware and batch sizes.
Below you can find inference durations in milliseconds for each model with and without compile()
. Note that OwlViT results in OOM in larger batch sizes.
A100 (batch size: 1)
Task/Model
torch 2.0 - no compile
torch 2.0 - compile
Image Classification/ViT
9.325
7.584
Image Segmentation/Segformer
11.759
10.500
Object Detection/OwlViT
24.978
18.420
Image Classification/BeiT
11.282
8.448
Object Detection/DETR
34.619
19.040
Image Classification/ConvNeXT
10.410
10.208
Image Classification/ResNet
6.531
4.124
Image Segmentation/Mask2former
60.188
49.117
Image Segmentation/Maskformer
75.764
59.487
Image Segmentation/MobileNet
8.583
3.974
Object Detection/Resnet-101
36.276
18.197
Object Detection/Conditional-DETR
31.219
17.993
A100 (batch size: 4)
Task/Model
torch 2.0 - no compile
torch 2.0 - compile
Image Classification/ViT
14.832
14.499
Image Segmentation/Segformer
18.838
16.476
Image Classification/BeiT
13.205
13.048
Object Detection/DETR
48.657
32.418
Image Classification/ConvNeXT
22.940
21.631
Image Classification/ResNet
6.657
4.268
Image Segmentation/Mask2former
74.277
61.781
Image Segmentation/Maskformer
180.700
159.116
Image Segmentation/MobileNet
14.174
8.515
Object Detection/Resnet-101
68.101
44.998
Object Detection/Conditional-DETR
56.470
35.552
A100 (batch size: 16)
Task/Model
torch 2.0 - no compile
torch 2.0 - compile
Image Classification/ViT
40.944
40.010
Image Segmentation/Segformer
37.005
31.144
Image Classification/BeiT
41.854
41.048
Object Detection/DETR
164.382
161.902
Image Classification/ConvNeXT
82.258
75.561
Image Classification/ResNet
7.018
5.024
Image Segmentation/Mask2former
178.945
154.814
Image Segmentation/Maskformer
638.570
579.826
Image Segmentation/MobileNet
51.693
30.310
Object Detection/Resnet-101
232.887
155.021
Object Detection/Conditional-DETR
180.491
124.032
V100 (batch size: 1)
Task/Model
torch 2.0 - no compile
torch 2.0 - compile
Image Classification/ViT
10.495
6.00
Image Segmentation/Segformer
13.321
5.862
Object Detection/OwlViT
25.769
22.395
Image Classification/BeiT
11.347
7.234
Object Detection/DETR
33.951
19.388
Image Classification/ConvNeXT
11.623
10.412
Image Classification/ResNet
6.484
3.820
Image Segmentation/Mask2former
64.640
49.873
Image Segmentation/Maskformer
95.532
72.207
Image Segmentation/MobileNet
9.217
4.753
Object Detection/Resnet-101
52.818
28.367
Object Detection/Conditional-DETR
39.512
20.816
V100 (batch size: 4)
Task/Model
torch 2.0 - no compile
torch 2.0 - compile
Image Classification/ViT
15.181
14.501
Image Segmentation/Segformer
16.787
16.188
Image Classification/BeiT
15.171
14.753
Object Detection/DETR
88.529
64.195
Image Classification/ConvNeXT
29.574
27.085
Image Classification/ResNet
6.109
4.731
Image Segmentation/Mask2former
90.402
76.926
Image Segmentation/Maskformer
234.261
205.456
Image Segmentation/MobileNet
24.623
14.816
Object Detection/Resnet-101
134.672
101.304
Object Detection/Conditional-DETR
97.464
69.739
V100 (batch size: 16)
Task/Model
torch 2.0 - no compile
torch 2.0 - compile
Image Classification/ViT
52.209
51.633
Image Segmentation/Segformer
61.013
55.499
Image Classification/BeiT
53.938
53.581
Object Detection/DETR
OOM
OOM
Image Classification/ConvNeXT
109.682
100.771
Image Classification/ResNet
14.857
12.089
Image Segmentation/Mask2former
249.605
222.801
Image Segmentation/Maskformer
831.142
743.645
Image Segmentation/MobileNet
93.129
55.365
Object Detection/Resnet-101
482.425
361.843
Object Detection/Conditional-DETR
344.661
255.298
T4 (batch size: 1)
Task/Model
torch 2.0 - no compile
torch 2.0 - compile
Image Classification/ViT
16.520
15.786
Image Segmentation/Segformer
16.116
14.205
Object Detection/OwlViT
53.634
51.105
Image Classification/BeiT
16.464
15.710
Object Detection/DETR
73.100
53.99
Image Classification/ConvNeXT
32.932
30.845
Image Classification/ResNet
6.031
4.321
Image Segmentation/Mask2former
79.192
66.815
Image Segmentation/Maskformer
200.026
188.268
Image Segmentation/MobileNet
18.908
11.997
Object Detection/Resnet-101
106.622
82.566
Object Detection/Conditional-DETR
77.594
56.984
T4 (batch size: 4)
Task/Model
torch 2.0 - no compile
torch 2.0 - compile
Image Classification/ViT
43.653
43.626
Image Segmentation/Segformer
45.327
42.445
Image Classification/BeiT
52.007
51.354
Object Detection/DETR
277.850
268.003
Image Classification/ConvNeXT
119.259
105.580
Image Classification/ResNet
13.039
11.388
Image Segmentation/Mask2former
201.540
184.670
Image Segmentation/Maskformer
764.052
711.280
Image Segmentation/MobileNet
74.289
48.677
Object Detection/Resnet-101
421.859
357.614
Object Detection/Conditional-DETR
289.002
226.945
T4 (batch size: 16)
Task/Model
torch 2.0 - no compile
torch 2.0 - compile
Image Classification/ViT
163.914
160.907
Image Segmentation/Segformer
192.412
163.620
Image Classification/BeiT
188.978
187.976
Object Detection/DETR
OOM
OOM
Image Classification/ConvNeXT
422.886
388.078
Image Classification/ResNet
44.114
37.604
Image Segmentation/Mask2former
756.337
695.291
Image Segmentation/Maskformer
2842.940
2656.88
Image Segmentation/MobileNet
299.003
201.942
Object Detection/Resnet-101
1619.505
1262.758
Object Detection/Conditional-DETR
1137.513
897.390
PyTorch Nightly
We also benchmarked on PyTorch nightly (2.1.0dev, find the wheel here) and observed improvement in latency both for uncompiled and compiled models.
A100
Task/Model
Batch Size
torch 2.0 - no compile
torch 2.0 - compile
Image Classification/BeiT
Unbatched
12.462
6.954
Image Classification/BeiT
4
14.109
12.851
Image Classification/BeiT
16
42.179
42.147
Object Detection/DETR
Unbatched
30.484
15.221
Object Detection/DETR
4
46.816
30.942
Object Detection/DETR
16
163.749
163.706
T4
Task/Model
Batch Size
torch 2.0 - no compile
torch 2.0 - compile
Image Classification/BeiT
Unbatched
14.408
14.052
Image Classification/BeiT
4
47.381
46.604
Image Classification/BeiT
16
42.179
42.147
Object Detection/DETR
Unbatched
68.382
53.481
Object Detection/DETR
4
269.615
204.785
Object Detection/DETR
16
OOM
OOM
### V100
Task/Model
Batch Size
torch 2.0 - no compile
torch 2.0 - compile
Image Classification/BeiT
Unbatched
13.477
7.926
Image Classification/BeiT
4
15.103
14.378
Image Classification/BeiT
16
52.517
51.691
Object Detection/DETR
Unbatched
28.706
19.077
Object Detection/DETR
4
88.402
62.949
Object Detection/DETR
16
OOM
OOM
Reduce Overhead
We benchmarked reduce-overhead
compilation mode for A100 and T4 in Nightly.
A100
Task/Model
Batch Size
torch 2.0 - no compile
torch 2.0 - compile
Image Classification/ConvNeXT
Unbatched
11.758
7.335
Image Classification/ConvNeXT
4
23.171
21.490
Image Classification/ResNet
Unbatched
7.435
3.801
Image Classification/ResNet
4
7.261
2.187
Object Detection/Conditional-DETR
Unbatched
32.823
11.627
Object Detection/Conditional-DETR
4
50.622
33.831
Image Segmentation/MobileNet
Unbatched
9.869
4.244
Image Segmentation/MobileNet
4
14.385
7.946
T4
Task/Model
Batch Size
torch 2.0 - no compile
torch 2.0 - compile
Image Classification/ConvNeXT
Unbatched
32.137
31.84
Image Classification/ConvNeXT
4
120.944
110.209
Image Classification/ResNet
Unbatched
9.761
7.698
Image Classification/ResNet
4
15.215
13.871
Object Detection/Conditional-DETR
Unbatched
72.150
57.660
Object Detection/Conditional-DETR
4
301.494
247.543
Image Segmentation/MobileNet
Unbatched
22.266
19.339
Image Segmentation/MobileNet
4
78.311
50.983
Last updated