BA-Chatbot/backend/retriever/custom_components/retrieval_model_classifier.py

47 lines
1.5 KiB
Python
Raw Permalink Normal View History

2023-11-15 14:28:48 +01:00
from haystack.nodes.base import BaseComponent
from typing import List, Optional
class MethodRetrieverClassifier(BaseComponent):
"""
The MethodRetrieverClassifier class, serves as a routing component within a pipeline, determining which retriever to use based on a specified method.
It supports different retrieval techniques such as "mpnet", "distilbert", "ada", and "llama", assigning queries to the corresponding retriever.
Args:
BaseComponent (_type_): Haystack BaseComponent for compability
"""
outgoing_edges = 7
def run(self, method: str, index: str, query: str, intent=None, top_k=5):
params = {"top_k": top_k, "index": index}
if method == "mpnet":
return params, "output_1"
elif method == "distilbert":
return params, "output_2"
elif method == "ada":
return params, "output_3"
elif method == "llama":
return params, "output_4"
else:
return params, "output_1"
def run_batch(
self,
method: str,
index: str,
queries: List[str],
top_k=5,
my_arg: Optional[int] = 10,
):
params = {"top_k": top_k, "index": index}
if method == "mpnet":
return params, "output_1"
elif method == "distilbert":
return params, "output_2"
elif method == "ada":
return params, "output_3"
elif method == "llama":
return params, "output_4"
else:
return params, "output_1"