advclip/train/base.py

95 lines
3.0 KiB
Python

import os
from tqdm import tqdm
import torch
from torch import distributed as dist
from utils import get_logger, get_summary_writer
class TrainBase(object):
def __init__(self,
args,
rank=0):
self.args = args
os.makedirs(args.save_dir, exist_ok=True)
self._init_writer()
self.logger.info(self.args)
self.rank = rank
self._init_dataset()
self._init_model()
self.global_step = 0
# self.global_step_t = 0
self.max_mapi2t = 0
self.max_mapt2i = 0
self.best_epoch_i = 0
self.best_epoch_t = 0
def _init_dataset(self):
self.train_loader = None
self.query_loader = None
self.retrieval_loader = None
def _init_model(self):
self.model = None
self.model_ddp = None
def _init_writer(self):
self.logger = get_logger(os.path.join(self.args.save_dir, "train.log" if self.args.is_train else "test.log"))
self.writer = get_summary_writer(os.path.join(self.args.save_dir, "tensorboard"))
def run(self):
if self.args.is_train:
self.train()
else:
self.test()
def change_state(self, mode):
if mode == "train":
self.model.train()
elif mode == "valid":
self.model.eval()
def get_code(self, data_loader, length: int):
img_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
text_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
for image, text, label, index in tqdm(data_loader):
image = image.to(self.rank, non_blocking=True)
text = text.to(self.rank, non_blocking=True)
index = index.numpy()
image_hash = self.model.encode_image(image)
text_hash = self.model.encode_text(text)
img_buffer[index, :] = image_hash.data
text_buffer[index, :] = text_hash.data
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
def hash_loss(self, a: torch.Tensor):
return torch.mean(torch.sqrt(torch.sum(torch.pow(torch.sign(a) - a, 2), dim=1))) * 0.5
def similarity_loss(self):
raise NotImplementedError("Function of 'similarity_loss' doesn't implement.")
def save_model(self, epoch):
torch.save(self.model.state_dict(), os.path.join(self.args.save_dir, "model-" + str(epoch) + ".pth"))
self.logger.info("save mode to {}".format(os.path.join(self.args.save_dir, "model-" + str(epoch) + ".pth")))
def train(self):
raise NotImplementedError("Function of 'train' doesn't implement.")
def valid(self):
raise NotImplementedError("Function of 'valid' doesn't implement.")
def test(self):
raise NotImplementedError("Function of 'test' doesn't implement.")
def compute_loss(self):
raise NotImplementedError("Function of 'compute_loss' doesn't implement.")