130 lines
4.3 KiB
Python
130 lines
4.3 KiB
Python
import random
|
|
import openai
|
|
import time
|
|
import json
|
|
import argparse
|
|
from openai import OpenAI
|
|
from openai import OpenAIError
|
|
|
|
|
|
client = OpenAI(
|
|
api_key="sk-", # 在这里将 MOONSHOT_API_KEY 替换为你从 Kimi 开放平台申请的 API Key
|
|
base_url="https://api.deepseek.com/v1",
|
|
)
|
|
|
|
def get_qa_response(model, question, answer):
|
|
message = [
|
|
{"role": "system", "content":"你是一个幻觉检测器。你必须根据世界知识确定问题的答案是否符合事实。你提供的答案必须是 \"YES\" or \"NO\" 并且给出你的理由"},
|
|
{"role": "user", "content":
|
|
"\n\n#Question#: " + question +
|
|
"\n#Answer#: " + answer +
|
|
"\n#Your Judgement#: "}
|
|
]
|
|
# prompt = "\n#Question#: " + question + "\n#Answer#: " + answer + "\n#Your Judgement#:"
|
|
while True:
|
|
try:
|
|
if model == "gpt-3.5-turbo":
|
|
res = openai.ChatCompletion.create(
|
|
model="gpt-3.5-turbo",
|
|
messages=message,
|
|
temperature=0.0,
|
|
)
|
|
response = res['choices'][0]['message']['content']
|
|
else:
|
|
res = client.chat.completions.create(
|
|
model="deepseek-chat",
|
|
messages=message,
|
|
stream=False
|
|
)
|
|
response = res.choices[0].message.content
|
|
break
|
|
except OpenAIError:
|
|
print('openai.error.RateLimitError\nRetrying...')
|
|
time.sleep(60)
|
|
except openai.error.ServiceUnavailableError:
|
|
print('openai.error.ServiceUnavailableError\nRetrying...')
|
|
time.sleep(20)
|
|
except openai.error.Timeout:
|
|
print('openai.error.Timeout\nRetrying...')
|
|
time.sleep(20)
|
|
except openai.error.APIError:
|
|
print('openai.error.APIError\nRetrying...')
|
|
time.sleep(20)
|
|
except openai.error.APIConnectionError:
|
|
print('openai.error.APIConnectionError\nRetrying...')
|
|
time.sleep(20)
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluation_qa_dataset(model, file, output_path):
|
|
result=[]
|
|
|
|
# test_file=json.loads(file)
|
|
with open(file, 'r', encoding="utf-8") as f:
|
|
# print(f"File content: {file}")
|
|
test_file=json.load(f)
|
|
data = []
|
|
for i in range(len(test_file)):
|
|
data.append(test_file[i])
|
|
for i in range(len(data)):
|
|
question= data[i]["Question"]
|
|
answer=data[i]["Answer"]
|
|
|
|
output_samples = get_qa_response(model, question, answer)
|
|
print('sample {} success......'.format(i))
|
|
result.append({"Question":question,"Answer":answer, "Prediction":"YES" if "YES" in output_samples else "NO"})
|
|
# result.append({"Question":question,"Answer":answer, "Prediction":"YES" if "YES" in output_samples else "NO", "Reason":output_samples})
|
|
dump_jsonl(result, output_path, append=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dump_jsonl(data, output_path, append=False):
|
|
"""
|
|
Write list of objects to a JSON lines file.
|
|
"""
|
|
mode = 'a+' if append else 'w'
|
|
with open(output_path, mode, encoding='utf-8') as f:
|
|
json_record = json.dumps(data, ensure_ascii=False)
|
|
f.write(json_record + '\n')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description="Hallucination Generation")
|
|
|
|
parser.add_argument("--task", default="qa", help="qa, dialogue, or summarization")
|
|
parser.add_argument("--model", default="deepseek", help="model name")
|
|
args = parser.parse_args()
|
|
|
|
# instruction_file = "{}/{}_evaluation_instruction.txt".format(args.task, args.task)
|
|
# f = open(instruction_file, 'r', encoding="utf-8")
|
|
# instruction = f.read()
|
|
|
|
model = args.model
|
|
output_path = "{}/{}_results.json".format(args.task, args.model)
|
|
|
|
# data = "../data/{}_data.json".format(args.task)
|
|
data="/home/leewlving/PycharmProjects/xianxing_cup3/factuality_predict.json"
|
|
|
|
if args.task == "qa":
|
|
evaluation_qa_dataset(model, data, output_path)
|
|
else:
|
|
raise ValueError("The task must be qa, dialogue, or summarization!")
|