BA-Chatbot/backend/evaluation/eval_retriever/custom_evaluation.py

317 lines
14 KiB
Python

from typing import Dict, Optional, List
from haystack.document_stores.base import BaseDocumentStore
from haystack.schema import Document, MultiLabel
from haystack.nodes.retriever import BaseRetriever
import logging
from time import perf_counter
from tqdm import tqdm
import sys
import json
from haystack.nodes import (
SentenceTransformersRanker,
)
sys.path.append("../..")
from reranker import ReRanker
logger = logging.getLogger(__name__)
def eval(
document_store: BaseDocumentStore ,
retriever: BaseRetriever,
reRankerGPT: ReRanker=None,
rerankerPipeline:SentenceTransformersRanker=None,
label_index: str = "label",
doc_index: str = "eval_document",
label_origin: str = "gold-label",
top_k: int = 10,
open_domain: bool = False,
return_preds: bool = False,
headers: Optional[Dict[str, str]] = None,
) -> dict:
# Extract all questions for evaluation
filters: Dict = {"origin": [label_origin]}
debug=[]
time_taken=0
if document_store is None:
raise ValueError(
"This Retriever was not initialized with a Document Store. Provide one to the eval() method."
)
labels: List[MultiLabel] = document_store.get_all_labels_aggregated(
index=label_index,
filters=filters,
open_domain=open_domain,
drop_negative_labels=True,
drop_no_answers=False,
headers=headers,
)
correct_retrievals = 0
summed_avg_precision = 0.0
summed_reciprocal_rank = 0.0
# Collect questions and corresponding answers/document_ids in a dict
question_label_dict = {}
for label in labels:
# document_ids are empty if no_answer == True
if not label.no_answer:
id_question_tuple = (label.document_ids[0], label.query)
if open_domain:
# here are no no_answer '' included if there are other actual answers
question_label_dict[id_question_tuple] = label.answers
else:
deduplicated_doc_ids = list({str(x) for x in label.document_ids})
question_label_dict[id_question_tuple] = deduplicated_doc_ids
predictions = []
# Option 1: Open-domain evaluation by checking if the answer string is in the retrieved docs
logger.info("Performing eval queries...")
if open_domain:
for (_, question), gold_answers in tqdm(question_label_dict.items()):
tic = perf_counter()
retrieved_docs = retriever.retrieve(query= question, headers=headers, index= doc_index, top_k= top_k)
item={"retrieved_ids": [doc.id for doc in retrieved_docs]}
if reRankerGPT:
reranked_docs= reRankerGPT.rerank_documents_with_gpt35(query= question,documents=retrieved_docs)
print(reranked_docs,)
item["reranked_ids"]= [doc.id for doc in reranked_docs]
item["isEqual"]= item["reranked_ids"] == item["retrieved_ids"]
retrieved_docs= reRankerGPT.get_final_references(reranked_documents=reranked_docs, retrieved_documents=retrieved_docs)
item["final_reorderd_ids"]= [doc.id for doc in retrieved_docs]
if rerankerPipeline:
retrieved_docs= rerankerPipeline.predict(query=question, documents=retrieved_docs)
debug.append({question:item})
toc = perf_counter()
time_taken+= toc -tic
if return_preds:
predictions.append({"question": question, "retrieved_docs": retrieved_docs})
# check if correct doc in retrieved docs
found_relevant_doc = False
relevant_docs_found = 0
current_avg_precision = 0.0
print("GOLD ANWERS: ", gold_answers)
for doc_idx, doc in enumerate(retrieved_docs):
for gold_answer in gold_answers:
if gold_answer in doc.content:
relevant_docs_found += 1
if not found_relevant_doc:
correct_retrievals += 1
summed_reciprocal_rank += 1 / (doc_idx + 1)
current_avg_precision += relevant_docs_found / (doc_idx + 1)
found_relevant_doc = True
break
if found_relevant_doc:
summed_avg_precision += current_avg_precision / relevant_docs_found
# Option 2: Strict evaluation by document ids that are listed in the labels
else:
for (_, question), gold_ids in tqdm(question_label_dict.items()):
tic = perf_counter()
retrieved_docs = retriever.retrieve(query= question, headers=headers, index= doc_index, top_k= top_k)
item={"retrieved_ids": [doc.id for doc in retrieved_docs]}
if reRanker:
reranked_docs= reRanker.rerank_documents_with_gpt35(query= question,documents=retrieved_docs)
print(reranked_docs,)
item["reranked_ids"]= [doc.id for doc in reranked_docs]
item["isEqual"]= item["reranked_ids"] == item["retrieved_ids"]
retrieved_docs= reRanker.get_final_references(reranked_documents=reranked_docs, retrieved_documents=retrieved_docs)
item["final_reorderd_ids"]= [doc.id for doc in retrieved_docs]
debug.append({question:item})
toc = perf_counter()
time_taken+= toc -tic
if return_preds:
predictions.append({"question": question, "retrieved_docs": retrieved_docs})
# check if correct doc in retrieved docs
found_relevant_doc = False
relevant_docs_found = 0
current_avg_precision = 0.0
for doc_idx, doc in enumerate(retrieved_docs):
for gold_id in gold_ids:
if str(doc.id) == gold_id:
relevant_docs_found += 1
if not found_relevant_doc:
correct_retrievals += 1
summed_reciprocal_rank += 1 / (doc_idx + 1)
current_avg_precision += relevant_docs_found / (doc_idx + 1)
found_relevant_doc = True
break
if found_relevant_doc:
all_relevant_docs = len(set(gold_ids))
summed_avg_precision += current_avg_precision / all_relevant_docs
# Metrics
number_of_questions = len(question_label_dict)
recall = correct_retrievals / number_of_questions
mean_reciprocal_rank = summed_reciprocal_rank / number_of_questions
mean_avg_precision = summed_avg_precision / number_of_questions
logger.info(
"For {} out of {} questions ({:.2%}), the answer was in the top-{} candidate passages selected by the retriever.".format(
correct_retrievals, number_of_questions, recall, top_k
)
)
metrics = {
"recall": recall,
"map": mean_avg_precision,
"mrr": mean_reciprocal_rank,
"retrieve_time": time_taken,
"n_questions": number_of_questions,
"top_k": top_k,
}
with open("debug.json", "w") as fp:
json.dump(debug, fp, ensure_ascii=False)
if return_preds:
return {"metrics": metrics, "predictions": predictions}
else:
return metrics
def eval_llama(
document_store: BaseDocumentStore ,
vector_store: BaseDocumentStore ,
retriever: BaseRetriever,
reRanker: ReRanker=None,
label_index: str = "label",
doc_index: str = "eval_document",
label_origin: str = "gold-label",
top_k: int = 10,
open_domain: bool = False,
return_preds: bool = False,
headers: Optional[Dict[str, str]] = None,
) -> dict:
# Extract all questions for evaluation
filters: Dict = {"origin": [label_origin]}
debug=[]
time_taken=0
if document_store is None:
raise ValueError(
"This Retriever was not initialized with a Document Store. Provide one to the eval() method."
)
labels: List[MultiLabel] = document_store.get_all_labels_aggregated(
index=label_index,
filters=filters,
open_domain=open_domain,
drop_negative_labels=True,
drop_no_answers=False,
headers=headers,
)
correct_retrievals = 0
summed_avg_precision = 0.0
summed_reciprocal_rank = 0.0
# Collect questions and corresponding answers/document_ids in a dict
question_label_dict = {}
for label in labels:
# document_ids are empty if no_answer == True
if not label.no_answer:
id_question_tuple = (label.document_ids[0], label.query)
if open_domain:
# here are no no_answer '' included if there are other actual answers
question_label_dict[id_question_tuple] = label.answers
else:
deduplicated_doc_ids = list({str(x) for x in label.document_ids})
question_label_dict[id_question_tuple] = deduplicated_doc_ids
predictions = []
# Option 1: Open-domain evaluation by checking if the answer string is in the retrieved docs
logger.info("Performing eval queries...")
if open_domain:
for (_, question), gold_answers in tqdm(question_label_dict.items()):
tic = perf_counter()
retrieved_docs = retriever.retrieve(query= question, headers=headers, index= doc_index, top_k= top_k)
print("retrieved_docs: ", retrieved_docs)
item={"retrieved_ids": [doc.id for doc in retrieved_docs]}
if reRanker:
reranked_docs= reRanker.rerank_documents_with_gpt35(query= question,documents=retrieved_docs)
print(reranked_docs,)
item["reranked_ids"]= [doc.id for doc in reranked_docs]
item["isEqual"]= item["reranked_ids"] == item["retrieved_ids"]
retrieved_docs= reRanker.get_final_references(reranked_documents=reranked_docs, retrieved_documents=retrieved_docs)
item["final_reorderd_ids"]= [doc.id for doc in retrieved_docs]
debug.append({question:item})
toc = perf_counter()
time_taken+= toc -tic
if return_preds:
predictions.append({"question": question, "retrieved_docs": retrieved_docs})
# check if correct doc in retrieved docs
found_relevant_doc = False
relevant_docs_found = 0
current_avg_precision = 0.0
print("GOLD ANWERS: ", gold_answers)
for doc_idx, doc in enumerate(retrieved_docs):
for gold_answer in gold_answers:
if gold_answer in doc.content:
relevant_docs_found += 1
if not found_relevant_doc:
correct_retrievals += 1
summed_reciprocal_rank += 1 / (doc_idx + 1)
current_avg_precision += relevant_docs_found / (doc_idx + 1)
found_relevant_doc = True
break
if found_relevant_doc:
summed_avg_precision += current_avg_precision / relevant_docs_found
# Option 2: Strict evaluation by document ids that are listed in the labels
else:
for (_, question), gold_ids in tqdm(question_label_dict.items()):
tic = perf_counter()
retrieved_docs = retriever.retrieve(query= question, headers=headers, index= doc_index, top_k= top_k)
item={"retrieved_ids": [doc.id for doc in retrieved_docs]}
if reRanker:
reranked_docs= reRanker.rerank_documents_with_gpt35(query= question,documents=retrieved_docs)
print(reranked_docs,)
item["reranked_ids"]= [doc.id for doc in reranked_docs]
item["isEqual"]= item["reranked_ids"] == item["retrieved_ids"]
retrieved_docs= reRanker.get_final_references(reranked_documents=reranked_docs, retrieved_documents=retrieved_docs)
item["final_reorderd_ids"]= [doc.id for doc in retrieved_docs]
debug.append({question:item})
toc = perf_counter()
time_taken+= toc -tic
if return_preds:
predictions.append({"question": question, "retrieved_docs": retrieved_docs})
# check if correct doc in retrieved docs
found_relevant_doc = False
relevant_docs_found = 0
current_avg_precision = 0.0
for doc_idx, doc in enumerate(retrieved_docs):
for gold_id in gold_ids:
if str(doc.id) == gold_id:
relevant_docs_found += 1
if not found_relevant_doc:
correct_retrievals += 1
summed_reciprocal_rank += 1 / (doc_idx + 1)
current_avg_precision += relevant_docs_found / (doc_idx + 1)
found_relevant_doc = True
break
if found_relevant_doc:
all_relevant_docs = len(set(gold_ids))
summed_avg_precision += current_avg_precision / all_relevant_docs
# Metrics
number_of_questions = len(question_label_dict)
recall = correct_retrievals / number_of_questions
mean_reciprocal_rank = summed_reciprocal_rank / number_of_questions
mean_avg_precision = summed_avg_precision / number_of_questions
logger.info(
"For {} out of {} questions ({:.2%}), the answer was in the top-{} candidate passages selected by the retriever.".format(
correct_retrievals, number_of_questions, recall, top_k
)
)
metrics = {
"recall": recall,
"map": mean_avg_precision,
"mrr": mean_reciprocal_rank,
"retrieve_time": time_taken,
"n_questions": number_of_questions,
"top_k": top_k,
}
with open("debug.json", "w") as fp:
json.dump(debug, fp, ensure_ascii=False)
if return_preds:
return {"metrics": metrics, "predictions": predictions}
else:
return metrics