This commit is contained in:
parent
ad01bb4f0e
commit
873bfd0462
|
|
@ -1,107 +0,0 @@
|
||||||
import os
|
|
||||||
import scipy.io as scio
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
# mkdir mat
|
|
||||||
# mv make_nuswide.py mat
|
|
||||||
# python make_nuswide.py
|
|
||||||
root_dir = "PATH/TO/YOUR/DOWNLOAD/DIR/"
|
|
||||||
|
|
||||||
|
|
||||||
imageListFile = os.path.join(root_dir, "/Low-Level-Features/ImageList/Imagelist.txt")
|
|
||||||
labelPath = os.path.join(root_dir, "/nuswide/Groundtruth/AllLabels")
|
|
||||||
textFile = os.path.join(root_dir, "/Low-Level-Features/NUS_WID_Tags/All_Tags.txt")
|
|
||||||
classIndexFile = os.path.join(root_dir, "/Low-Level-Features/Concepts81.txt")
|
|
||||||
|
|
||||||
# you can use the image urls to download images
|
|
||||||
imagePath = os.path.join(root_dir, "nuswide/Flickr")
|
|
||||||
|
|
||||||
with open(imageListFile, "r") as f:
|
|
||||||
indexs = f.readlines()
|
|
||||||
|
|
||||||
indexs = [os.path.join(imagePath, item.strip().replace("\\", "/")) for item in indexs]
|
|
||||||
print("indexs length:", len(indexs))
|
|
||||||
|
|
||||||
#class_index = {}
|
|
||||||
#with open(classIndexFile, "r") as f:
|
|
||||||
# data = f.readlines()
|
|
||||||
#
|
|
||||||
#for i, item in enumerate(data):
|
|
||||||
# class_index.update({item.strip(): i})
|
|
||||||
|
|
||||||
captions = []
|
|
||||||
with open(textFile, "r") as f:
|
|
||||||
for line in f:
|
|
||||||
if len(line.strip()) == 0:
|
|
||||||
print("some line empty!")
|
|
||||||
continue
|
|
||||||
caption = line.split()[1:]
|
|
||||||
caption = " ".join(caption).strip()
|
|
||||||
if len(caption) == 0:
|
|
||||||
caption = "123456"
|
|
||||||
captions.append(caption)
|
|
||||||
|
|
||||||
print("captions length:", len(captions))
|
|
||||||
|
|
||||||
#labels = np.zeros([len(indexs), len(class_index)], dtype=np.int8)
|
|
||||||
# label_lists = os.listdir(labelPath)
|
|
||||||
with open(os.path.join(root_dir, "/nuswide/Groundtruth/used_label.txt")) as f:
|
|
||||||
label_lists = f.readlines()
|
|
||||||
label_lists = [item.strip() for item in label_lists]
|
|
||||||
|
|
||||||
class_index = {}
|
|
||||||
for i, item in enumerate(label_lists):
|
|
||||||
class_index.update({item: i})
|
|
||||||
|
|
||||||
labels = np.zeros([len(indexs), len(class_index)], dtype=np.int8)
|
|
||||||
|
|
||||||
for item in label_lists:
|
|
||||||
path = os.path.join(labelPath, item)
|
|
||||||
class_label = item# .split(".")[0].split("_")[-1]
|
|
||||||
|
|
||||||
with open(path, "r") as f:
|
|
||||||
data = f.readlines()
|
|
||||||
for i, val in enumerate(data):
|
|
||||||
labels[i][class_index[class_label]] = 1 if val.strip() == "1" else 0
|
|
||||||
print("labels sum:", labels.sum())
|
|
||||||
|
|
||||||
not_used_id = []
|
|
||||||
with open(os.path.join(root_dir, "/nuswide/Groundtruth/not_used_id.txt")) as f:
|
|
||||||
not_used_id = f.readlines()
|
|
||||||
not_used_id = [int(item.strip()) for item in not_used_id]
|
|
||||||
|
|
||||||
# for item in not_used_id:
|
|
||||||
# indexs.pop(item)
|
|
||||||
# captions.pop(item)
|
|
||||||
# labels = np.delete(labels, item, 0)
|
|
||||||
ind = list(range(len(indexs)))
|
|
||||||
for item in not_used_id:
|
|
||||||
ind.remove(item)
|
|
||||||
indexs[item] = ""
|
|
||||||
captions[item] = ""
|
|
||||||
indexs = [item for item in indexs if item != ""]
|
|
||||||
captions = [item for item in captions if item != ""]
|
|
||||||
ind = np.asarray(ind)
|
|
||||||
labels = labels[ind]
|
|
||||||
# ind = range(len(indexs))
|
|
||||||
|
|
||||||
print("indexs length:", len(indexs))
|
|
||||||
print("captions length:", len(captions))
|
|
||||||
print("labels shape:", labels.shape)
|
|
||||||
|
|
||||||
indexs = {"index": indexs}
|
|
||||||
captions = {"caption": captions}
|
|
||||||
labels = {"category": labels}
|
|
||||||
|
|
||||||
scio.savemat(os.path.join(root_dir, "/mat/index.mat"), indexs)
|
|
||||||
# scio.savemat("caption.mat", captions)
|
|
||||||
scio.savemat(os.path.join(root_dir, "/mat/label.mat"), labels)
|
|
||||||
|
|
||||||
|
|
||||||
captions = [item + "\n" for item in captions["caption"]]
|
|
||||||
|
|
||||||
with open(os.path.join(root_dir, "/mat/caption.txt"), "w") as f:
|
|
||||||
f.writelines(captions)
|
|
||||||
|
|
||||||
print("finished!")
|
|
||||||
|
|
||||||
6
main.py
6
main.py
|
|
@ -1,11 +1,11 @@
|
||||||
# from train.text_train import Trainer
|
from train.text_train import Trainer
|
||||||
from train.hash_train import Trainer
|
# from train.hash_train import Trainer
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
engine=Trainer()
|
engine=Trainer()
|
||||||
engine.test()
|
engine.test()
|
||||||
# engine.train_epoch()
|
engine.train_epoch()
|
||||||
|
|
||||||
|
|
||||||
# engine.train()
|
# engine.train()
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,10 @@
|
||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
|
||||||
|
export https_proxy=http://127.0.0.1:7897 http_proxy=http://127.0.0.1:7897 all_proxy=socks5://127.0.0.1:7897
|
||||||
|
# CUDA_VISIBLE_DEVICES=0 python main.py --method CSQ --bit 32
|
||||||
|
# CUDA_VISIBLE_DEVICES=0 python main.py --method CSQ --bit 64
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python main.py --victim ViT-B/16 --output-dim 512
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python main.py --victim ViT-B/32 --output-dim 512
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python main.py --victim RN101 --output-dim 512
|
||||||
|
|
@ -15,12 +15,12 @@ def get_args():
|
||||||
parser.add_argument("--label-file", type=str, default="label.mat")
|
parser.add_argument("--label-file", type=str, default="label.mat")
|
||||||
parser.add_argument("--similarity-function", type=str, default="euclidean", help="choise form [cosine, euclidean]")
|
parser.add_argument("--similarity-function", type=str, default="euclidean", help="choise form [cosine, euclidean]")
|
||||||
parser.add_argument("--loss-type", type=str, default="l2", help="choise form [l1, l2]")
|
parser.add_argument("--loss-type", type=str, default="l2", help="choise form [l1, l2]")
|
||||||
parser.add_argument('--victim', default='RN50', choices=['ViT-L/14', 'ViT-B/16', 'ViT-B/32', 'RN50', 'RN101'])
|
parser.add_argument('--victim', default='ViT-B/16', choices=['ViT-L/14', 'ViT-B/16', 'ViT-B/32', 'RN50', 'RN101'])
|
||||||
parser.add_argument("--text_encoder", type=str, default="bert-base-uncased")
|
parser.add_argument("--text_encoder", type=str, default="bert-base-uncased")
|
||||||
parser.add_argument("--topk", type=int, default=10)
|
parser.add_argument("--topk", type=int, default=10)
|
||||||
parser.add_argument("--num-perturbation", type=int, default=3)
|
parser.add_argument("--num-perturbation", type=int, default=3)
|
||||||
parser.add_argument("--txt-dim", type=int, default=1024)
|
parser.add_argument("--txt-dim", type=int, default=1024)
|
||||||
parser.add_argument("--output-dim", type=int, default=1024)
|
parser.add_argument("--output-dim", type=int, default=512)
|
||||||
parser.add_argument("--epochs", type=int, default=100)
|
parser.add_argument("--epochs", type=int, default=100)
|
||||||
parser.add_argument("--max-words", type=int, default=77)
|
parser.add_argument("--max-words", type=int, default=77)
|
||||||
parser.add_argument("--max-candidate", type=int, default=7)
|
parser.add_argument("--max-candidate", type=int, default=7)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue