add
This commit is contained in:
parent
48ef1a071c
commit
43babf3e20
2
env.yml
2
env.yml
|
|
@ -1,4 +1,4 @@
|
||||||
name: base
|
name: haloscope
|
||||||
channels:
|
channels:
|
||||||
- conda-forge
|
- conda-forge
|
||||||
- defaults
|
- defaults
|
||||||
|
|
|
||||||
|
|
@ -4,10 +4,11 @@ import torch.nn.functional as F
|
||||||
import evaluate
|
import evaluate
|
||||||
from datasets import load_metric
|
from datasets import load_metric
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
import datasets
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pickle
|
import pickle
|
||||||
from utils import get_llama_activations_bau, tokenized_tqa, tokenized_tqa_gen, tokenized_tqa_gen_end_q
|
# from utils import get_llama_activations_bau, tokenized_tqa, tokenized_tqa_gen, tokenized_tqa_gen_end_q
|
||||||
import llama_iti
|
import llama_iti
|
||||||
import pickle
|
import pickle
|
||||||
import argparse
|
import argparse
|
||||||
|
|
@ -49,7 +50,7 @@ def main():
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--model_name', type=str, default='llama2_chat_7B')
|
parser.add_argument('--model_name', type=str, default='llama2_chat_7B')
|
||||||
parser.add_argument('--dataset_name', type=str, default='tqa')
|
parser.add_argument('--dataset_name', type=str, default='triviaqa')
|
||||||
parser.add_argument('--num_gene', type=int, default=1)
|
parser.add_argument('--num_gene', type=int, default=1)
|
||||||
parser.add_argument('--gene', type=int, default=0)
|
parser.add_argument('--gene', type=int, default=0)
|
||||||
parser.add_argument('--generate_gt', type=int, default=0)
|
parser.add_argument('--generate_gt', type=int, default=0)
|
||||||
|
|
@ -240,8 +241,8 @@ def main():
|
||||||
elif args.generate_gt:
|
elif args.generate_gt:
|
||||||
from bleurt_pytorch import BleurtConfig, BleurtForSequenceClassification, BleurtTokenizer
|
from bleurt_pytorch import BleurtConfig, BleurtForSequenceClassification, BleurtTokenizer
|
||||||
|
|
||||||
model = BleurtForSequenceClassification.from_pretrained('./models/BLEURT-20').cuda()
|
model = BleurtForSequenceClassification.from_pretrained('lucadiliello/BLEURT-20').cuda()
|
||||||
tokenizer = BleurtTokenizer.from_pretrained('./models/BLEURT-20')
|
tokenizer = BleurtTokenizer.from_pretrained('lucadiliello/BLEURT-20')
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
rouge = evaluate.load('rouge')
|
rouge = evaluate.load('rouge')
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue