95 lines
3.0 KiB
Python
95 lines
3.0 KiB
Python
import os
|
|
from tqdm import tqdm
|
|
import torch
|
|
|
|
from torch import distributed as dist
|
|
from utils import get_logger, get_summary_writer
|
|
|
|
|
|
class TrainBase(object):
|
|
|
|
def __init__(self,
|
|
args,
|
|
rank=0):
|
|
|
|
self.args = args
|
|
os.makedirs(args.save_dir, exist_ok=True)
|
|
self._init_writer()
|
|
self.logger.info(self.args)
|
|
self.rank = rank
|
|
|
|
self._init_dataset()
|
|
self._init_model()
|
|
|
|
self.global_step = 0
|
|
# self.global_step_t = 0
|
|
self.max_mapi2t = 0
|
|
self.max_mapt2i = 0
|
|
self.best_epoch_i = 0
|
|
self.best_epoch_t = 0
|
|
|
|
def _init_dataset(self):
|
|
self.train_loader = None
|
|
self.query_loader = None
|
|
self.retrieval_loader = None
|
|
|
|
def _init_model(self):
|
|
self.model = None
|
|
self.model_ddp = None
|
|
|
|
def _init_writer(self):
|
|
self.logger = get_logger(os.path.join(self.args.save_dir, "train.log" if self.args.is_train else "test.log"))
|
|
self.writer = get_summary_writer(os.path.join(self.args.save_dir, "tensorboard"))
|
|
|
|
def run(self):
|
|
if self.args.is_train:
|
|
self.train()
|
|
else:
|
|
self.test()
|
|
|
|
def change_state(self, mode):
|
|
|
|
if mode == "train":
|
|
self.model.train()
|
|
elif mode == "valid":
|
|
self.model.eval()
|
|
|
|
def get_code(self, data_loader, length: int):
|
|
|
|
img_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
|
|
text_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
|
|
|
|
for image, text, label, index in tqdm(data_loader):
|
|
image = image.to(self.rank, non_blocking=True)
|
|
text = text.to(self.rank, non_blocking=True)
|
|
index = index.numpy()
|
|
image_hash = self.model.encode_image(image)
|
|
text_hash = self.model.encode_text(text)
|
|
|
|
img_buffer[index, :] = image_hash.data
|
|
text_buffer[index, :] = text_hash.data
|
|
|
|
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
|
|
|
|
def hash_loss(self, a: torch.Tensor):
|
|
return torch.mean(torch.sqrt(torch.sum(torch.pow(torch.sign(a) - a, 2), dim=1))) * 0.5
|
|
|
|
def similarity_loss(self):
|
|
raise NotImplementedError("Function of 'similarity_loss' doesn't implement.")
|
|
|
|
def save_model(self, epoch):
|
|
torch.save(self.model.state_dict(), os.path.join(self.args.save_dir, "model-" + str(epoch) + ".pth"))
|
|
self.logger.info("save mode to {}".format(os.path.join(self.args.save_dir, "model-" + str(epoch) + ".pth")))
|
|
|
|
def train(self):
|
|
raise NotImplementedError("Function of 'train' doesn't implement.")
|
|
|
|
def valid(self):
|
|
raise NotImplementedError("Function of 'valid' doesn't implement.")
|
|
|
|
def test(self):
|
|
raise NotImplementedError("Function of 'test' doesn't implement.")
|
|
|
|
def compute_loss(self):
|
|
raise NotImplementedError("Function of 'compute_loss' doesn't implement.")
|