From 45ae150791ba72925f12f701769a2bb2463aceeb Mon Sep 17 00:00:00 2001 From: liwenyun Date: Tue, 15 Oct 2024 10:32:11 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20cifar10.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cifar10.py | 116 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) diff --git a/cifar10.py b/cifar10.py index e69de29..33ec69d 100644 --- a/cifar10.py +++ b/cifar10.py @@ -0,0 +1,116 @@ +import warnings + +warnings.filterwarnings('ignore') + +from PIL import Image +import torch +import timm +import requests +import numpy as np +import torchvision.transforms as transforms +from torch import nn +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torch.utils.data import Dataset, DataLoader +import copy + +from art.estimators.classification import PyTorchClassifier +from art.data_generators import PyTorchDataGenerator +from art.utils import load_cifar10 +from art.attacks.evasion import ProjectedGradientDescent +from art.defences.trainer import AdversarialTrainer + +model = timm.create_model("timm/vit_base_patch16_224.orig_in21k_ft_in1k", pretrained=False) +model.head = nn.Linear(model.head.in_features, 10) +model.load_state_dict( + torch.hub.load_state_dict_from_url( + "https://huggingface.co/edadaltocg/vit_base_patch16_224_in21k_ft_cifar10/resolve/main/pytorch_model.bin", + map_location="cuda", + file_name="vit_base_patch16_224_in21k_ft_cifar10.pth", + ) +) +model.eval() + +DEFAULT_MEAN = (0.485, 0.456, 0.406) +DEFAULT_STD = (0.229, 0.224, 0.225) + +transform = transforms.Compose([ + transforms.Resize(256, interpolation=3), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(DEFAULT_MEAN, DEFAULT_STD), +]) + + +class CIFAR10_dataset(Dataset): + def __init__(self, data, targets, transform=None): + self.data = data + self.targets = torch.LongTensor(targets) + self.transform = transform + + def __getitem__(self, index): + x = Image.fromarray(((self.data[index] * 255).round()).astype(np.uint8).transpose(1, 2, 0)) + x = self.transform(x) + y = self.targets[index] + return x, y + + def __len__(self): + return len(self.data) + + +(x_train, y_train), (x_test, y_test), min_pixel_value, max_pixel_value = load_cifar10() +# print(max_pixel_value) +x_train = x_train.transpose(0, 3, 1, 2).astype("float32") +x_test = x_test.transpose(0, 3, 1, 2).astype("float32") + +dataset = CIFAR10_dataset(x_train, y_train, transform=transform) +dataloader = DataLoader(dataset, batch_size=64, shuffle=True) + + +opt = torch.optim.Adam(model.parameters(), lr=0.01) + + +criterion = nn.CrossEntropyLoss() + +classifier = PyTorchClassifier( + model=model, + clip_values=(min_pixel_value, max_pixel_value), + loss=criterion, + optimizer=opt, + input_shape=(3, 224, 224), + nb_classes=10, +) + + +attack = ProjectedGradientDescent( + classifier, + norm=np.inf, + eps=8.0 / 255.0, + eps_step=2.0 / 255.0, + max_iter=10, + targeted=False, + num_random_init=1, + batch_size=64, + verbose=False, +) + +trainer = AdversarialTrainer( + classifier, attack +) +art_datagen = PyTorchDataGenerator(iterator=dataloader, size=x_train.shape[0], batch_size=64) + +trainer.fit_generator(art_datagen, nb_epochs=100) + +for i, data in enumerate(dataloader): + x, y = data + x = x.numpy() + y = y.numpy() + # print(x.shape) + # print(y.shape) + x_test_pred = np.argmax(classifier.predict(x), axis=1) + print( + "Accuracy on benign test samples after adversarial training: %.2f%%" + % (np.sum(x_test_pred == np.argmax(y, axis=1)) / x.shape[0] * 100) + ) +trainer.classifier.save('AT-cifar10.pth') + +