45 lines
1.9 KiB
Python
45 lines
1.9 KiB
Python
|
from typing import Dict
|
||
|
from api.embeddingsServiceCaller import EmbeddingServiceCaller
|
||
|
from retriever.retriever_pipeline import CustomPipeline
|
||
|
import numpy as np
|
||
|
class Retriever:
|
||
|
def __init__(self, pipeline:CustomPipeline, caller:EmbeddingServiceCaller) -> None:
|
||
|
"""
|
||
|
Initializes the Retriever class with a CustomPipeline and an EmbeddingServiceCaller.
|
||
|
|
||
|
Args:
|
||
|
pipeline (CustomPipeline): An instance of the CustomPipeline to handle query retrieval.
|
||
|
caller (EmbeddingServiceCaller): An instance of EmbeddingServiceCaller to fetch embeddings for the **MODEL SERVICE**.
|
||
|
"""
|
||
|
self.pipeline=pipeline
|
||
|
self.caller= caller
|
||
|
|
||
|
def get_top_k_passages(
|
||
|
self, query: str, index: str = "", meta: Dict = {}, method: str = "mpnet"
|
||
|
):
|
||
|
"""
|
||
|
Retrieves the top K passages for a given query using the specified retrieval method.
|
||
|
|
||
|
Args:
|
||
|
query (str): The search query.
|
||
|
index (str, optional): The index to search in. Defaults to "".
|
||
|
meta (Dict, optional): Additional metadata for the query. Defaults to {}.
|
||
|
method (str, optional): The retrieval method (e.g., 'mpnet', 'llama'). Defaults to "mpnet".
|
||
|
|
||
|
Returns:
|
||
|
[type]: The retrieved results.
|
||
|
"""
|
||
|
emb_query = None
|
||
|
results = None
|
||
|
if method == "llama":
|
||
|
emb_query = np.array(self.caller.get_embeddings(query))
|
||
|
results = self.pipeline.query_by_emb(index=index, emb=emb_query)
|
||
|
else:
|
||
|
results = self.pipeline.run(query=query, index=index, retrieval_method=method)
|
||
|
# self.apply_softmax(results)
|
||
|
return results
|
||
|
|
||
|
def get_module_credits(self, module: str, index: str = ""):
|
||
|
return self.pipeline.filter_query(
|
||
|
query="", index="ib", params={"title": [module]}
|
||
|
)
|