diff --git a/svhn.py b/svhn.py new file mode 100644 index 0000000..bf6adf8 --- /dev/null +++ b/svhn.py @@ -0,0 +1,122 @@ +import warnings + +import torchvision.datasets + +warnings.filterwarnings('ignore') + +from PIL import Image +import torch +import timm +import requests +import numpy as np +import torchvision.transforms as transforms +from torch import nn +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torch.utils.data import Dataset, DataLoader +import copy + +from art.estimators.classification import PyTorchClassifier +from art.data_generators import PyTorchDataGenerator +from art.utils import load_cifar10 +from art.attacks.evasion import ProjectedGradientDescent +from art.defences.trainer import AdversarialTrainer + +model = timm.create_model("timm/vit_base_patch16_224.orig_in21k_ft_in1k", pretrained=False) +model.head = nn.Linear(model.head.in_features, 10) +state_dict = torch.load('/home/leewlving/.cache/torch/hub/checkpoints/vit_base_patch16_224_in21k_ft_cifar10.pth') +model.load_state_dict(state_dict) +# model.load_state_dict( +# torch.hub.load_state_dict_from_url( +# "https://huggingface.co/edadaltocg/vit_base_patch16_224_in21k_ft_cifar10/resolve/main/pytorch_model.bin", +# map_location="cuda", +# file_name="vit_base_patch16_224_in21k_ft_cifar10.pth", +# ) +# ) +model.eval() + +DEFAULT_MEAN = (0.485, 0.456, 0.406) +DEFAULT_STD = (0.229, 0.224, 0.225) + +transform = transforms.Compose([ + transforms.Resize(256, interpolation=3), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(DEFAULT_MEAN, DEFAULT_STD), +]) + + +class CIFAR10_dataset(Dataset): + def __init__(self, data, targets, transform=None): + self.data = data + self.targets = torch.LongTensor(targets) + self.transform = transform + + def __getitem__(self, index): + x = Image.fromarray(((self.data[index] * 255).round()).astype(np.uint8).transpose(1, 2, 0)) + x = self.transform(x) + y = self.targets[index] + return x, y + + def __len__(self): + return len(self.data) + + +# (x_train, y_train), (x_test, y_test), min_pixel_value, max_pixel_value = load_cifar10() +# print(max_pixel_value) +# x_train = x_train.transpose(0, 3, 1, 2).astype("float32") +# x_test = x_test.transpose(0, 3, 1, 2).astype("float32") +train_dataset = torchvision.datasets.SVHN(root='./svhn',split='train',download=True,transform=transform) +test_dataset= torchvision.datasets.SVHN(root='./svhn',split='test',download=True,transform=transform) +# dataset = CIFAR10_dataset(x_train, y_train, transform=transform) +dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True) +test_dataloader =DataLoader(test_dataset, batch_size=64, shuffle=False) + +opt = torch.optim.Adam(model.parameters(), lr=0.01) + + +criterion = nn.CrossEntropyLoss() + +classifier = PyTorchClassifier( + model=model, + clip_values=(0.0, 1.0), + loss=criterion, + optimizer=opt, + input_shape=(3, 224, 224), + nb_classes=10, +) + + +attack = ProjectedGradientDescent( + classifier, + norm=np.inf, + eps=8.0 / 255.0, + eps_step=2.0 / 255.0, + max_iter=10, + targeted=False, + num_random_init=1, + batch_size=64, + verbose=False, +) + +trainer = AdversarialTrainer( + classifier, attack +) +art_datagen = PyTorchDataGenerator(iterator=dataloader, size=len(train_dataset), batch_size=64) + +trainer.fit_generator(art_datagen, nb_epochs=1) + +# for i, data in enumerate(test_dataloader): +# x, y = data +# x = x.numpy() +# y = y.numpy() +# # print(x.shape) +# # print(y.shape) +# x_test_pred = np.argmax(classifier.predict(x), axis=1) +# print( +# "Accuracy on benign test samples after adversarial training: %.2f%%" +# % (np.sum(x_test_pred == np.argmax(y, axis=1)) / x.shape[0] * 100) +# ) + +# trainer.classifier.save('AT-cifar10.pth') +torch.save(trainer.classifier.model.state_dict(), 'AT-svhn.pth') +