更新 cifar10.py

This commit is contained in:
liwenyun 2024-10-19 19:27:46 +08:00
parent 45ae150791
commit 75d95fea2c
1 changed files with 12 additions and 9 deletions

View File

@ -21,13 +21,15 @@ from art.defences.trainer import AdversarialTrainer
model = timm.create_model("timm/vit_base_patch16_224.orig_in21k_ft_in1k", pretrained=False)
model.head = nn.Linear(model.head.in_features, 10)
model.load_state_dict(
torch.hub.load_state_dict_from_url(
"https://huggingface.co/edadaltocg/vit_base_patch16_224_in21k_ft_cifar10/resolve/main/pytorch_model.bin",
map_location="cuda",
file_name="vit_base_patch16_224_in21k_ft_cifar10.pth",
)
)
state_dict = torch.load('/home/leewlving/.cache/torch/hub/checkpoints/vit_base_patch16_224_in21k_ft_cifar10.pth')
model.load_state_dict(state_dict)
# model.load_state_dict(
# torch.hub.load_state_dict_from_url(
# "https://huggingface.co/edadaltocg/vit_base_patch16_224_in21k_ft_cifar10/resolve/main/pytorch_model.bin",
# map_location="cuda",
# file_name="vit_base_patch16_224_in21k_ft_cifar10.pth",
# )
# )
model.eval()
DEFAULT_MEAN = (0.485, 0.456, 0.406)
@ -98,7 +100,7 @@ trainer = AdversarialTrainer(
)
art_datagen = PyTorchDataGenerator(iterator=dataloader, size=x_train.shape[0], batch_size=64)
trainer.fit_generator(art_datagen, nb_epochs=100)
trainer.fit_generator(art_datagen, nb_epochs=1)
for i, data in enumerate(dataloader):
x, y = data
@ -111,6 +113,7 @@ for i, data in enumerate(dataloader):
"Accuracy on benign test samples after adversarial training: %.2f%%"
% (np.sum(x_test_pred == np.argmax(y, axis=1)) / x.shape[0] * 100)
)
trainer.classifier.save('AT-cifar10.pth')
# trainer.classifier.save('AT-cifar10.pth')
torch.save(trainer.classifier.model.state_dict(), 'AT-cifar10.pth')