This commit is contained in:
weixin_43297441 2025-02-07 10:23:14 +08:00
parent 48ef1a071c
commit 43babf3e20
2 changed files with 6 additions and 5 deletions

View File

@ -1,4 +1,4 @@
name: base
name: haloscope
channels:
- conda-forge
- defaults

View File

@ -4,10 +4,11 @@ import torch.nn.functional as F
import evaluate
from datasets import load_metric
from datasets import load_dataset
import datasets
from tqdm import tqdm
import numpy as np
import pickle
from utils import get_llama_activations_bau, tokenized_tqa, tokenized_tqa_gen, tokenized_tqa_gen_end_q
# from utils import get_llama_activations_bau, tokenized_tqa, tokenized_tqa_gen, tokenized_tqa_gen_end_q
import llama_iti
import pickle
import argparse
@ -49,7 +50,7 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default='llama2_chat_7B')
parser.add_argument('--dataset_name', type=str, default='tqa')
parser.add_argument('--dataset_name', type=str, default='triviaqa')
parser.add_argument('--num_gene', type=int, default=1)
parser.add_argument('--gene', type=int, default=0)
parser.add_argument('--generate_gt', type=int, default=0)
@ -240,8 +241,8 @@ def main():
elif args.generate_gt:
from bleurt_pytorch import BleurtConfig, BleurtForSequenceClassification, BleurtTokenizer
model = BleurtForSequenceClassification.from_pretrained('./models/BLEURT-20').cuda()
tokenizer = BleurtTokenizer.from_pretrained('./models/BLEURT-20')
model = BleurtForSequenceClassification.from_pretrained('lucadiliello/BLEURT-20').cuda()
tokenizer = BleurtTokenizer.from_pretrained('lucadiliello/BLEURT-20')
model.eval()
rouge = evaluate.load('rouge')