diff --git a/env.yml b/env.yml index e924686..f51cd59 100644 --- a/env.yml +++ b/env.yml @@ -1,4 +1,4 @@ -name: base +name: haloscope channels: - conda-forge - defaults diff --git a/hal_det_llama.py b/hal_det_llama.py index adcb40e..d80bb16 100644 --- a/hal_det_llama.py +++ b/hal_det_llama.py @@ -4,10 +4,11 @@ import torch.nn.functional as F import evaluate from datasets import load_metric from datasets import load_dataset +import datasets from tqdm import tqdm import numpy as np import pickle -from utils import get_llama_activations_bau, tokenized_tqa, tokenized_tqa_gen, tokenized_tqa_gen_end_q +# from utils import get_llama_activations_bau, tokenized_tqa, tokenized_tqa_gen, tokenized_tqa_gen_end_q import llama_iti import pickle import argparse @@ -49,7 +50,7 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument('--model_name', type=str, default='llama2_chat_7B') - parser.add_argument('--dataset_name', type=str, default='tqa') + parser.add_argument('--dataset_name', type=str, default='triviaqa') parser.add_argument('--num_gene', type=int, default=1) parser.add_argument('--gene', type=int, default=0) parser.add_argument('--generate_gt', type=int, default=0) @@ -240,8 +241,8 @@ def main(): elif args.generate_gt: from bleurt_pytorch import BleurtConfig, BleurtForSequenceClassification, BleurtTokenizer - model = BleurtForSequenceClassification.from_pretrained('./models/BLEURT-20').cuda() - tokenizer = BleurtTokenizer.from_pretrained('./models/BLEURT-20') + model = BleurtForSequenceClassification.from_pretrained('lucadiliello/BLEURT-20').cuda() + tokenizer = BleurtTokenizer.from_pretrained('lucadiliello/BLEURT-20') model.eval() rouge = evaluate.load('rouge')