haloscope/linear_probe.py

628 lines
20 KiB
Python

# from __future__ import print_function
# import os
# import sys
# import argparse
# import time
# import math
# import easydict
# import torch
# import torch.backends.cudnn as cudnn
# import torch.optim as optim
# import torch.nn as nn
# import torch.nn.functional as F
# import numpy as np
# import copy
# from ylib.ytool import ArrayDataset
# cudnn.benchmark = True
# class LinearClassifier(nn.Module):
# """Linear classifier"""
# def __init__(self, feat_dim, num_classes=10):
# super(LinearClassifier, self).__init__()
# self.fc = nn.Linear(feat_dim, 1)
# def forward(self, features):
# return self.fc(features)
# class NonLinearClassifier(nn.Module):
# """Linear classifier"""
# def __init__(self, feat_dim, num_classes=10):
# super(NonLinearClassifier, self).__init__()
# self.fc1 = nn.Linear(feat_dim, 1024)
# # self.fc2 = nn.Linear(1024, 512)
# self.fc3 = nn.Linear(1024, 1)
# def forward(self, features):
# x = F.relu(self.fc1(features))
# # x = F.relu(self.fc2(x))
# x = self.fc3(x)
# return x
# class NormedLinear(nn.Module):
# def __init__(self, in_features, out_features, bn=False):
# super(NormedLinear, self).__init__()
# self.weight = nn.Parameter(torch.Tensor(in_features, out_features))
# self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
# self.bn = bn
# if bn:
# self.bn_layer = nn.BatchNorm1d(out_features)
# def forward(self, x):
# out = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0))
# if self.bn:
# out = self.bn_layer(out)
# return out
# class AverageMeter(object):
# """Computes and stores the average and current value"""
# def __init__(self):
# self.reset()
# def reset(self):
# self.val = 0
# self.avg = 0
# self.sum = 0
# self.count = 0
# def update(self, val, n=1):
# self.val = val
# self.sum += val * n
# self.count += n
# self.avg = self.sum / self.count
# def accuracy(output, target, topk=(1,)):
# """Computes the accuracy over the k top predictions for the specified values of k"""
# with torch.no_grad():
# maxk = max(topk)
# batch_size = target.size(0)
# _, pred = output.topk(maxk, 1, True, True)
# pred = pred.t()
# correct = pred.eq(target.view(1, -1).expand_as(pred))
# res = []
# for k in topk:
# correct_k = correct[:k].flatten().float().sum(0, keepdim=True)
# res.append(correct_k.mul_(100.0 / batch_size))
# return res
# def adjust_learning_rate(args, optimizer, epoch):
# lr = args.learning_rate
# if args.cosine:
# eta_min = lr * (args.lr_decay_rate ** 3)
# lr = eta_min + (lr - eta_min) * (
# 1 + math.cos(math.pi * epoch / args.epochs)) / 2
# else:
# steps = np.sum(epoch > np.asarray(args.lr_decay_epochs))
# if steps > 0:
# lr = lr * (args.lr_decay_rate ** steps)
# for param_group in optimizer.param_groups:
# param_group['lr'] = lr
# def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer):
# if args.warm and epoch <= args.warm_epochs:
# p = (batch_id + (epoch - 1) * total_batches) / \
# (args.warm_epochs * total_batches)
# lr = args.warmup_from + p * (args.warmup_to - args.warmup_from)
# for param_group in optimizer.param_groups:
# param_group['lr'] = lr
# def set_optimizer(opt, model):
# optimizer = optim.SGD(model.parameters(),
# lr=opt.learning_rate,
# momentum=opt.momentum,
# weight_decay=opt.weight_decay)
# return optimizer
# try:
# import apex
# from apex import amp, optimizers
# except ImportError:
# pass
# def train(train_loader, classifier, criterion, optimizer, epoch, print_freq=10):
# """one epoch training"""
# classifier.train()
# batch_time = AverageMeter()
# data_time = AverageMeter()
# losses = AverageMeter()
# top1 = AverageMeter()
# end = time.time()
# for idx, (features, labels) in enumerate(train_loader):
# data_time.update(time.time() - end)
# features = features.cuda(non_blocking=True).float()
# labels = labels.cuda(non_blocking=True).long()
# bsz = labels.shape[0]
# optimizer.zero_grad()
# # warm-up learning rate
# # warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)
# output = classifier(features)
# loss = F.binary_cross_entropy_with_logits(output.view(-1), labels.float())
# # update metric
# losses.update(loss.item(), bsz)
# # acc1, acc5 = accuracy(output, labels, topk=(1, 5))
# # breakpoint()
# correct = (torch.sigmoid(output) > 0.5).long().view(-1).eq(labels.view(-1))
# top1.update(correct.sum() / bsz, bsz)
# # SGD
# loss.backward()
# optimizer.step()
# # measure elapsed time
# batch_time.update(time.time() - end)
# end = time.time()
# #
# # # print info
# if (idx + 1) % print_freq == 0:
# print('Train: [{0}][{1}/{2}]\t'
# 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
# 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
# 'loss {loss.val:.3f} ({loss.avg:.3f})\t'
# 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
# epoch, idx + 1, len(train_loader), batch_time=batch_time,
# data_time=data_time, loss=losses, top1=top1))
# sys.stdout.flush()
# return losses.avg, top1.avg
# def validate(val_loader, classifier, criterion, print_freq):
# """validation"""
# classifier.eval()
# batch_time = AverageMeter()
# losses = AverageMeter()
# top1 = AverageMeter()
# preds = np.array([])
# labels_out = np.array([])
# with torch.no_grad():
# end = time.time()
# for idx, (features, labels) in enumerate(val_loader):
# features = features.float().cuda()
# labels_out = np.append(labels_out, labels)
# labels = labels.long().cuda()
# bsz = labels.shape[0]
# # forward
# # output = classifier(model.encoder(images))
# output = classifier(features.detach())
# loss = F.binary_cross_entropy_with_logits(output.view(-1), labels.float())
# prob = torch.sigmoid(output)
# conf = prob
# pred = (prob>0.5).long().view(-1)
# # conf, pred = prob.max(1)
# preds = np.append(preds, conf.cpu().numpy())
# # update metric
# losses.update(loss.item(), bsz)
# correct = (torch.sigmoid(output) > 0.5).long().view(-1).eq(labels.view(-1))
# top1.update(correct.sum()/bsz, bsz)
# # measure elapsed time
# batch_time.update(time.time() - end)
# end = time.time()
# if (idx + 1) % 200 == 0:
# print('Test: [{0}/{1}]\t'
# 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
# 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
# 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
# idx, len(val_loader), batch_time=batch_time,
# loss=losses, top1=top1))
# # print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))
# return losses.avg, top1.avg, preds, labels_out
# def get_linear_acc(ftrain, ltrain, ftest, ltest, n_cls, epochs=10,
# args=None, classifier=None,
# print_ret=True, normed=False, nonlinear=False,
# learning_rate=5,
# weight_decay=0,
# batch_size=512,
# cosine=False,
# lr_decay_epochs=[30,60,90]):
# cluster2label = np.unique(ltrain)
# label2cluster = {li: ci for ci, li in enumerate(cluster2label)}
# ctrain = [label2cluster[l] for l in ltrain]
# ctest = [label2cluster[l] for l in ltest]
# # breakpoint()
# opt = easydict.EasyDict({
# "lr_decay_rate": 0.2,
# "cosine": cosine,
# "lr_decay_epochs": lr_decay_epochs,
# "start_epoch": 0,
# "learning_rate": learning_rate,
# "epochs": epochs,
# "print_freq": 200,
# "batch_size": batch_size,
# "momentum": 0.9,
# "weight_decay": weight_decay,
# })
# if args is not None:
# for k, v in args.items():
# opt[k] = v
# best_acc = 0
# criterion = torch.nn.CrossEntropyLoss().cuda()
# if classifier is None:
# classifier = LinearClassifier(ftrain.shape[1], num_classes=n_cls).cuda()
# if nonlinear:
# classifier = NonLinearClassifier(ftrain.shape[1], num_classes=n_cls).cuda()
# trainset = ArrayDataset(ftrain, labels=ctrain)
# train_loader = torch.utils.data.DataLoader(trainset, batch_size=opt.batch_size, shuffle=True)
# valset = ArrayDataset(ftest, labels=ctest)
# val_loader = torch.utils.data.DataLoader(valset, batch_size=opt.batch_size, shuffle=False)
# optimizer = set_optimizer(opt, classifier)
# best_preds = None
# best_state = None
# # training routine
# for epoch in range(opt.start_epoch + 1, opt.epochs + 1):
# adjust_learning_rate(opt, optimizer, epoch)
# # train for one epoch
# loss_train, acc = train(train_loader, classifier, criterion, optimizer, epoch, print_freq=opt.print_freq)
# # eval for one epoch
# loss, val_acc, preds, labels_out = validate(val_loader, classifier, criterion, print_freq=opt.print_freq)
# if val_acc > best_acc:
# best_acc = val_acc
# best_preds = preds
# best_state = copy.deepcopy(classifier.state_dict())
# return best_acc.item(), val_acc.item(), (classifier, best_state, best_preds, preds, labels_out), loss_train
# def save_model(model, acc, save_file):
# print('==> Saving...')
# torch.save({
# 'acc': acc,
# 'state_dict': model.state_dict(),
# }, save_file)
from __future__ import print_function
import os
import sys
import argparse
import time
import math
import easydict
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy
from ylib.ytool import ArrayDataset
cudnn.benchmark = True
class LinearClassifier(nn.Module):
"""Linear classifier"""
def __init__(self, feat_dim, num_classes=10):
super(LinearClassifier, self).__init__()
self.fc = nn.Linear(feat_dim, 1)
def forward(self, features):
return self.fc(features)
class NonLinearClassifier(nn.Module):
"""Linear classifier"""
def __init__(self, feat_dim, num_classes=10):
super(NonLinearClassifier, self).__init__()
self.fc1 = nn.Linear(feat_dim, 1024)
# self.fc2 = nn.Linear(1024, 512)
self.fc3 = nn.Linear(1024, 1)
def forward(self, features):
x = F.relu(self.fc1(features))
# x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
class NormedLinear(nn.Module):
def __init__(self, in_features, out_features, bn=False):
super(NormedLinear, self).__init__()
self.weight = nn.Parameter(torch.Tensor(in_features, out_features))
self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
self.bn = bn
if bn:
self.bn_layer = nn.BatchNorm1d(out_features)
def forward(self, x):
out = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0))
if self.bn:
out = self.bn_layer(out)
return out
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].flatten().float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def adjust_learning_rate(args, optimizer, epoch):
lr = args.learning_rate
if args.cosine:
eta_min = lr * (args.lr_decay_rate ** 3)
lr = eta_min + (lr - eta_min) * (
1 + math.cos(math.pi * epoch / args.epochs)) / 2
else:
steps = np.sum(epoch > np.asarray(args.lr_decay_epochs))
if steps > 0:
lr = lr * (args.lr_decay_rate ** steps)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer):
if args.warm and epoch <= args.warm_epochs:
p = (batch_id + (epoch - 1) * total_batches) / \
(args.warm_epochs * total_batches)
lr = args.warmup_from + p * (args.warmup_to - args.warmup_from)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def set_optimizer(opt, model):
optimizer = optim.SGD(model.parameters(),
lr=opt.learning_rate,
momentum=opt.momentum,
weight_decay=opt.weight_decay)
return optimizer
try:
import apex
from apex import amp, optimizers
except ImportError:
pass
def train(train_loader, classifier, criterion, optimizer, epoch, print_freq=10):
"""one epoch training"""
classifier.train()
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
end = time.time()
for idx, (features, labels) in enumerate(train_loader):
data_time.update(time.time() - end)
features = features.cuda(non_blocking=True).float()
labels = labels.cuda(non_blocking=True).long()
bsz = labels.shape[0]
optimizer.zero_grad()
# warm-up learning rate
# warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)
output = classifier(features)
loss = F.binary_cross_entropy_with_logits(output.view(-1), labels.float())
# update metric
losses.update(loss.item(), bsz)
# acc1, acc5 = accuracy(output, labels, topk=(1, 5))
# breakpoint()
correct = (torch.sigmoid(output) > 0.5).long().view(-1).eq(labels.view(-1))
top1.update(correct.sum() / bsz, bsz)
# SGD
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
#
# # print info
if (idx + 1) % print_freq == 0:
print('Train: [{0}][{1}/{2}]\t'
'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
'loss {loss.val:.3f} ({loss.avg:.3f})\t'
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
epoch, idx + 1, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1))
sys.stdout.flush()
return losses.avg, top1.avg
def validate(val_loader, classifier, criterion, print_freq):
"""validation"""
classifier.eval()
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
preds = np.array([])
labels_out = np.array([])
with torch.no_grad():
end = time.time()
for idx, (features, labels) in enumerate(val_loader):
features = features.float().cuda()
labels_out = np.append(labels_out, labels)
labels = labels.long().cuda()
bsz = labels.shape[0]
# forward
# output = classifier(model.encoder(images))
output = classifier(features.detach())
loss = F.binary_cross_entropy_with_logits(output.view(-1), labels.float())
prob = torch.sigmoid(output)
conf = prob
pred = (prob>0.5).long().view(-1)
# conf, pred = prob.max(1)
preds = np.append(preds, conf.cpu().numpy())
# update metric
losses.update(loss.item(), bsz)
correct = (torch.sigmoid(output) > 0.5).long().view(-1).eq(labels.view(-1))
top1.update(correct.sum()/bsz, bsz)
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if (idx + 1) % 200 == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
idx, len(val_loader), batch_time=batch_time,
loss=losses, top1=top1))
# print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))
return losses.avg, top1.avg, preds, labels_out
def get_linear_acc(ftrain, ltrain, ftest, ltest, n_cls, epochs=10,
args=None, classifier=None,
print_ret=True, normed=False, nonlinear=False,
learning_rate=5,
weight_decay=0,
batch_size=512,
cosine=False,
lr_decay_epochs=[30,60,90]):
cluster2label = np.unique(ltrain)
label2cluster = {li: ci for ci, li in enumerate(cluster2label)}
ctrain = [label2cluster[l] for l in ltrain]
ctest = [label2cluster[l] for l in ltest]
# breakpoint()
opt = easydict.EasyDict({
"lr_decay_rate": 0.2,
"cosine": cosine,
"lr_decay_epochs": lr_decay_epochs,
"start_epoch": 0,
"learning_rate": learning_rate,
"epochs": epochs,
"print_freq": 200,
"batch_size": batch_size,
"momentum": 0.9,
"weight_decay": weight_decay,
})
if args is not None:
for k, v in args.items():
opt[k] = v
best_acc = 0
criterion = torch.nn.CrossEntropyLoss().cuda()
if classifier is None:
classifier = LinearClassifier(ftrain.shape[1], num_classes=n_cls).cuda()
if nonlinear:
classifier = NonLinearClassifier(ftrain.shape[1], num_classes=n_cls).cuda()
trainset = ArrayDataset(ftrain, labels=ctrain)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=opt.batch_size, shuffle=True)
valset = ArrayDataset(ftest, labels=ctest)
val_loader = torch.utils.data.DataLoader(valset, batch_size=opt.batch_size, shuffle=False)
optimizer = set_optimizer(opt, classifier)
best_preds = None
best_state = None
# training routine
for epoch in range(opt.start_epoch + 1, opt.epochs + 1):
adjust_learning_rate(opt, optimizer, epoch)
# train for one epoch
loss_train, acc = train(train_loader, classifier, criterion, optimizer, epoch, print_freq=opt.print_freq)
# eval for one epoch
loss, val_acc, preds, labels_out = validate(val_loader, classifier, criterion, print_freq=opt.print_freq)
if val_acc > best_acc:
best_acc = val_acc
best_preds = preds
best_state = copy.deepcopy(classifier.state_dict())
return best_acc.item(), val_acc.item(), (classifier, best_state, best_preds, preds, labels_out), loss_train
def save_model(model, acc, save_file):
print('==> Saving...')
torch.save({
'acc': acc,
'state_dict': model.state_dict(),
}, save_file)