add
This commit is contained in:
parent
48ef1a071c
commit
43babf3e20
|
|
@ -4,10 +4,11 @@ import torch.nn.functional as F
|
|||
import evaluate
|
||||
from datasets import load_metric
|
||||
from datasets import load_dataset
|
||||
import datasets
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import pickle
|
||||
from utils import get_llama_activations_bau, tokenized_tqa, tokenized_tqa_gen, tokenized_tqa_gen_end_q
|
||||
# from utils import get_llama_activations_bau, tokenized_tqa, tokenized_tqa_gen, tokenized_tqa_gen_end_q
|
||||
import llama_iti
|
||||
import pickle
|
||||
import argparse
|
||||
|
|
@ -49,7 +50,7 @@ def main():
|
|||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model_name', type=str, default='llama2_chat_7B')
|
||||
parser.add_argument('--dataset_name', type=str, default='tqa')
|
||||
parser.add_argument('--dataset_name', type=str, default='triviaqa')
|
||||
parser.add_argument('--num_gene', type=int, default=1)
|
||||
parser.add_argument('--gene', type=int, default=0)
|
||||
parser.add_argument('--generate_gt', type=int, default=0)
|
||||
|
|
@ -240,8 +241,8 @@ def main():
|
|||
elif args.generate_gt:
|
||||
from bleurt_pytorch import BleurtConfig, BleurtForSequenceClassification, BleurtTokenizer
|
||||
|
||||
model = BleurtForSequenceClassification.from_pretrained('./models/BLEURT-20').cuda()
|
||||
tokenizer = BleurtTokenizer.from_pretrained('./models/BLEURT-20')
|
||||
model = BleurtForSequenceClassification.from_pretrained('lucadiliello/BLEURT-20').cuda()
|
||||
tokenizer = BleurtTokenizer.from_pretrained('lucadiliello/BLEURT-20')
|
||||
model.eval()
|
||||
|
||||
rouge = evaluate.load('rouge')
|
||||
|
|
|
|||
Loading…
Reference in New Issue