BA-Chatbot/backend/retriever/retriever.py

45 lines
1.9 KiB
Python
Raw Normal View History

2023-11-15 14:28:48 +01:00
from typing import Dict
from api.embeddingsServiceCaller import EmbeddingServiceCaller
from retriever.retriever_pipeline import CustomPipeline
import numpy as np
class Retriever:
def __init__(self, pipeline:CustomPipeline, caller:EmbeddingServiceCaller) -> None:
"""
Initializes the Retriever class with a CustomPipeline and an EmbeddingServiceCaller.
Args:
pipeline (CustomPipeline): An instance of the CustomPipeline to handle query retrieval.
caller (EmbeddingServiceCaller): An instance of EmbeddingServiceCaller to fetch embeddings for the **MODEL SERVICE**.
"""
self.pipeline=pipeline
self.caller= caller
def get_top_k_passages(
self, query: str, index: str = "", meta: Dict = {}, method: str = "mpnet"
):
"""
Retrieves the top K passages for a given query using the specified retrieval method.
Args:
query (str): The search query.
index (str, optional): The index to search in. Defaults to "".
meta (Dict, optional): Additional metadata for the query. Defaults to {}.
method (str, optional): The retrieval method (e.g., 'mpnet', 'llama'). Defaults to "mpnet".
Returns:
[type]: The retrieved results.
"""
emb_query = None
results = None
if method == "llama":
emb_query = np.array(self.caller.get_embeddings(query))
results = self.pipeline.query_by_emb(index=index, emb=emb_query)
else:
results = self.pipeline.run(query=query, index=index, retrieval_method=method)
# self.apply_softmax(results)
return results
def get_module_credits(self, module: str, index: str = ""):
return self.pipeline.filter_query(
query="", index="ib", params={"title": [module]}
)