from typing import Dict from api.embeddingsServiceCaller import EmbeddingServiceCaller from retriever.retriever_pipeline import CustomPipeline import numpy as np class Retriever: def __init__(self, pipeline:CustomPipeline, caller:EmbeddingServiceCaller) -> None: """ Initializes the Retriever class with a CustomPipeline and an EmbeddingServiceCaller. Args: pipeline (CustomPipeline): An instance of the CustomPipeline to handle query retrieval. caller (EmbeddingServiceCaller): An instance of EmbeddingServiceCaller to fetch embeddings for the **MODEL SERVICE**. """ self.pipeline=pipeline self.caller= caller def get_top_k_passages( self, query: str, index: str = "", meta: Dict = {}, method: str = "mpnet" ): """ Retrieves the top K passages for a given query using the specified retrieval method. Args: query (str): The search query. index (str, optional): The index to search in. Defaults to "". meta (Dict, optional): Additional metadata for the query. Defaults to {}. method (str, optional): The retrieval method (e.g., 'mpnet', 'llama'). Defaults to "mpnet". Returns: [type]: The retrieved results. """ emb_query = None results = None if method == "llama": emb_query = np.array(self.caller.get_embeddings(query)) results = self.pipeline.query_by_emb(index=index, emb=emb_query) else: results = self.pipeline.run(query=query, index=index, retrieval_method=method) # self.apply_softmax(results) return results def get_module_credits(self, module: str, index: str = ""): return self.pipeline.filter_query( query="", index="ib", params={"title": [module]} )