from haystack.nodes.base import BaseComponent from typing import List, Optional class MethodRetrieverClassifier(BaseComponent): """ The MethodRetrieverClassifier class, serves as a routing component within a pipeline, determining which retriever to use based on a specified method. It supports different retrieval techniques such as "mpnet", "distilbert", "ada", and "llama", assigning queries to the corresponding retriever. Args: BaseComponent (_type_): Haystack BaseComponent for compability """ outgoing_edges = 7 def run(self, method: str, index: str, query: str, intent=None, top_k=5): params = {"top_k": top_k, "index": index} if method == "mpnet": return params, "output_1" elif method == "distilbert": return params, "output_2" elif method == "ada": return params, "output_3" elif method == "llama": return params, "output_4" else: return params, "output_1" def run_batch( self, method: str, index: str, queries: List[str], top_k=5, my_arg: Optional[int] = 10, ): params = {"top_k": top_k, "index": index} if method == "mpnet": return params, "output_1" elif method == "distilbert": return params, "output_2" elif method == "ada": return params, "output_3" elif method == "llama": return params, "output_4" else: return params, "output_1"