37 lines
1.6 KiB
Python
37 lines
1.6 KiB
Python
from custom_evaluation import eval
|
|
doc_index = "stupo_eval_docs_distilbert"
|
|
label_index = "stupo_eval_labels_distilbert"
|
|
|
|
from haystack.nodes import PreProcessor
|
|
import sys
|
|
sys.path.append("../..")
|
|
from retriever.retriever_pipeline import CustomPipeline
|
|
pipeline= CustomPipeline(doc_index=doc_index, label_index=label_index)
|
|
from reranker import ReRanker
|
|
reranker= ReRanker()
|
|
open_domain=True
|
|
|
|
if not open_domain:
|
|
preprocessor = PreProcessor(
|
|
split_by="word",
|
|
split_length=100,
|
|
split_overlap=0,
|
|
split_respect_sentence_boundary=False,
|
|
clean_empty_lines=False,
|
|
clean_whitespace=False,
|
|
)
|
|
pipeline.doc_store_distilbert.delete_documents(index=doc_index)
|
|
pipeline.doc_store_distilbert.delete_documents(index=label_index)
|
|
|
|
# The add_eval_data() method converts the given dataset in json format into Haystack document and label objects. Those objects are then indexed in their respective document and label index in the document store. The method can be used with any dataset in SQuAD format.
|
|
pipeline.doc_store_distilbert.add_eval_data(
|
|
filename="squad_format.json",
|
|
doc_index=doc_index,
|
|
label_index=label_index,
|
|
preprocessor=preprocessor,
|
|
)
|
|
pipeline.doc_store_distilbert.update_embeddings(pipeline.retriever_distilbert, index=doc_index)
|
|
|
|
index= "stupo" if open_domain else doc_index
|
|
retriever_eval_results= eval(label_index=label_index, doc_index=index, top_k=20, document_store= pipeline.doc_store_distilbert, retriever= pipeline.retriever_distilbert, reRankerGPT=None, rerankerPipeline=pipeline.ranker, open_domain=open_domain)
|
|
print(retriever_eval_results) |