From 873bfd0462a3b83e6ccc4d069ef66492263fdd41 Mon Sep 17 00:00:00 2001 From: Li Wenyun Date: Tue, 9 Jul 2024 18:52:41 +0800 Subject: [PATCH] a --- dataset/make_nuswide.py | 107 ---------------------------------------- main.py | 6 +-- run.sh | 10 ++++ utils/get_args.py | 4 +- 4 files changed, 15 insertions(+), 112 deletions(-) delete mode 100644 dataset/make_nuswide.py create mode 100644 run.sh diff --git a/dataset/make_nuswide.py b/dataset/make_nuswide.py deleted file mode 100644 index aa6ec32..0000000 --- a/dataset/make_nuswide.py +++ /dev/null @@ -1,107 +0,0 @@ -import os -import scipy.io as scio -import numpy as np - -# mkdir mat -# mv make_nuswide.py mat -# python make_nuswide.py -root_dir = "PATH/TO/YOUR/DOWNLOAD/DIR/" - - -imageListFile = os.path.join(root_dir, "/Low-Level-Features/ImageList/Imagelist.txt") -labelPath = os.path.join(root_dir, "/nuswide/Groundtruth/AllLabels") -textFile = os.path.join(root_dir, "/Low-Level-Features/NUS_WID_Tags/All_Tags.txt") -classIndexFile = os.path.join(root_dir, "/Low-Level-Features/Concepts81.txt") - -# you can use the image urls to download images -imagePath = os.path.join(root_dir, "nuswide/Flickr") - -with open(imageListFile, "r") as f: - indexs = f.readlines() - -indexs = [os.path.join(imagePath, item.strip().replace("\\", "/")) for item in indexs] -print("indexs length:", len(indexs)) - -#class_index = {} -#with open(classIndexFile, "r") as f: -# data = f.readlines() -# -#for i, item in enumerate(data): -# class_index.update({item.strip(): i}) - -captions = [] -with open(textFile, "r") as f: - for line in f: - if len(line.strip()) == 0: - print("some line empty!") - continue - caption = line.split()[1:] - caption = " ".join(caption).strip() - if len(caption) == 0: - caption = "123456" - captions.append(caption) - -print("captions length:", len(captions)) - -#labels = np.zeros([len(indexs), len(class_index)], dtype=np.int8) -# label_lists = os.listdir(labelPath) -with open(os.path.join(root_dir, "/nuswide/Groundtruth/used_label.txt")) as f: - label_lists = f.readlines() -label_lists = [item.strip() for item in label_lists] - -class_index = {} -for i, item in enumerate(label_lists): - class_index.update({item: i}) - -labels = np.zeros([len(indexs), len(class_index)], dtype=np.int8) - -for item in label_lists: - path = os.path.join(labelPath, item) - class_label = item# .split(".")[0].split("_")[-1] - - with open(path, "r") as f: - data = f.readlines() - for i, val in enumerate(data): - labels[i][class_index[class_label]] = 1 if val.strip() == "1" else 0 -print("labels sum:", labels.sum()) - -not_used_id = [] -with open(os.path.join(root_dir, "/nuswide/Groundtruth/not_used_id.txt")) as f: - not_used_id = f.readlines() -not_used_id = [int(item.strip()) for item in not_used_id] - -# for item in not_used_id: -# indexs.pop(item) -# captions.pop(item) -# labels = np.delete(labels, item, 0) -ind = list(range(len(indexs))) -for item in not_used_id: - ind.remove(item) - indexs[item] = "" - captions[item] = "" -indexs = [item for item in indexs if item != ""] -captions = [item for item in captions if item != ""] -ind = np.asarray(ind) -labels = labels[ind] -# ind = range(len(indexs)) - -print("indexs length:", len(indexs)) -print("captions length:", len(captions)) -print("labels shape:", labels.shape) - -indexs = {"index": indexs} -captions = {"caption": captions} -labels = {"category": labels} - -scio.savemat(os.path.join(root_dir, "/mat/index.mat"), indexs) -# scio.savemat("caption.mat", captions) -scio.savemat(os.path.join(root_dir, "/mat/label.mat"), labels) - - -captions = [item + "\n" for item in captions["caption"]] - -with open(os.path.join(root_dir, "/mat/caption.txt"), "w") as f: - f.writelines(captions) - -print("finished!") - diff --git a/main.py b/main.py index 7192f94..5699186 100644 --- a/main.py +++ b/main.py @@ -1,11 +1,11 @@ -# from train.text_train import Trainer -from train.hash_train import Trainer +from train.text_train import Trainer +# from train.hash_train import Trainer if __name__ == "__main__": engine=Trainer() engine.test() - # engine.train_epoch() + engine.train_epoch() # engine.train() diff --git a/run.sh b/run.sh new file mode 100644 index 0000000..5e9a45c --- /dev/null +++ b/run.sh @@ -0,0 +1,10 @@ +#!/bin/bash +set -e + + +export https_proxy=http://127.0.0.1:7897 http_proxy=http://127.0.0.1:7897 all_proxy=socks5://127.0.0.1:7897 +# CUDA_VISIBLE_DEVICES=0 python main.py --method CSQ --bit 32 +# CUDA_VISIBLE_DEVICES=0 python main.py --method CSQ --bit 64 +CUDA_VISIBLE_DEVICES=0 python main.py --victim ViT-B/16 --output-dim 512 +CUDA_VISIBLE_DEVICES=0 python main.py --victim ViT-B/32 --output-dim 512 +CUDA_VISIBLE_DEVICES=0 python main.py --victim RN101 --output-dim 512 \ No newline at end of file diff --git a/utils/get_args.py b/utils/get_args.py index 3456cbb..1c03abe 100644 --- a/utils/get_args.py +++ b/utils/get_args.py @@ -15,12 +15,12 @@ def get_args(): parser.add_argument("--label-file", type=str, default="label.mat") parser.add_argument("--similarity-function", type=str, default="euclidean", help="choise form [cosine, euclidean]") parser.add_argument("--loss-type", type=str, default="l2", help="choise form [l1, l2]") - parser.add_argument('--victim', default='RN50', choices=['ViT-L/14', 'ViT-B/16', 'ViT-B/32', 'RN50', 'RN101']) + parser.add_argument('--victim', default='ViT-B/16', choices=['ViT-L/14', 'ViT-B/16', 'ViT-B/32', 'RN50', 'RN101']) parser.add_argument("--text_encoder", type=str, default="bert-base-uncased") parser.add_argument("--topk", type=int, default=10) parser.add_argument("--num-perturbation", type=int, default=3) parser.add_argument("--txt-dim", type=int, default=1024) - parser.add_argument("--output-dim", type=int, default=1024) + parser.add_argument("--output-dim", type=int, default=512) parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--max-words", type=int, default=77) parser.add_argument("--max-candidate", type=int, default=7)