From 75d95fea2c27c9cc6c1e2fa327137c8cfe64173c Mon Sep 17 00:00:00 2001 From: liwenyun Date: Sat, 19 Oct 2024 19:27:46 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20cifar10.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cifar10.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/cifar10.py b/cifar10.py index 33ec69d..7a0f8a1 100644 --- a/cifar10.py +++ b/cifar10.py @@ -21,13 +21,15 @@ from art.defences.trainer import AdversarialTrainer model = timm.create_model("timm/vit_base_patch16_224.orig_in21k_ft_in1k", pretrained=False) model.head = nn.Linear(model.head.in_features, 10) -model.load_state_dict( - torch.hub.load_state_dict_from_url( - "https://huggingface.co/edadaltocg/vit_base_patch16_224_in21k_ft_cifar10/resolve/main/pytorch_model.bin", - map_location="cuda", - file_name="vit_base_patch16_224_in21k_ft_cifar10.pth", - ) -) +state_dict = torch.load('/home/leewlving/.cache/torch/hub/checkpoints/vit_base_patch16_224_in21k_ft_cifar10.pth') +model.load_state_dict(state_dict) +# model.load_state_dict( +# torch.hub.load_state_dict_from_url( +# "https://huggingface.co/edadaltocg/vit_base_patch16_224_in21k_ft_cifar10/resolve/main/pytorch_model.bin", +# map_location="cuda", +# file_name="vit_base_patch16_224_in21k_ft_cifar10.pth", +# ) +# ) model.eval() DEFAULT_MEAN = (0.485, 0.456, 0.406) @@ -98,7 +100,7 @@ trainer = AdversarialTrainer( ) art_datagen = PyTorchDataGenerator(iterator=dataloader, size=x_train.shape[0], batch_size=64) -trainer.fit_generator(art_datagen, nb_epochs=100) +trainer.fit_generator(art_datagen, nb_epochs=1) for i, data in enumerate(dataloader): x, y = data @@ -111,6 +113,7 @@ for i, data in enumerate(dataloader): "Accuracy on benign test samples after adversarial training: %.2f%%" % (np.sum(x_test_pred == np.argmax(y, axis=1)) / x.shape[0] * 100) ) -trainer.classifier.save('AT-cifar10.pth') +# trainer.classifier.save('AT-cifar10.pth') +torch.save(trainer.classifier.model.state_dict(), 'AT-cifar10.pth')