This commit is contained in:
Li Wenyun 2024-07-09 18:52:41 +08:00
parent ad01bb4f0e
commit 873bfd0462
4 changed files with 15 additions and 112 deletions

View File

@ -1,107 +0,0 @@
import os
import scipy.io as scio
import numpy as np
# mkdir mat
# mv make_nuswide.py mat
# python make_nuswide.py
root_dir = "PATH/TO/YOUR/DOWNLOAD/DIR/"
imageListFile = os.path.join(root_dir, "/Low-Level-Features/ImageList/Imagelist.txt")
labelPath = os.path.join(root_dir, "/nuswide/Groundtruth/AllLabels")
textFile = os.path.join(root_dir, "/Low-Level-Features/NUS_WID_Tags/All_Tags.txt")
classIndexFile = os.path.join(root_dir, "/Low-Level-Features/Concepts81.txt")
# you can use the image urls to download images
imagePath = os.path.join(root_dir, "nuswide/Flickr")
with open(imageListFile, "r") as f:
indexs = f.readlines()
indexs = [os.path.join(imagePath, item.strip().replace("\\", "/")) for item in indexs]
print("indexs length:", len(indexs))
#class_index = {}
#with open(classIndexFile, "r") as f:
# data = f.readlines()
#
#for i, item in enumerate(data):
# class_index.update({item.strip(): i})
captions = []
with open(textFile, "r") as f:
for line in f:
if len(line.strip()) == 0:
print("some line empty!")
continue
caption = line.split()[1:]
caption = " ".join(caption).strip()
if len(caption) == 0:
caption = "123456"
captions.append(caption)
print("captions length:", len(captions))
#labels = np.zeros([len(indexs), len(class_index)], dtype=np.int8)
# label_lists = os.listdir(labelPath)
with open(os.path.join(root_dir, "/nuswide/Groundtruth/used_label.txt")) as f:
label_lists = f.readlines()
label_lists = [item.strip() for item in label_lists]
class_index = {}
for i, item in enumerate(label_lists):
class_index.update({item: i})
labels = np.zeros([len(indexs), len(class_index)], dtype=np.int8)
for item in label_lists:
path = os.path.join(labelPath, item)
class_label = item# .split(".")[0].split("_")[-1]
with open(path, "r") as f:
data = f.readlines()
for i, val in enumerate(data):
labels[i][class_index[class_label]] = 1 if val.strip() == "1" else 0
print("labels sum:", labels.sum())
not_used_id = []
with open(os.path.join(root_dir, "/nuswide/Groundtruth/not_used_id.txt")) as f:
not_used_id = f.readlines()
not_used_id = [int(item.strip()) for item in not_used_id]
# for item in not_used_id:
# indexs.pop(item)
# captions.pop(item)
# labels = np.delete(labels, item, 0)
ind = list(range(len(indexs)))
for item in not_used_id:
ind.remove(item)
indexs[item] = ""
captions[item] = ""
indexs = [item for item in indexs if item != ""]
captions = [item for item in captions if item != ""]
ind = np.asarray(ind)
labels = labels[ind]
# ind = range(len(indexs))
print("indexs length:", len(indexs))
print("captions length:", len(captions))
print("labels shape:", labels.shape)
indexs = {"index": indexs}
captions = {"caption": captions}
labels = {"category": labels}
scio.savemat(os.path.join(root_dir, "/mat/index.mat"), indexs)
# scio.savemat("caption.mat", captions)
scio.savemat(os.path.join(root_dir, "/mat/label.mat"), labels)
captions = [item + "\n" for item in captions["caption"]]
with open(os.path.join(root_dir, "/mat/caption.txt"), "w") as f:
f.writelines(captions)
print("finished!")

View File

@ -1,11 +1,11 @@
# from train.text_train import Trainer
from train.hash_train import Trainer
from train.text_train import Trainer
# from train.hash_train import Trainer
if __name__ == "__main__":
engine=Trainer()
engine.test()
# engine.train_epoch()
engine.train_epoch()
# engine.train()

10
run.sh Normal file
View File

@ -0,0 +1,10 @@
#!/bin/bash
set -e
export https_proxy=http://127.0.0.1:7897 http_proxy=http://127.0.0.1:7897 all_proxy=socks5://127.0.0.1:7897
# CUDA_VISIBLE_DEVICES=0 python main.py --method CSQ --bit 32
# CUDA_VISIBLE_DEVICES=0 python main.py --method CSQ --bit 64
CUDA_VISIBLE_DEVICES=0 python main.py --victim ViT-B/16 --output-dim 512
CUDA_VISIBLE_DEVICES=0 python main.py --victim ViT-B/32 --output-dim 512
CUDA_VISIBLE_DEVICES=0 python main.py --victim RN101 --output-dim 512

View File

@ -15,12 +15,12 @@ def get_args():
parser.add_argument("--label-file", type=str, default="label.mat")
parser.add_argument("--similarity-function", type=str, default="euclidean", help="choise form [cosine, euclidean]")
parser.add_argument("--loss-type", type=str, default="l2", help="choise form [l1, l2]")
parser.add_argument('--victim', default='RN50', choices=['ViT-L/14', 'ViT-B/16', 'ViT-B/32', 'RN50', 'RN101'])
parser.add_argument('--victim', default='ViT-B/16', choices=['ViT-L/14', 'ViT-B/16', 'ViT-B/32', 'RN50', 'RN101'])
parser.add_argument("--text_encoder", type=str, default="bert-base-uncased")
parser.add_argument("--topk", type=int, default=10)
parser.add_argument("--num-perturbation", type=int, default=3)
parser.add_argument("--txt-dim", type=int, default=1024)
parser.add_argument("--output-dim", type=int, default=1024)
parser.add_argument("--output-dim", type=int, default=512)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--max-words", type=int, default=77)
parser.add_argument("--max-candidate", type=int, default=7)