This commit is contained in:
parent
c37a5314a6
commit
ce7e5c711b
|
|
@ -457,9 +457,11 @@ def main():
|
||||||
if feat_loc == 3:
|
if feat_loc == 3:
|
||||||
embed_generated_wild = embed_generated[feat_indices_wild][:,1:,:]
|
embed_generated_wild = embed_generated[feat_indices_wild][:,1:,:]
|
||||||
embed_generated_eval = embed_generated[feat_indices_eval][:, 1:, :]
|
embed_generated_eval = embed_generated[feat_indices_eval][:, 1:, :]
|
||||||
|
embed_generated_hal,embed_generated_tru=embed_generated_h[feat_indices_wild][:,1:,:], embed_generated_t[feat_indices_wild][:,1:,:]
|
||||||
else:
|
else:
|
||||||
embed_generated_wild = embed_generated[feat_indices_wild]
|
embed_generated_wild = embed_generated[feat_indices_wild]
|
||||||
embed_generated_eval = embed_generated[feat_indices_eval]
|
embed_generated_eval = embed_generated[feat_indices_eval]
|
||||||
|
embed_generated_hal,embed_generated_tru=embed_generated_h[feat_indices_wild], embed_generated_t[feat_indices_wild]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -468,7 +470,7 @@ def main():
|
||||||
# returned_results = svd_embed_score(embed_generated_wild, gt_label_wild,
|
# returned_results = svd_embed_score(embed_generated_wild, gt_label_wild,
|
||||||
# 1, 11, mean=0, svd=0, weight=args.weighted_svd)
|
# 1, 11, mean=0, svd=0, weight=args.weighted_svd)
|
||||||
# get the best hyper-parameters on validation set
|
# get the best hyper-parameters on validation set
|
||||||
returned_results = svd_embed_score(embed_generated_eval, gt_label_val,
|
returned_results = svd_embed_score(embed_generated_eval, gt_label_val, embed_generated_hal,embed_generated_tru,
|
||||||
1, 11, mean=1, svd=1, wei1ght=args.weighted_svd)
|
1, 11, mean=1, svd=1, wei1ght=args.weighted_svd)
|
||||||
|
|
||||||
pca_model = PCA(n_components=returned_results['k'], whiten=False).fit(embed_generated_wild[:,returned_results['best_layer'],:])
|
pca_model = PCA(n_components=returned_results['k'], whiten=False).fit(embed_generated_wild[:,returned_results['best_layer'],:])
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from tqdm import tqdm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pickle
|
import pickle
|
||||||
# from utils import get_llama_activations_bau, tokenized_tqa, tokenized_tqa_gen, tokenized_tqa_gen_end_q
|
# from utils import get_llama_activations_bau, tokenized_tqa, tokenized_tqa_gen, tokenized_tqa_gen_end_q
|
||||||
from utils import get_hal_prompt, get_qa_prompt
|
from utils import get_hal_prompt, get_qa_prompt, get_truth_prompt
|
||||||
import llama_iti
|
import llama_iti
|
||||||
import pickle
|
import pickle
|
||||||
import argparse
|
import argparse
|
||||||
|
|
@ -196,14 +196,23 @@ def main():
|
||||||
if not os.path.exists(f'./save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/hallucinations'):
|
if not os.path.exists(f'./save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/hallucinations'):
|
||||||
os.mkdir(f'./save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/hallucinations')
|
os.mkdir(f'./save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/hallucinations')
|
||||||
|
|
||||||
|
if not os.path.exists(f'./save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/truths'):
|
||||||
|
os.mkdir(f'./save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/truths')
|
||||||
|
|
||||||
|
|
||||||
for i in range(begin_index, end_index):
|
for i in range(begin_index, end_index):
|
||||||
answers = [None] * args.num_gene
|
answers = [None] * args.num_gene
|
||||||
hallucinations= [None] * args.num_gene
|
hallucinations= [None] * args.num_gene
|
||||||
|
truths = [None] * args.num_gene
|
||||||
if args.dataset_name == 'tydiqa':
|
if args.dataset_name == 'tydiqa':
|
||||||
question = dataset[int(used_indices[i])]['question']
|
question = dataset[int(used_indices[i])]['question']
|
||||||
prompt = get_qa_prompt(dataset[int(used_indices[i])]['context'],question)
|
prompt = get_qa_prompt(dataset[int(used_indices[i])]['context'],question)
|
||||||
hallucination_prompt=get_hal_prompt(dataset[int(used_indices[i])]['context'],question,instruction)
|
hallucination_prompt=get_hal_prompt(dataset[int(used_indices[i])]['context'],question,instruction)
|
||||||
|
truth_prompt=get_truth_prompt(dataset[int(used_indices[i])]['context'],question)
|
||||||
|
elif args.dataset_name == 'triviaqa':
|
||||||
|
prompt = get_qa_prompt("None",dataset[i]['prompt'])
|
||||||
|
hallucination_prompt=get_hal_prompt("None",dataset[i]['prompt'],instruction)
|
||||||
|
truth_prompt=get_truth_prompt("None",question)
|
||||||
elif args.dataset_name == 'coqa':
|
elif args.dataset_name == 'coqa':
|
||||||
prompt = get_qa_prompt("None",dataset[i]['prompt'])
|
prompt = get_qa_prompt("None",dataset[i]['prompt'])
|
||||||
hallucination_prompt=get_hal_prompt("None",dataset[i]['prompt'],instruction)
|
hallucination_prompt=get_hal_prompt("None",dataset[i]['prompt'],instruction)
|
||||||
|
|
@ -217,24 +226,33 @@ def main():
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
model = args.model_name,
|
model = args.model_name,
|
||||||
messages = prompt,
|
messages = prompt,
|
||||||
max_tokens=256,
|
# max_tokens=256,
|
||||||
top_p=1,
|
top_p=1,
|
||||||
temperature = 1,
|
temperature = 1,
|
||||||
)
|
)
|
||||||
hallucination_response = client.chat.completions.create(
|
hallucination_response = client.chat.completions.create(
|
||||||
model = args.model_name,
|
model = args.model_name,
|
||||||
messages = hallucination_prompt,
|
messages = hallucination_prompt,
|
||||||
max_tokens=256,
|
# max_tokens=256,
|
||||||
top_p=1,
|
top_p=1,
|
||||||
temperature = 1,
|
temperature = 1,
|
||||||
)
|
)
|
||||||
|
if args.dataset_name == 'tydiqa' or args.dataset_name == 'tydiqa':
|
||||||
|
truth_response=client.chat.completions.create(
|
||||||
|
model = args.model_name,
|
||||||
|
messages = truth_prompt,
|
||||||
|
# max_tokens=256,
|
||||||
|
top_p=1,
|
||||||
|
temperature=1
|
||||||
|
)
|
||||||
|
truth_decoded=truth_response.choices[0].message.content
|
||||||
decoded=response.choices[0].message.content
|
decoded=response.choices[0].message.content
|
||||||
hallucination_decoded=hallucination_response.choices[0].message.content
|
hallucination_decoded=hallucination_response.choices[0].message.content
|
||||||
else:
|
else:
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
model = args.model_name,
|
model = args.model_name,
|
||||||
messages = prompt,
|
messages = prompt,
|
||||||
max_tokens=256,
|
# max_tokens=256,
|
||||||
n=1,
|
n=1,
|
||||||
# best_of=1,
|
# best_of=1,
|
||||||
top_p=0.5,
|
top_p=0.5,
|
||||||
|
|
@ -250,6 +268,14 @@ def main():
|
||||||
top_p=0.5,
|
top_p=0.5,
|
||||||
temperature = 0.5,
|
temperature = 0.5,
|
||||||
)
|
)
|
||||||
|
if args.dataset_name == 'tydiqa' or args.dataset_name == 'tydiqa':
|
||||||
|
truth_response=client.chat.completions.create(
|
||||||
|
model = args.model_name,
|
||||||
|
messages = truth_prompt,
|
||||||
|
top_p=0.5,
|
||||||
|
temperature = 0.5,
|
||||||
|
)
|
||||||
|
truth_decoded=truth_response.choices[0].message.content
|
||||||
decoded=response.choices[0].message.content
|
decoded=response.choices[0].message.content
|
||||||
hallucination_decoded=hallucination_response.choices[0].message.content
|
hallucination_decoded=hallucination_response.choices[0].message.content
|
||||||
time.sleep(20)
|
time.sleep(20)
|
||||||
|
|
@ -270,7 +296,28 @@ def main():
|
||||||
hallucination_decoded = hallucination_decoded.split('Q:')[0]
|
hallucination_decoded = hallucination_decoded.split('Q:')[0]
|
||||||
answers[gen_iter] = decoded
|
answers[gen_iter] = decoded
|
||||||
hallucinations[gen_iter]=hallucination_decoded
|
hallucinations[gen_iter]=hallucination_decoded
|
||||||
|
if args.dataset_name == 'tydiqa' or args.dataset_name == 'tydiqa':
|
||||||
|
truths[gen_iter]=truth_decoded
|
||||||
|
|
||||||
|
|
||||||
|
# if args.dataset_name == 'tydiqa':
|
||||||
|
# pass
|
||||||
|
# elif args.dataset_name == 'triviaqa':
|
||||||
|
# pass
|
||||||
|
if args.dataset_name == 'coqa':
|
||||||
|
truths[0]=dataset[i]['answer']
|
||||||
|
if args.num_gene >1 and dataset[i]['additional_answers']>= args.num_gene-1:
|
||||||
|
left_truth=dataset[i]['additional_answers'][:args.num_gene-1]
|
||||||
|
truths=truths+left_truth
|
||||||
|
elif args.dataset_name == 'tqa':
|
||||||
|
truths[0]=dataset[i]['Best Answer']
|
||||||
|
if args.num_gene >1:
|
||||||
|
correct=dataset[i]['Correct Answers'].split(";")
|
||||||
|
if len(correct) >= args.num_gene-1:
|
||||||
|
left_truth=correct[:args.num_gene-1]
|
||||||
|
truths=truths+left_truth
|
||||||
|
else:
|
||||||
|
assert 'Not supported dataset!'
|
||||||
|
|
||||||
print('sample: ', i)
|
print('sample: ', i)
|
||||||
if args.most_likely:
|
if args.most_likely:
|
||||||
|
|
@ -283,6 +330,9 @@ def main():
|
||||||
print("Saving hallucinations")
|
print("Saving hallucinations")
|
||||||
np.save(f'./save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/hallucinations/' + info + f'hal_det_{args.model_name}_{args.dataset_name}_hallucinations_index_{i}.npy',
|
np.save(f'./save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/hallucinations/' + info + f'hal_det_{args.model_name}_{args.dataset_name}_hallucinations_index_{i}.npy',
|
||||||
hallucinations)
|
hallucinations)
|
||||||
|
print("Saving truths")
|
||||||
|
np.save(f'./save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/truths/' + info + f'hal_det_{args.model_name}_{args.dataset_name}_truths_index_{i}.npy',
|
||||||
|
truths)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
tokenizer = llama_iti.LlamaTokenizer.from_pretrained(MODEL, trust_remote_code=True)
|
tokenizer = llama_iti.LlamaTokenizer.from_pretrained(MODEL, trust_remote_code=True)
|
||||||
|
|
|
||||||
315
linear_probe.py
315
linear_probe.py
|
|
@ -1,3 +1,317 @@
|
||||||
|
# from __future__ import print_function
|
||||||
|
|
||||||
|
# import os
|
||||||
|
# import sys
|
||||||
|
# import argparse
|
||||||
|
# import time
|
||||||
|
# import math
|
||||||
|
|
||||||
|
# import easydict
|
||||||
|
# import torch
|
||||||
|
# import torch.backends.cudnn as cudnn
|
||||||
|
# import torch.optim as optim
|
||||||
|
# import torch.nn as nn
|
||||||
|
# import torch.nn.functional as F
|
||||||
|
# import numpy as np
|
||||||
|
# import copy
|
||||||
|
|
||||||
|
# from ylib.ytool import ArrayDataset
|
||||||
|
# cudnn.benchmark = True
|
||||||
|
|
||||||
|
# class LinearClassifier(nn.Module):
|
||||||
|
# """Linear classifier"""
|
||||||
|
# def __init__(self, feat_dim, num_classes=10):
|
||||||
|
# super(LinearClassifier, self).__init__()
|
||||||
|
# self.fc = nn.Linear(feat_dim, 1)
|
||||||
|
|
||||||
|
# def forward(self, features):
|
||||||
|
# return self.fc(features)
|
||||||
|
|
||||||
|
# class NonLinearClassifier(nn.Module):
|
||||||
|
# """Linear classifier"""
|
||||||
|
# def __init__(self, feat_dim, num_classes=10):
|
||||||
|
# super(NonLinearClassifier, self).__init__()
|
||||||
|
# self.fc1 = nn.Linear(feat_dim, 1024)
|
||||||
|
# # self.fc2 = nn.Linear(1024, 512)
|
||||||
|
# self.fc3 = nn.Linear(1024, 1)
|
||||||
|
|
||||||
|
# def forward(self, features):
|
||||||
|
# x = F.relu(self.fc1(features))
|
||||||
|
# # x = F.relu(self.fc2(x))
|
||||||
|
# x = self.fc3(x)
|
||||||
|
# return x
|
||||||
|
|
||||||
|
|
||||||
|
# class NormedLinear(nn.Module):
|
||||||
|
|
||||||
|
# def __init__(self, in_features, out_features, bn=False):
|
||||||
|
# super(NormedLinear, self).__init__()
|
||||||
|
# self.weight = nn.Parameter(torch.Tensor(in_features, out_features))
|
||||||
|
# self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
|
||||||
|
# self.bn = bn
|
||||||
|
# if bn:
|
||||||
|
# self.bn_layer = nn.BatchNorm1d(out_features)
|
||||||
|
|
||||||
|
# def forward(self, x):
|
||||||
|
# out = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0))
|
||||||
|
# if self.bn:
|
||||||
|
# out = self.bn_layer(out)
|
||||||
|
# return out
|
||||||
|
|
||||||
|
# class AverageMeter(object):
|
||||||
|
# """Computes and stores the average and current value"""
|
||||||
|
# def __init__(self):
|
||||||
|
# self.reset()
|
||||||
|
|
||||||
|
# def reset(self):
|
||||||
|
# self.val = 0
|
||||||
|
# self.avg = 0
|
||||||
|
# self.sum = 0
|
||||||
|
# self.count = 0
|
||||||
|
|
||||||
|
# def update(self, val, n=1):
|
||||||
|
# self.val = val
|
||||||
|
# self.sum += val * n
|
||||||
|
# self.count += n
|
||||||
|
# self.avg = self.sum / self.count
|
||||||
|
|
||||||
|
|
||||||
|
# def accuracy(output, target, topk=(1,)):
|
||||||
|
# """Computes the accuracy over the k top predictions for the specified values of k"""
|
||||||
|
# with torch.no_grad():
|
||||||
|
# maxk = max(topk)
|
||||||
|
# batch_size = target.size(0)
|
||||||
|
|
||||||
|
# _, pred = output.topk(maxk, 1, True, True)
|
||||||
|
# pred = pred.t()
|
||||||
|
# correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||||
|
|
||||||
|
# res = []
|
||||||
|
# for k in topk:
|
||||||
|
# correct_k = correct[:k].flatten().float().sum(0, keepdim=True)
|
||||||
|
# res.append(correct_k.mul_(100.0 / batch_size))
|
||||||
|
# return res
|
||||||
|
|
||||||
|
|
||||||
|
# def adjust_learning_rate(args, optimizer, epoch):
|
||||||
|
# lr = args.learning_rate
|
||||||
|
# if args.cosine:
|
||||||
|
# eta_min = lr * (args.lr_decay_rate ** 3)
|
||||||
|
# lr = eta_min + (lr - eta_min) * (
|
||||||
|
# 1 + math.cos(math.pi * epoch / args.epochs)) / 2
|
||||||
|
# else:
|
||||||
|
# steps = np.sum(epoch > np.asarray(args.lr_decay_epochs))
|
||||||
|
# if steps > 0:
|
||||||
|
# lr = lr * (args.lr_decay_rate ** steps)
|
||||||
|
|
||||||
|
# for param_group in optimizer.param_groups:
|
||||||
|
# param_group['lr'] = lr
|
||||||
|
|
||||||
|
|
||||||
|
# def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer):
|
||||||
|
# if args.warm and epoch <= args.warm_epochs:
|
||||||
|
# p = (batch_id + (epoch - 1) * total_batches) / \
|
||||||
|
# (args.warm_epochs * total_batches)
|
||||||
|
# lr = args.warmup_from + p * (args.warmup_to - args.warmup_from)
|
||||||
|
|
||||||
|
# for param_group in optimizer.param_groups:
|
||||||
|
# param_group['lr'] = lr
|
||||||
|
|
||||||
|
|
||||||
|
# def set_optimizer(opt, model):
|
||||||
|
# optimizer = optim.SGD(model.parameters(),
|
||||||
|
# lr=opt.learning_rate,
|
||||||
|
# momentum=opt.momentum,
|
||||||
|
# weight_decay=opt.weight_decay)
|
||||||
|
# return optimizer
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# import apex
|
||||||
|
# from apex import amp, optimizers
|
||||||
|
# except ImportError:
|
||||||
|
# pass
|
||||||
|
|
||||||
|
|
||||||
|
# def train(train_loader, classifier, criterion, optimizer, epoch, print_freq=10):
|
||||||
|
# """one epoch training"""
|
||||||
|
# classifier.train()
|
||||||
|
|
||||||
|
# batch_time = AverageMeter()
|
||||||
|
# data_time = AverageMeter()
|
||||||
|
# losses = AverageMeter()
|
||||||
|
# top1 = AverageMeter()
|
||||||
|
|
||||||
|
# end = time.time()
|
||||||
|
# for idx, (features, labels) in enumerate(train_loader):
|
||||||
|
# data_time.update(time.time() - end)
|
||||||
|
|
||||||
|
# features = features.cuda(non_blocking=True).float()
|
||||||
|
# labels = labels.cuda(non_blocking=True).long()
|
||||||
|
# bsz = labels.shape[0]
|
||||||
|
# optimizer.zero_grad()
|
||||||
|
# # warm-up learning rate
|
||||||
|
# # warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)
|
||||||
|
|
||||||
|
# output = classifier(features)
|
||||||
|
# loss = F.binary_cross_entropy_with_logits(output.view(-1), labels.float())
|
||||||
|
|
||||||
|
# # update metric
|
||||||
|
# losses.update(loss.item(), bsz)
|
||||||
|
# # acc1, acc5 = accuracy(output, labels, topk=(1, 5))
|
||||||
|
# # breakpoint()
|
||||||
|
# correct = (torch.sigmoid(output) > 0.5).long().view(-1).eq(labels.view(-1))
|
||||||
|
|
||||||
|
# top1.update(correct.sum() / bsz, bsz)
|
||||||
|
|
||||||
|
# # SGD
|
||||||
|
|
||||||
|
# loss.backward()
|
||||||
|
# optimizer.step()
|
||||||
|
|
||||||
|
# # measure elapsed time
|
||||||
|
# batch_time.update(time.time() - end)
|
||||||
|
# end = time.time()
|
||||||
|
|
||||||
|
# #
|
||||||
|
# # # print info
|
||||||
|
# if (idx + 1) % print_freq == 0:
|
||||||
|
# print('Train: [{0}][{1}/{2}]\t'
|
||||||
|
# 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
||||||
|
# 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
|
||||||
|
# 'loss {loss.val:.3f} ({loss.avg:.3f})\t'
|
||||||
|
# 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
|
||||||
|
# epoch, idx + 1, len(train_loader), batch_time=batch_time,
|
||||||
|
# data_time=data_time, loss=losses, top1=top1))
|
||||||
|
# sys.stdout.flush()
|
||||||
|
|
||||||
|
# return losses.avg, top1.avg
|
||||||
|
|
||||||
|
|
||||||
|
# def validate(val_loader, classifier, criterion, print_freq):
|
||||||
|
# """validation"""
|
||||||
|
# classifier.eval()
|
||||||
|
|
||||||
|
# batch_time = AverageMeter()
|
||||||
|
# losses = AverageMeter()
|
||||||
|
# top1 = AverageMeter()
|
||||||
|
|
||||||
|
# preds = np.array([])
|
||||||
|
|
||||||
|
# labels_out = np.array([])
|
||||||
|
# with torch.no_grad():
|
||||||
|
# end = time.time()
|
||||||
|
# for idx, (features, labels) in enumerate(val_loader):
|
||||||
|
# features = features.float().cuda()
|
||||||
|
# labels_out = np.append(labels_out, labels)
|
||||||
|
# labels = labels.long().cuda()
|
||||||
|
# bsz = labels.shape[0]
|
||||||
|
|
||||||
|
# # forward
|
||||||
|
# # output = classifier(model.encoder(images))
|
||||||
|
|
||||||
|
# output = classifier(features.detach())
|
||||||
|
# loss = F.binary_cross_entropy_with_logits(output.view(-1), labels.float())
|
||||||
|
# prob = torch.sigmoid(output)
|
||||||
|
# conf = prob
|
||||||
|
# pred = (prob>0.5).long().view(-1)
|
||||||
|
# # conf, pred = prob.max(1)
|
||||||
|
# preds = np.append(preds, conf.cpu().numpy())
|
||||||
|
|
||||||
|
# # update metric
|
||||||
|
# losses.update(loss.item(), bsz)
|
||||||
|
# correct = (torch.sigmoid(output) > 0.5).long().view(-1).eq(labels.view(-1))
|
||||||
|
# top1.update(correct.sum()/bsz, bsz)
|
||||||
|
|
||||||
|
# # measure elapsed time
|
||||||
|
# batch_time.update(time.time() - end)
|
||||||
|
# end = time.time()
|
||||||
|
|
||||||
|
# if (idx + 1) % 200 == 0:
|
||||||
|
# print('Test: [{0}/{1}]\t'
|
||||||
|
# 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
||||||
|
# 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
|
||||||
|
# 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
|
||||||
|
# idx, len(val_loader), batch_time=batch_time,
|
||||||
|
# loss=losses, top1=top1))
|
||||||
|
|
||||||
|
# # print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))
|
||||||
|
# return losses.avg, top1.avg, preds, labels_out
|
||||||
|
|
||||||
|
# def get_linear_acc(ftrain, ltrain, ftest, ltest, n_cls, epochs=10,
|
||||||
|
# args=None, classifier=None,
|
||||||
|
# print_ret=True, normed=False, nonlinear=False,
|
||||||
|
# learning_rate=5,
|
||||||
|
# weight_decay=0,
|
||||||
|
# batch_size=512,
|
||||||
|
# cosine=False,
|
||||||
|
# lr_decay_epochs=[30,60,90]):
|
||||||
|
|
||||||
|
# cluster2label = np.unique(ltrain)
|
||||||
|
# label2cluster = {li: ci for ci, li in enumerate(cluster2label)}
|
||||||
|
# ctrain = [label2cluster[l] for l in ltrain]
|
||||||
|
# ctest = [label2cluster[l] for l in ltest]
|
||||||
|
# # breakpoint()
|
||||||
|
# opt = easydict.EasyDict({
|
||||||
|
# "lr_decay_rate": 0.2,
|
||||||
|
# "cosine": cosine,
|
||||||
|
# "lr_decay_epochs": lr_decay_epochs,
|
||||||
|
# "start_epoch": 0,
|
||||||
|
# "learning_rate": learning_rate,
|
||||||
|
# "epochs": epochs,
|
||||||
|
# "print_freq": 200,
|
||||||
|
# "batch_size": batch_size,
|
||||||
|
# "momentum": 0.9,
|
||||||
|
# "weight_decay": weight_decay,
|
||||||
|
# })
|
||||||
|
# if args is not None:
|
||||||
|
# for k, v in args.items():
|
||||||
|
# opt[k] = v
|
||||||
|
|
||||||
|
# best_acc = 0
|
||||||
|
|
||||||
|
# criterion = torch.nn.CrossEntropyLoss().cuda()
|
||||||
|
# if classifier is None:
|
||||||
|
# classifier = LinearClassifier(ftrain.shape[1], num_classes=n_cls).cuda()
|
||||||
|
# if nonlinear:
|
||||||
|
# classifier = NonLinearClassifier(ftrain.shape[1], num_classes=n_cls).cuda()
|
||||||
|
|
||||||
|
# trainset = ArrayDataset(ftrain, labels=ctrain)
|
||||||
|
# train_loader = torch.utils.data.DataLoader(trainset, batch_size=opt.batch_size, shuffle=True)
|
||||||
|
|
||||||
|
# valset = ArrayDataset(ftest, labels=ctest)
|
||||||
|
# val_loader = torch.utils.data.DataLoader(valset, batch_size=opt.batch_size, shuffle=False)
|
||||||
|
|
||||||
|
# optimizer = set_optimizer(opt, classifier)
|
||||||
|
|
||||||
|
# best_preds = None
|
||||||
|
# best_state = None
|
||||||
|
# # training routine
|
||||||
|
# for epoch in range(opt.start_epoch + 1, opt.epochs + 1):
|
||||||
|
# adjust_learning_rate(opt, optimizer, epoch)
|
||||||
|
|
||||||
|
# # train for one epoch
|
||||||
|
# loss_train, acc = train(train_loader, classifier, criterion, optimizer, epoch, print_freq=opt.print_freq)
|
||||||
|
|
||||||
|
# # eval for one epoch
|
||||||
|
# loss, val_acc, preds, labels_out = validate(val_loader, classifier, criterion, print_freq=opt.print_freq)
|
||||||
|
# if val_acc > best_acc:
|
||||||
|
# best_acc = val_acc
|
||||||
|
# best_preds = preds
|
||||||
|
# best_state = copy.deepcopy(classifier.state_dict())
|
||||||
|
|
||||||
|
# return best_acc.item(), val_acc.item(), (classifier, best_state, best_preds, preds, labels_out), loss_train
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# def save_model(model, acc, save_file):
|
||||||
|
# print('==> Saving...')
|
||||||
|
# torch.save({
|
||||||
|
# 'acc': acc,
|
||||||
|
# 'state_dict': model.state_dict(),
|
||||||
|
# }, save_file)
|
||||||
|
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
@ -311,4 +625,3 @@ def save_model(model, acc, save_file):
|
||||||
'acc': acc,
|
'acc': acc,
|
||||||
'state_dict': model.state_dict(),
|
'state_dict': model.state_dict(),
|
||||||
}, save_file)
|
}, save_file)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue