Optimize inference using `torch.compile()`

Optimize inference using torch.compile()

This guide aims to provide a benchmark on the inference speed-ups introduced with torch.compile() for computer vision models in 🌍 Transformers.

Benefits of torch.compile

Depending on the model and the GPU, torch.compile() yields up to 30% speed-up during inference. To use torch.compile(), simply install any version of torch above 2.0.

Compiling a model takes time, so it’s useful if you are compiling the model only once instead of every time you infer. To compile any computer vision model of your choice, call torch.compile() on the model as shown below:

Copied

from transformers import AutoModelForImageClassification

model = AutoModelForImageClassification.from_pretrained(MODEL_ID).to("cuda")
+ model = torch.compile(model)

compile() comes with multiple modes for compiling, which essentially differ in compilation time and inference overhead. max-autotune takes longer than reduce-overhead but results in faster inference. Default mode is fastest for compilation but is not as efficient compared to reduce-overhead for inference time. In this guide, we used the default mode. You can learn more about it here.

We benchmarked torch.compile with different computer vision models, tasks, types of hardware, and batch sizes on torch version 2.0.1.

Benchmarking code

Below you can find the benchmarking code for each task. We warm up the GPU before inference and take the mean time of 300 inferences, using the same image each time.

Image Classification with ViT

Copied

import torch
from PIL import Image
import requests
import numpy as np
from transformers import AutoImageProcessor, AutoModelForImageClassification

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224").to("cuda")
model = torch.compile(model)

processed_input = processor(image, return_tensors='pt').to(device="cuda")

with torch.no_grad():
    _ = model(**processed_input)

Object Detection with DETR

Copied

from transformers import AutoImageProcessor, AutoModelForObjectDetection

processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = AutoModelForObjectDetection.from_pretrained("facebook/detr-resnet-50").to("cuda")
model = torch.compile(model)

texts = ["a photo of a cat", "a photo of a dog"]
inputs = processor(text=texts, images=image, return_tensors="pt").to("cuda")

with torch.no_grad():
    _ = model(**inputs)

Image Segmentation with Segformer

Copied

from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation

processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512").to("cuda")
model = torch.compile(model)
seg_inputs = processor(images=image, return_tensors="pt").to("cuda")

with torch.no_grad():
    _ = model(**seg_inputs)

Below you can find the list of the models we benchmarked.

Image Classification

Image Segmentation

Object Detection

Below you can find visualization of inference durations with and without torch.compile() and percentage improvements for each model in different hardware and batch sizes.

Below you can find inference durations in milliseconds for each model with and without compile(). Note that OwlViT results in OOM in larger batch sizes.

A100 (batch size: 1)

Task/Model

torch 2.0 - no compile

torch 2.0 - compile

Image Classification/ViT

9.325

7.584

Image Segmentation/Segformer

11.759

10.500

Object Detection/OwlViT

24.978

18.420

Image Classification/BeiT

11.282

8.448

Object Detection/DETR

34.619

19.040

Image Classification/ConvNeXT

10.410

10.208

Image Classification/ResNet

6.531

4.124

Image Segmentation/Mask2former

60.188

49.117

Image Segmentation/Maskformer

75.764

59.487

Image Segmentation/MobileNet

8.583

3.974

Object Detection/Resnet-101

36.276

18.197

Object Detection/Conditional-DETR

31.219

17.993

A100 (batch size: 4)

Task/Model

torch 2.0 - no compile

torch 2.0 - compile

Image Classification/ViT

14.832

14.499

Image Segmentation/Segformer

18.838

16.476

Image Classification/BeiT

13.205

13.048

Object Detection/DETR

48.657

32.418

Image Classification/ConvNeXT

22.940

21.631

Image Classification/ResNet

6.657

4.268

Image Segmentation/Mask2former

74.277

61.781

Image Segmentation/Maskformer

180.700

159.116

Image Segmentation/MobileNet

14.174

8.515

Object Detection/Resnet-101

68.101

44.998

Object Detection/Conditional-DETR

56.470

35.552

A100 (batch size: 16)

Task/Model

torch 2.0 - no compile

torch 2.0 - compile

Image Classification/ViT

40.944

40.010

Image Segmentation/Segformer

37.005

31.144

Image Classification/BeiT

41.854

41.048

Object Detection/DETR

164.382

161.902

Image Classification/ConvNeXT

82.258

75.561

Image Classification/ResNet

7.018

5.024

Image Segmentation/Mask2former

178.945

154.814

Image Segmentation/Maskformer

638.570

579.826

Image Segmentation/MobileNet

51.693

30.310

Object Detection/Resnet-101

232.887

155.021

Object Detection/Conditional-DETR

180.491

124.032

V100 (batch size: 1)

Task/Model

torch 2.0 - no compile

torch 2.0 - compile

Image Classification/ViT

10.495

6.00

Image Segmentation/Segformer

13.321

5.862

Object Detection/OwlViT

25.769

22.395

Image Classification/BeiT

11.347

7.234

Object Detection/DETR

33.951

19.388

Image Classification/ConvNeXT

11.623

10.412

Image Classification/ResNet

6.484

3.820

Image Segmentation/Mask2former

64.640

49.873

Image Segmentation/Maskformer

95.532

72.207

Image Segmentation/MobileNet

9.217

4.753

Object Detection/Resnet-101

52.818

28.367

Object Detection/Conditional-DETR

39.512

20.816

V100 (batch size: 4)

Task/Model

torch 2.0 - no compile

torch 2.0 - compile

Image Classification/ViT

15.181

14.501

Image Segmentation/Segformer

16.787

16.188

Image Classification/BeiT

15.171

14.753

Object Detection/DETR

88.529

64.195

Image Classification/ConvNeXT

29.574

27.085

Image Classification/ResNet

6.109

4.731

Image Segmentation/Mask2former

90.402

76.926

Image Segmentation/Maskformer

234.261

205.456

Image Segmentation/MobileNet

24.623

14.816

Object Detection/Resnet-101

134.672

101.304

Object Detection/Conditional-DETR

97.464

69.739

V100 (batch size: 16)

Task/Model

torch 2.0 - no compile

torch 2.0 - compile

Image Classification/ViT

52.209

51.633

Image Segmentation/Segformer

61.013

55.499

Image Classification/BeiT

53.938

53.581

Object Detection/DETR

OOM

OOM

Image Classification/ConvNeXT

109.682

100.771

Image Classification/ResNet

14.857

12.089

Image Segmentation/Mask2former

249.605

222.801

Image Segmentation/Maskformer

831.142

743.645

Image Segmentation/MobileNet

93.129

55.365

Object Detection/Resnet-101

482.425

361.843

Object Detection/Conditional-DETR

344.661

255.298

T4 (batch size: 1)

Task/Model

torch 2.0 - no compile

torch 2.0 - compile

Image Classification/ViT

16.520

15.786

Image Segmentation/Segformer

16.116

14.205

Object Detection/OwlViT

53.634

51.105

Image Classification/BeiT

16.464

15.710

Object Detection/DETR

73.100

53.99

Image Classification/ConvNeXT

32.932

30.845

Image Classification/ResNet

6.031

4.321

Image Segmentation/Mask2former

79.192

66.815

Image Segmentation/Maskformer

200.026

188.268

Image Segmentation/MobileNet

18.908

11.997

Object Detection/Resnet-101

106.622

82.566

Object Detection/Conditional-DETR

77.594

56.984

T4 (batch size: 4)

Task/Model

torch 2.0 - no compile

torch 2.0 - compile

Image Classification/ViT

43.653

43.626

Image Segmentation/Segformer

45.327

42.445

Image Classification/BeiT

52.007

51.354

Object Detection/DETR

277.850

268.003

Image Classification/ConvNeXT

119.259

105.580

Image Classification/ResNet

13.039

11.388

Image Segmentation/Mask2former

201.540

184.670

Image Segmentation/Maskformer

764.052

711.280

Image Segmentation/MobileNet

74.289

48.677

Object Detection/Resnet-101

421.859

357.614

Object Detection/Conditional-DETR

289.002

226.945

T4 (batch size: 16)

Task/Model

torch 2.0 - no compile

torch 2.0 - compile

Image Classification/ViT

163.914

160.907

Image Segmentation/Segformer

192.412

163.620

Image Classification/BeiT

188.978

187.976

Object Detection/DETR

OOM

OOM

Image Classification/ConvNeXT

422.886

388.078

Image Classification/ResNet

44.114

37.604

Image Segmentation/Mask2former

756.337

695.291

Image Segmentation/Maskformer

2842.940

2656.88

Image Segmentation/MobileNet

299.003

201.942

Object Detection/Resnet-101

1619.505

1262.758

Object Detection/Conditional-DETR

1137.513

897.390

PyTorch Nightly

We also benchmarked on PyTorch nightly (2.1.0dev, find the wheel here) and observed improvement in latency both for uncompiled and compiled models.

A100

Task/Model

Batch Size

torch 2.0 - no compile

torch 2.0 - compile

Image Classification/BeiT

Unbatched

12.462

6.954

Image Classification/BeiT

4

14.109

12.851

Image Classification/BeiT

16

42.179

42.147

Object Detection/DETR

Unbatched

30.484

15.221

Object Detection/DETR

4

46.816

30.942

Object Detection/DETR

16

163.749

163.706

T4

Task/Model

Batch Size

torch 2.0 - no compile

torch 2.0 - compile

Image Classification/BeiT

Unbatched

14.408

14.052

Image Classification/BeiT

4

47.381

46.604

Image Classification/BeiT

16

42.179

42.147

Object Detection/DETR

Unbatched

68.382

53.481

Object Detection/DETR

4

269.615

204.785

Object Detection/DETR

16

OOM

OOM

### V100

Task/Model

Batch Size

torch 2.0 - no compile

torch 2.0 - compile

Image Classification/BeiT

Unbatched

13.477

7.926

Image Classification/BeiT

4

15.103

14.378

Image Classification/BeiT

16

52.517

51.691

Object Detection/DETR

Unbatched

28.706

19.077

Object Detection/DETR

4

88.402

62.949

Object Detection/DETR

16

OOM

OOM

Reduce Overhead

We benchmarked reduce-overhead compilation mode for A100 and T4 in Nightly.

A100

Task/Model

Batch Size

torch 2.0 - no compile

torch 2.0 - compile

Image Classification/ConvNeXT

Unbatched

11.758

7.335

Image Classification/ConvNeXT

4

23.171

21.490

Image Classification/ResNet

Unbatched

7.435

3.801

Image Classification/ResNet

4

7.261

2.187

Object Detection/Conditional-DETR

Unbatched

32.823

11.627

Object Detection/Conditional-DETR

4

50.622

33.831

Image Segmentation/MobileNet

Unbatched

9.869

4.244

Image Segmentation/MobileNet

4

14.385

7.946

T4

Task/Model

Batch Size

torch 2.0 - no compile

torch 2.0 - compile

Image Classification/ConvNeXT

Unbatched

32.137

31.84

Image Classification/ConvNeXT

4

120.944

110.209

Image Classification/ResNet

Unbatched

9.761

7.698

Image Classification/ResNet

4

15.215

13.871

Object Detection/Conditional-DETR

Unbatched

72.150

57.660

Object Detection/Conditional-DETR

4

301.494

247.543

Image Segmentation/MobileNet

Unbatched

22.266

19.339

Image Segmentation/MobileNet

4

78.311

50.983

Last updated