更新 cifar10.py
This commit is contained in:
parent
45ae150791
commit
75d95fea2c
21
cifar10.py
21
cifar10.py
|
|
@ -21,13 +21,15 @@ from art.defences.trainer import AdversarialTrainer
|
|||
|
||||
model = timm.create_model("timm/vit_base_patch16_224.orig_in21k_ft_in1k", pretrained=False)
|
||||
model.head = nn.Linear(model.head.in_features, 10)
|
||||
model.load_state_dict(
|
||||
torch.hub.load_state_dict_from_url(
|
||||
"https://huggingface.co/edadaltocg/vit_base_patch16_224_in21k_ft_cifar10/resolve/main/pytorch_model.bin",
|
||||
map_location="cuda",
|
||||
file_name="vit_base_patch16_224_in21k_ft_cifar10.pth",
|
||||
)
|
||||
)
|
||||
state_dict = torch.load('/home/leewlving/.cache/torch/hub/checkpoints/vit_base_patch16_224_in21k_ft_cifar10.pth')
|
||||
model.load_state_dict(state_dict)
|
||||
# model.load_state_dict(
|
||||
# torch.hub.load_state_dict_from_url(
|
||||
# "https://huggingface.co/edadaltocg/vit_base_patch16_224_in21k_ft_cifar10/resolve/main/pytorch_model.bin",
|
||||
# map_location="cuda",
|
||||
# file_name="vit_base_patch16_224_in21k_ft_cifar10.pth",
|
||||
# )
|
||||
# )
|
||||
model.eval()
|
||||
|
||||
DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
||||
|
|
@ -98,7 +100,7 @@ trainer = AdversarialTrainer(
|
|||
)
|
||||
art_datagen = PyTorchDataGenerator(iterator=dataloader, size=x_train.shape[0], batch_size=64)
|
||||
|
||||
trainer.fit_generator(art_datagen, nb_epochs=100)
|
||||
trainer.fit_generator(art_datagen, nb_epochs=1)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
x, y = data
|
||||
|
|
@ -111,6 +113,7 @@ for i, data in enumerate(dataloader):
|
|||
"Accuracy on benign test samples after adversarial training: %.2f%%"
|
||||
% (np.sum(x_test_pred == np.argmax(y, axis=1)) / x.shape[0] * 100)
|
||||
)
|
||||
trainer.classifier.save('AT-cifar10.pth')
|
||||
|
||||
# trainer.classifier.save('AT-cifar10.pth')
|
||||
torch.save(trainer.classifier.model.state_dict(), 'AT-cifar10.pth')
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue