317 lines
14 KiB
Python
317 lines
14 KiB
Python
from typing import Dict, Optional, List
|
|
from haystack.document_stores.base import BaseDocumentStore
|
|
from haystack.schema import Document, MultiLabel
|
|
from haystack.nodes.retriever import BaseRetriever
|
|
import logging
|
|
from time import perf_counter
|
|
from tqdm import tqdm
|
|
import sys
|
|
import json
|
|
from haystack.nodes import (
|
|
SentenceTransformersRanker,
|
|
|
|
)
|
|
sys.path.append("../..")
|
|
from reranker import ReRanker
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def eval(
|
|
document_store: BaseDocumentStore ,
|
|
retriever: BaseRetriever,
|
|
reRankerGPT: ReRanker=None,
|
|
rerankerPipeline:SentenceTransformersRanker=None,
|
|
label_index: str = "label",
|
|
doc_index: str = "eval_document",
|
|
label_origin: str = "gold-label",
|
|
top_k: int = 10,
|
|
open_domain: bool = False,
|
|
return_preds: bool = False,
|
|
headers: Optional[Dict[str, str]] = None,
|
|
) -> dict:
|
|
# Extract all questions for evaluation
|
|
filters: Dict = {"origin": [label_origin]}
|
|
debug=[]
|
|
time_taken=0
|
|
if document_store is None:
|
|
raise ValueError(
|
|
"This Retriever was not initialized with a Document Store. Provide one to the eval() method."
|
|
)
|
|
labels: List[MultiLabel] = document_store.get_all_labels_aggregated(
|
|
index=label_index,
|
|
filters=filters,
|
|
open_domain=open_domain,
|
|
drop_negative_labels=True,
|
|
drop_no_answers=False,
|
|
headers=headers,
|
|
)
|
|
|
|
correct_retrievals = 0
|
|
summed_avg_precision = 0.0
|
|
summed_reciprocal_rank = 0.0
|
|
|
|
# Collect questions and corresponding answers/document_ids in a dict
|
|
question_label_dict = {}
|
|
for label in labels:
|
|
# document_ids are empty if no_answer == True
|
|
if not label.no_answer:
|
|
id_question_tuple = (label.document_ids[0], label.query)
|
|
if open_domain:
|
|
# here are no no_answer '' included if there are other actual answers
|
|
question_label_dict[id_question_tuple] = label.answers
|
|
else:
|
|
deduplicated_doc_ids = list({str(x) for x in label.document_ids})
|
|
question_label_dict[id_question_tuple] = deduplicated_doc_ids
|
|
|
|
predictions = []
|
|
|
|
# Option 1: Open-domain evaluation by checking if the answer string is in the retrieved docs
|
|
logger.info("Performing eval queries...")
|
|
if open_domain:
|
|
for (_, question), gold_answers in tqdm(question_label_dict.items()):
|
|
tic = perf_counter()
|
|
retrieved_docs = retriever.retrieve(query= question, headers=headers, index= doc_index, top_k= top_k)
|
|
item={"retrieved_ids": [doc.id for doc in retrieved_docs]}
|
|
if reRankerGPT:
|
|
reranked_docs= reRankerGPT.rerank_documents_with_gpt35(query= question,documents=retrieved_docs)
|
|
print(reranked_docs,)
|
|
item["reranked_ids"]= [doc.id for doc in reranked_docs]
|
|
item["isEqual"]= item["reranked_ids"] == item["retrieved_ids"]
|
|
retrieved_docs= reRankerGPT.get_final_references(reranked_documents=reranked_docs, retrieved_documents=retrieved_docs)
|
|
item["final_reorderd_ids"]= [doc.id for doc in retrieved_docs]
|
|
if rerankerPipeline:
|
|
retrieved_docs= rerankerPipeline.predict(query=question, documents=retrieved_docs)
|
|
debug.append({question:item})
|
|
toc = perf_counter()
|
|
time_taken+= toc -tic
|
|
if return_preds:
|
|
predictions.append({"question": question, "retrieved_docs": retrieved_docs})
|
|
# check if correct doc in retrieved docs
|
|
found_relevant_doc = False
|
|
relevant_docs_found = 0
|
|
current_avg_precision = 0.0
|
|
print("GOLD ANWERS: ", gold_answers)
|
|
for doc_idx, doc in enumerate(retrieved_docs):
|
|
for gold_answer in gold_answers:
|
|
if gold_answer in doc.content:
|
|
|
|
relevant_docs_found += 1
|
|
if not found_relevant_doc:
|
|
correct_retrievals += 1
|
|
summed_reciprocal_rank += 1 / (doc_idx + 1)
|
|
current_avg_precision += relevant_docs_found / (doc_idx + 1)
|
|
found_relevant_doc = True
|
|
break
|
|
if found_relevant_doc:
|
|
summed_avg_precision += current_avg_precision / relevant_docs_found
|
|
# Option 2: Strict evaluation by document ids that are listed in the labels
|
|
else:
|
|
for (_, question), gold_ids in tqdm(question_label_dict.items()):
|
|
tic = perf_counter()
|
|
retrieved_docs = retriever.retrieve(query= question, headers=headers, index= doc_index, top_k= top_k)
|
|
item={"retrieved_ids": [doc.id for doc in retrieved_docs]}
|
|
if reRanker:
|
|
reranked_docs= reRanker.rerank_documents_with_gpt35(query= question,documents=retrieved_docs)
|
|
print(reranked_docs,)
|
|
item["reranked_ids"]= [doc.id for doc in reranked_docs]
|
|
item["isEqual"]= item["reranked_ids"] == item["retrieved_ids"]
|
|
retrieved_docs= reRanker.get_final_references(reranked_documents=reranked_docs, retrieved_documents=retrieved_docs)
|
|
item["final_reorderd_ids"]= [doc.id for doc in retrieved_docs]
|
|
debug.append({question:item})
|
|
toc = perf_counter()
|
|
time_taken+= toc -tic
|
|
if return_preds:
|
|
predictions.append({"question": question, "retrieved_docs": retrieved_docs})
|
|
# check if correct doc in retrieved docs
|
|
found_relevant_doc = False
|
|
relevant_docs_found = 0
|
|
current_avg_precision = 0.0
|
|
for doc_idx, doc in enumerate(retrieved_docs):
|
|
for gold_id in gold_ids:
|
|
if str(doc.id) == gold_id:
|
|
relevant_docs_found += 1
|
|
if not found_relevant_doc:
|
|
correct_retrievals += 1
|
|
summed_reciprocal_rank += 1 / (doc_idx + 1)
|
|
current_avg_precision += relevant_docs_found / (doc_idx + 1)
|
|
found_relevant_doc = True
|
|
break
|
|
if found_relevant_doc:
|
|
all_relevant_docs = len(set(gold_ids))
|
|
summed_avg_precision += current_avg_precision / all_relevant_docs
|
|
# Metrics
|
|
number_of_questions = len(question_label_dict)
|
|
recall = correct_retrievals / number_of_questions
|
|
mean_reciprocal_rank = summed_reciprocal_rank / number_of_questions
|
|
mean_avg_precision = summed_avg_precision / number_of_questions
|
|
|
|
logger.info(
|
|
"For {} out of {} questions ({:.2%}), the answer was in the top-{} candidate passages selected by the retriever.".format(
|
|
correct_retrievals, number_of_questions, recall, top_k
|
|
)
|
|
)
|
|
|
|
metrics = {
|
|
"recall": recall,
|
|
"map": mean_avg_precision,
|
|
"mrr": mean_reciprocal_rank,
|
|
"retrieve_time": time_taken,
|
|
"n_questions": number_of_questions,
|
|
"top_k": top_k,
|
|
}
|
|
with open("debug.json", "w") as fp:
|
|
json.dump(debug, fp, ensure_ascii=False)
|
|
if return_preds:
|
|
return {"metrics": metrics, "predictions": predictions}
|
|
else:
|
|
return metrics
|
|
|
|
|
|
def eval_llama(
|
|
document_store: BaseDocumentStore ,
|
|
vector_store: BaseDocumentStore ,
|
|
retriever: BaseRetriever,
|
|
reRanker: ReRanker=None,
|
|
label_index: str = "label",
|
|
doc_index: str = "eval_document",
|
|
label_origin: str = "gold-label",
|
|
top_k: int = 10,
|
|
open_domain: bool = False,
|
|
return_preds: bool = False,
|
|
headers: Optional[Dict[str, str]] = None,
|
|
) -> dict:
|
|
# Extract all questions for evaluation
|
|
filters: Dict = {"origin": [label_origin]}
|
|
debug=[]
|
|
time_taken=0
|
|
if document_store is None:
|
|
raise ValueError(
|
|
"This Retriever was not initialized with a Document Store. Provide one to the eval() method."
|
|
)
|
|
labels: List[MultiLabel] = document_store.get_all_labels_aggregated(
|
|
index=label_index,
|
|
filters=filters,
|
|
open_domain=open_domain,
|
|
drop_negative_labels=True,
|
|
drop_no_answers=False,
|
|
headers=headers,
|
|
)
|
|
|
|
correct_retrievals = 0
|
|
summed_avg_precision = 0.0
|
|
summed_reciprocal_rank = 0.0
|
|
|
|
# Collect questions and corresponding answers/document_ids in a dict
|
|
question_label_dict = {}
|
|
for label in labels:
|
|
# document_ids are empty if no_answer == True
|
|
if not label.no_answer:
|
|
id_question_tuple = (label.document_ids[0], label.query)
|
|
if open_domain:
|
|
# here are no no_answer '' included if there are other actual answers
|
|
question_label_dict[id_question_tuple] = label.answers
|
|
else:
|
|
deduplicated_doc_ids = list({str(x) for x in label.document_ids})
|
|
question_label_dict[id_question_tuple] = deduplicated_doc_ids
|
|
|
|
predictions = []
|
|
|
|
# Option 1: Open-domain evaluation by checking if the answer string is in the retrieved docs
|
|
logger.info("Performing eval queries...")
|
|
if open_domain:
|
|
for (_, question), gold_answers in tqdm(question_label_dict.items()):
|
|
tic = perf_counter()
|
|
retrieved_docs = retriever.retrieve(query= question, headers=headers, index= doc_index, top_k= top_k)
|
|
print("retrieved_docs: ", retrieved_docs)
|
|
item={"retrieved_ids": [doc.id for doc in retrieved_docs]}
|
|
if reRanker:
|
|
reranked_docs= reRanker.rerank_documents_with_gpt35(query= question,documents=retrieved_docs)
|
|
print(reranked_docs,)
|
|
item["reranked_ids"]= [doc.id for doc in reranked_docs]
|
|
item["isEqual"]= item["reranked_ids"] == item["retrieved_ids"]
|
|
retrieved_docs= reRanker.get_final_references(reranked_documents=reranked_docs, retrieved_documents=retrieved_docs)
|
|
item["final_reorderd_ids"]= [doc.id for doc in retrieved_docs]
|
|
debug.append({question:item})
|
|
toc = perf_counter()
|
|
time_taken+= toc -tic
|
|
if return_preds:
|
|
predictions.append({"question": question, "retrieved_docs": retrieved_docs})
|
|
# check if correct doc in retrieved docs
|
|
found_relevant_doc = False
|
|
relevant_docs_found = 0
|
|
current_avg_precision = 0.0
|
|
print("GOLD ANWERS: ", gold_answers)
|
|
for doc_idx, doc in enumerate(retrieved_docs):
|
|
for gold_answer in gold_answers:
|
|
if gold_answer in doc.content:
|
|
|
|
relevant_docs_found += 1
|
|
if not found_relevant_doc:
|
|
correct_retrievals += 1
|
|
summed_reciprocal_rank += 1 / (doc_idx + 1)
|
|
current_avg_precision += relevant_docs_found / (doc_idx + 1)
|
|
found_relevant_doc = True
|
|
break
|
|
if found_relevant_doc:
|
|
summed_avg_precision += current_avg_precision / relevant_docs_found
|
|
# Option 2: Strict evaluation by document ids that are listed in the labels
|
|
else:
|
|
for (_, question), gold_ids in tqdm(question_label_dict.items()):
|
|
tic = perf_counter()
|
|
retrieved_docs = retriever.retrieve(query= question, headers=headers, index= doc_index, top_k= top_k)
|
|
item={"retrieved_ids": [doc.id for doc in retrieved_docs]}
|
|
if reRanker:
|
|
reranked_docs= reRanker.rerank_documents_with_gpt35(query= question,documents=retrieved_docs)
|
|
print(reranked_docs,)
|
|
item["reranked_ids"]= [doc.id for doc in reranked_docs]
|
|
item["isEqual"]= item["reranked_ids"] == item["retrieved_ids"]
|
|
retrieved_docs= reRanker.get_final_references(reranked_documents=reranked_docs, retrieved_documents=retrieved_docs)
|
|
item["final_reorderd_ids"]= [doc.id for doc in retrieved_docs]
|
|
debug.append({question:item})
|
|
toc = perf_counter()
|
|
time_taken+= toc -tic
|
|
if return_preds:
|
|
predictions.append({"question": question, "retrieved_docs": retrieved_docs})
|
|
# check if correct doc in retrieved docs
|
|
found_relevant_doc = False
|
|
relevant_docs_found = 0
|
|
current_avg_precision = 0.0
|
|
for doc_idx, doc in enumerate(retrieved_docs):
|
|
for gold_id in gold_ids:
|
|
if str(doc.id) == gold_id:
|
|
relevant_docs_found += 1
|
|
if not found_relevant_doc:
|
|
correct_retrievals += 1
|
|
summed_reciprocal_rank += 1 / (doc_idx + 1)
|
|
current_avg_precision += relevant_docs_found / (doc_idx + 1)
|
|
found_relevant_doc = True
|
|
break
|
|
if found_relevant_doc:
|
|
all_relevant_docs = len(set(gold_ids))
|
|
summed_avg_precision += current_avg_precision / all_relevant_docs
|
|
# Metrics
|
|
number_of_questions = len(question_label_dict)
|
|
recall = correct_retrievals / number_of_questions
|
|
mean_reciprocal_rank = summed_reciprocal_rank / number_of_questions
|
|
mean_avg_precision = summed_avg_precision / number_of_questions
|
|
|
|
logger.info(
|
|
"For {} out of {} questions ({:.2%}), the answer was in the top-{} candidate passages selected by the retriever.".format(
|
|
correct_retrievals, number_of_questions, recall, top_k
|
|
)
|
|
)
|
|
|
|
metrics = {
|
|
"recall": recall,
|
|
"map": mean_avg_precision,
|
|
"mrr": mean_reciprocal_rank,
|
|
"retrieve_time": time_taken,
|
|
"n_questions": number_of_questions,
|
|
"top_k": top_k,
|
|
}
|
|
with open("debug.json", "w") as fp:
|
|
json.dump(debug, fp, ensure_ascii=False)
|
|
if return_preds:
|
|
return {"metrics": metrics, "predictions": predictions}
|
|
else:
|
|
return metrics
|
|
|