37 lines
1.5 KiB
Python
37 lines
1.5 KiB
Python
import numpy as np
|
|
import sys
|
|
from LlamaRetriever import LlamaRetriever
|
|
sys.path.append("../..")
|
|
from retriever.retriever_pipeline import CustomPipeline
|
|
from api.embeddingsServiceCaller import EmbeddingServiceCaller
|
|
from haystack.nodes import PreProcessor
|
|
from custom_evaluation import eval, eval_llama
|
|
|
|
caller= EmbeddingServiceCaller()
|
|
doc_index = "stupo_eval_docs_llama"
|
|
label_index = "stupo_eval_labels_llama"
|
|
pipeline= CustomPipeline(label_index=label_index, doc_index=doc_index)
|
|
|
|
retriever= LlamaRetriever(document_store=pipeline.vector_doc_store_llama)
|
|
open_domain=True
|
|
if not open_domain:
|
|
preprocessor = PreProcessor(
|
|
split_by="word",
|
|
split_length=100,
|
|
split_overlap=0,
|
|
split_respect_sentence_boundary=False,
|
|
clean_empty_lines=False,
|
|
clean_whitespace=False,
|
|
)
|
|
# emb_query = np.array(caller.get_embeddings(query))
|
|
# results = pipeline.query_by_emb(index=index, emb=emb_query)
|
|
pipeline.doc_store_mpnet.add_eval_data(
|
|
filename="squad_format.json",
|
|
doc_index=doc_index,
|
|
label_index=label_index,
|
|
preprocessor=preprocessor,
|
|
)
|
|
# pipeline.vector_doc_store.update_embeddings(retriever,index=doc_index )
|
|
index= "stupo" if open_domain else doc_index
|
|
retriever_eval_results= eval_llama(label_index=label_index, doc_index=index, top_k=30, document_store= pipeline.doc_store_mpnet, vector_store= pipeline.vector_doc_store_llama, retriever= retriever, reRanker=None, open_domain=open_domain)
|
|
print(retriever_eval_results) |