diff --git a/cifar10.py b/cifar10.py index 33ec69d..7a0f8a1 100644 --- a/cifar10.py +++ b/cifar10.py @@ -21,13 +21,15 @@ from art.defences.trainer import AdversarialTrainer model = timm.create_model("timm/vit_base_patch16_224.orig_in21k_ft_in1k", pretrained=False) model.head = nn.Linear(model.head.in_features, 10) -model.load_state_dict( - torch.hub.load_state_dict_from_url( - "https://huggingface.co/edadaltocg/vit_base_patch16_224_in21k_ft_cifar10/resolve/main/pytorch_model.bin", - map_location="cuda", - file_name="vit_base_patch16_224_in21k_ft_cifar10.pth", - ) -) +state_dict = torch.load('/home/leewlving/.cache/torch/hub/checkpoints/vit_base_patch16_224_in21k_ft_cifar10.pth') +model.load_state_dict(state_dict) +# model.load_state_dict( +# torch.hub.load_state_dict_from_url( +# "https://huggingface.co/edadaltocg/vit_base_patch16_224_in21k_ft_cifar10/resolve/main/pytorch_model.bin", +# map_location="cuda", +# file_name="vit_base_patch16_224_in21k_ft_cifar10.pth", +# ) +# ) model.eval() DEFAULT_MEAN = (0.485, 0.456, 0.406) @@ -98,7 +100,7 @@ trainer = AdversarialTrainer( ) art_datagen = PyTorchDataGenerator(iterator=dataloader, size=x_train.shape[0], batch_size=64) -trainer.fit_generator(art_datagen, nb_epochs=100) +trainer.fit_generator(art_datagen, nb_epochs=1) for i, data in enumerate(dataloader): x, y = data @@ -111,6 +113,7 @@ for i, data in enumerate(dataloader): "Accuracy on benign test samples after adversarial training: %.2f%%" % (np.sum(x_test_pred == np.argmax(y, axis=1)) / x.shape[0] * 100) ) -trainer.classifier.save('AT-cifar10.pth') +# trainer.classifier.save('AT-cifar10.pth') +torch.save(trainer.classifier.model.state_dict(), 'AT-cifar10.pth')