47 lines
1.5 KiB
Python
47 lines
1.5 KiB
Python
from haystack.nodes.base import BaseComponent
|
|
from typing import List, Optional
|
|
|
|
|
|
class MethodRetrieverClassifier(BaseComponent):
|
|
"""
|
|
The MethodRetrieverClassifier class, serves as a routing component within a pipeline, determining which retriever to use based on a specified method.
|
|
It supports different retrieval techniques such as "mpnet", "distilbert", "ada", and "llama", assigning queries to the corresponding retriever.
|
|
|
|
Args:
|
|
BaseComponent (_type_): Haystack BaseComponent for compability
|
|
"""
|
|
outgoing_edges = 7
|
|
|
|
def run(self, method: str, index: str, query: str, intent=None, top_k=5):
|
|
params = {"top_k": top_k, "index": index}
|
|
if method == "mpnet":
|
|
return params, "output_1"
|
|
elif method == "distilbert":
|
|
return params, "output_2"
|
|
elif method == "ada":
|
|
return params, "output_3"
|
|
elif method == "llama":
|
|
return params, "output_4"
|
|
else:
|
|
return params, "output_1"
|
|
|
|
def run_batch(
|
|
self,
|
|
method: str,
|
|
index: str,
|
|
queries: List[str],
|
|
top_k=5,
|
|
my_arg: Optional[int] = 10,
|
|
):
|
|
params = {"top_k": top_k, "index": index}
|
|
if method == "mpnet":
|
|
return params, "output_1"
|
|
elif method == "distilbert":
|
|
return params, "output_2"
|
|
elif method == "ada":
|
|
return params, "output_3"
|
|
elif method == "llama":
|
|
return params, "output_4"
|
|
else:
|
|
return params, "output_1"
|