A Light Recipe to Train Robust Vision Transformers¶

Presentation by: Pranjal Gulati

Transformers for robustness¶

  • non-adversarially trained transformers show better robustness than CNNs to non-adversarial examples (common corruptions, semantic shifts, ood samples) (Paul et. al.)
  • BUT, recent study shows transformers are no more adversarially robust than CNNs (Bai et. al.)
  • current paper claims it is just a result of sub-par training recipes

Searching for a better training recipe¶

  • data augmentation policy
  • warmup for perturbation budget $\epsilon$
  • weight decay

Training parameters¶

  • training time: XCiT-S12, pod with 64TPUv4 cores, 256 batch, 110 epochs => 19h 30min (need more analysis!!)
  • training attack: FGSM for AT, epsilon 4/255 (2 step attack used for some variants)

Canonical training¶

  • to achieve good performance with ViTs without pretraining, we apply strong data augmentations
  • mixup, cutmix, random erasing helps compensate for the vision priors missing in transformers
  • 28.7% AutoAttack accuracy (35.51% with ResNet-50 (Bai et. al.), using a setup that does not differ from the standard ResNet setup)

Weak augmentations¶

  • robust accuracy improves by 3.84% when using basic augmentations like flips, resize, color jitter instead of heavy ones listed before
  • AT already introduces regularization (within the inner maximization), strong augmentations make the optimization harder
In [ ]:
torchvision.transforms.ColorJitter(brightness=0.4)
torchvision.transforms.RandomHorizontalFlip(p=0.5)
torchvision.transforms.RandomResizedCrop(size=224,
                                         scale=(0.08, 1.0),
                                         ratio=(0.75, 1.33))

image.png

More optimization setup¶

  • epsilon warmup: 20 epoch linear warmup for adversarial perturbation budget
    • Bai et. al. failed to adversarially trained DeiT
    • authors trained an XCiT-N12 with successful training, but training struggled for the first few epochs
  • weight decay: used a higher decay (10x) of 0.5 than what was used in XCiT training

image.png

image.png

Comparision of training recipes¶

Canonical Proposed
no warmup 20 epoch linear warmup
strong data augs (mixup, cutmix, random erasing) weak data augs (flips, resize, color jitter)
decay 0.05 decay 0.5

Step by step improvements¶

image.png

Cross architecture generalization¶

image.png

Current state of RobustBench leaderboard¶

Setup: ImageNet, l-infinity = 4/255, untargeted attack

image.png

Comparisions with DeACL recipe¶

  • DeACL uses a much lower weight decay (5e-4)
    • but that is required since higher wd for ResNet-18 makes training unstable
    • current paper authors also use a wd of 0.05 (instead of 0.5) for PoolFormer-M12, which has lesser parameters than ResNet-50
    • but wd in AT stage is still 10x that used in SSL stage for DeACL
  • DeACL also uses weak data augmentations for AT stage
    • since DeACL uses CIFAR-10, they use even weaker augmentations (color jitter is considered strong here)
  • DeACL does not mention epsilon warmup
    • Priyam has already conducted successful experiments with varying epsilon

Understanding success of VITs in AT¶

  • attack effectiveness: the key question is not whether it is easy to solve inner max (attack) but whether its easy under few steps
    • choice of network architecture impacts optimization sucess
    • single-step attack is more effective against their trained XCiT than a ResNet-50 network
  • the way we perform the (outer) min optimization of adversarial training influences the ease of (inner) max optimization,

image.png