BA-Chatbot/backend/evaluation/eval_retriever/eval_llama.py

37 lines
1.5 KiB
Python
Raw Normal View History

2023-11-15 14:28:48 +01:00
import numpy as np
import sys
from LlamaRetriever import LlamaRetriever
sys.path.append("../..")
from retriever.retriever_pipeline import CustomPipeline
from api.embeddingsServiceCaller import EmbeddingServiceCaller
from haystack.nodes import PreProcessor
from custom_evaluation import eval, eval_llama
caller= EmbeddingServiceCaller()
doc_index = "stupo_eval_docs_llama"
label_index = "stupo_eval_labels_llama"
pipeline= CustomPipeline(label_index=label_index, doc_index=doc_index)
retriever= LlamaRetriever(document_store=pipeline.vector_doc_store_llama)
open_domain=True
if not open_domain:
preprocessor = PreProcessor(
split_by="word",
split_length=100,
split_overlap=0,
split_respect_sentence_boundary=False,
clean_empty_lines=False,
clean_whitespace=False,
)
# emb_query = np.array(caller.get_embeddings(query))
# results = pipeline.query_by_emb(index=index, emb=emb_query)
pipeline.doc_store_mpnet.add_eval_data(
filename="squad_format.json",
doc_index=doc_index,
label_index=label_index,
preprocessor=preprocessor,
)
# pipeline.vector_doc_store.update_embeddings(retriever,index=doc_index )
index= "stupo" if open_domain else doc_index
retriever_eval_results= eval_llama(label_index=label_index, doc_index=index, top_k=30, document_store= pipeline.doc_store_mpnet, vector_store= pipeline.vector_doc_store_llama, retriever= retriever, reRanker=None, open_domain=open_domain)
print(retriever_eval_results)