更新 cifar10.py

This commit is contained in:
liwenyun 2024-10-15 10:32:11 +08:00
parent 79f0ecad93
commit 45ae150791
1 changed files with 116 additions and 0 deletions

View File

@ -0,0 +1,116 @@
import warnings
warnings.filterwarnings('ignore')
from PIL import Image
import torch
import timm
import requests
import numpy as np
import torchvision.transforms as transforms
from torch import nn
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torch.utils.data import Dataset, DataLoader
import copy
from art.estimators.classification import PyTorchClassifier
from art.data_generators import PyTorchDataGenerator
from art.utils import load_cifar10
from art.attacks.evasion import ProjectedGradientDescent
from art.defences.trainer import AdversarialTrainer
model = timm.create_model("timm/vit_base_patch16_224.orig_in21k_ft_in1k", pretrained=False)
model.head = nn.Linear(model.head.in_features, 10)
model.load_state_dict(
torch.hub.load_state_dict_from_url(
"https://huggingface.co/edadaltocg/vit_base_patch16_224_in21k_ft_cifar10/resolve/main/pytorch_model.bin",
map_location="cuda",
file_name="vit_base_patch16_224_in21k_ft_cifar10.pth",
)
)
model.eval()
DEFAULT_MEAN = (0.485, 0.456, 0.406)
DEFAULT_STD = (0.229, 0.224, 0.225)
transform = transforms.Compose([
transforms.Resize(256, interpolation=3),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(DEFAULT_MEAN, DEFAULT_STD),
])
class CIFAR10_dataset(Dataset):
def __init__(self, data, targets, transform=None):
self.data = data
self.targets = torch.LongTensor(targets)
self.transform = transform
def __getitem__(self, index):
x = Image.fromarray(((self.data[index] * 255).round()).astype(np.uint8).transpose(1, 2, 0))
x = self.transform(x)
y = self.targets[index]
return x, y
def __len__(self):
return len(self.data)
(x_train, y_train), (x_test, y_test), min_pixel_value, max_pixel_value = load_cifar10()
# print(max_pixel_value)
x_train = x_train.transpose(0, 3, 1, 2).astype("float32")
x_test = x_test.transpose(0, 3, 1, 2).astype("float32")
dataset = CIFAR10_dataset(x_train, y_train, transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
opt = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
classifier = PyTorchClassifier(
model=model,
clip_values=(min_pixel_value, max_pixel_value),
loss=criterion,
optimizer=opt,
input_shape=(3, 224, 224),
nb_classes=10,
)
attack = ProjectedGradientDescent(
classifier,
norm=np.inf,
eps=8.0 / 255.0,
eps_step=2.0 / 255.0,
max_iter=10,
targeted=False,
num_random_init=1,
batch_size=64,
verbose=False,
)
trainer = AdversarialTrainer(
classifier, attack
)
art_datagen = PyTorchDataGenerator(iterator=dataloader, size=x_train.shape[0], batch_size=64)
trainer.fit_generator(art_datagen, nb_epochs=100)
for i, data in enumerate(dataloader):
x, y = data
x = x.numpy()
y = y.numpy()
# print(x.shape)
# print(y.shape)
x_test_pred = np.argmax(classifier.predict(x), axis=1)
print(
"Accuracy on benign test samples after adversarial training: %.2f%%"
% (np.sum(x_test_pred == np.argmax(y, axis=1)) / x.shape[0] * 100)
)
trainer.classifier.save('AT-cifar10.pth')