BA-Chatbot/backend/retriever/retriever_pipeline.py

280 lines
10 KiB
Python
Raw Normal View History

2023-11-15 14:28:48 +01:00
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
from haystack.nodes import (
EmbeddingRetriever,
BM25Retriever,
SentenceTransformersRanker,
FilterRetriever,
)
from haystack import Pipeline
from typing import List, Dict, Optional
import os
from dotenv import load_dotenv
from haystack.document_stores import WeaviateDocumentStore
from .LlamaRetriever import LlamaRetriever
from .custom_components.retrieval_model_classifier import MethodRetrieverClassifier
load_dotenv()
sys_path = os.environ.get("SYS_PATH")
es_host = os.environ.get("ELASTIC_HOST", "localhost")
PORT = 9210 if es_host == "localhost" else 9200
# Custom Elasticsearch mapping for multiple embedding fields for every retrieval model
custom_mapping = {
"mappings": {
"properties": {
"content": {"type": "text"},
"content_type": {"type": "text"},
"ada_embedding": {"type": "dense_vector", "dims": 1536},
"mpnet_embedding": {"type": "dense_vector", "dims": 768},
"distilbert_embedding": {"type": "dense_vector", "dims": 512},
"name": {"type": "keyword"},
},
"dynamic_templates": [
{
"strings": {
"path_match": "*",
"match_mapping_type": "string",
"mapping": {"type": "keyword"},
}
}
],
},
"settings": {"analysis": {"analyzer": {"default": {"type": "german"}}}},
}
class CustomPipeline:
"""
The CustomPipeline class orchestrates a variety of retrievers and document stores, utilizing a MethodRetrieverClassifier to direct queries based on defined parameters.
It integrates multiple embedding-based retrieval methods and reranking methods.
"""
def __init__(self, doc_index="document", label_index="label", api_key="") -> None:
"""Initializes the Question Answering Pipeline with retrievers, Document Stores for DB Connections and reranking components.
Args:
doc_index (str, optional): Default Elasticsearch / Weaviate Index. Defaults to "document".
label_index (str, optional): Label index for evaluation purposes. Defaults to "label".
api_key (str, optional): API Key for external Provider Services: Defaults to "".
"""
self.doc_store_ada = ElasticsearchDocumentStore(
host=es_host,
port=PORT,
analyzer="german",
index=doc_index,
label_index=label_index,
embedding_dim=1536,
similarity="dot_product",
embedding_field="ada_embedding",
custom_mapping=custom_mapping,
)
self.doc_store_mpnet = ElasticsearchDocumentStore(
host=es_host,
port=PORT,
analyzer="german",
index=doc_index,
label_index=label_index,
embedding_dim=768,
similarity="dot_product",
embedding_field="mpnet_embedding",
custom_mapping=custom_mapping,
)
self.doc_store_distilbert = ElasticsearchDocumentStore(
host=es_host,
port=PORT,
analyzer="german",
index=doc_index,
label_index=label_index,
embedding_dim=512,
similarity="dot_product",
embedding_field="distilbert_embedding",
custom_mapping=custom_mapping,
)
# self.vector_doc_store_llama = WeaviateDocumentStore(
# host="http://localhost", port=3434, embedding_dim=4096
# )
self.emb_retriever_ada = EmbeddingRetriever(
document_store=self.doc_store_ada,
batch_size=8,
embedding_model="text-embedding-ada-002",
api_key=api_key,
max_seq_len=1536,
)
self.emb_retriever_mpnet = EmbeddingRetriever(
document_store=self.doc_store_mpnet,
embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
model_format="sentence_transformers",
)
self.retriever_distilbert = EmbeddingRetriever(
document_store=self.doc_store_distilbert,
embedding_model="sentence-transformers/distiluse-base-multilingual-cased-v2",
model_format="sentence_transformers",
)
# self.llama_retriever = LlamaRetriever(
# document_store=self.vector_doc_store_llama
# )
self.bm25_retriever = BM25Retriever(document_store=self.doc_store_mpnet)
self.ranker = SentenceTransformersRanker(
model_name_or_path="svalabs/cross-electra-ms-marco-german-uncased",
use_gpu=True,
)
self.init_qa_pipeline()
self.filter_retriever = FilterRetriever(
document_store=self.doc_store_mpnet, all_terms_must_match=True
)
def __init_doc_store(
self,
host: str = os.getenv("ES_HOSTNAME"),
port: int = 9200,
analyzer: str = "german",
index: str = "",
embedding_dim: int = 768,
similarity: str = "dot_product",
custom_mapping: Optional[dict] = None,
):
""" Helper Function to u a document store with the provided configuration.
Args:
host (str, optional): hostname where the DB is running e.g. es01 or localhost. Defaults to os.getenv("ES_HOSTNAME").
port (int, optional): Port where the DB is running. Defaults to 9200.
analyzer (str, optional): Elasticsearch analyzer. Defaults to "german".
index (str, optional): Index which the Document Store referes to. Defaults to "".
embedding_dim (int, optional): Dimenstions of the Embeding Model. Defaults to 768.
similarity (str, optional): Similarity function for retrieval. Defaults to "dot_product".
custom_mapping (Optional[dict], optional): Custom DB Mapping. Defaults to None.
Returns:
_type_: _description_
"""
doc_store = ElasticsearchDocumentStore(
host=host,
port=port,
analyzer=analyzer,
index=index,
embedding_dim=embedding_dim,
similarity=similarity,
custom_mapping=custom_mapping,
)
self.doc_stores[index] = doc_store
return doc_store
def init_qa_pipeline(self):
"""
Initializes the question-answering pipeline by adding necessary retriever nodes , reranking nodes, and Custom Components for retriever routing .
Returns:
Pipeline: The initialized QA pipeline.
"""
pipe = Pipeline()
pipe.add_node(
component=MethodRetrieverClassifier(),
name="RetrieverClassifier",
inputs=["Query"],
)
pipe.add_node(
component=self.emb_retriever_mpnet,
name="EMBRetrieverMPNET",
inputs=["RetrieverClassifier.output_1"],
)
pipe.add_node(
component=self.retriever_distilbert,
name="EMBRetrieverDISTILBERT",
inputs=["RetrieverClassifier.output_2"],
)
pipe.add_node(
component=self.emb_retriever_ada,
name="EMBRetrieverADA",
inputs=["RetrieverClassifier.output_3"],
)
# pipe.add_node(
# component=self.llama_retriever,
# name="EMBRetrieverLLAMA",
# inputs=["RetrieverClassifier.output_4"],
# )
pipe.add_node(component=self.ranker, name="Ranker",
inputs=["EMBRetrieverADA","EMBRetrieverDISTILBERT","EMBRetrieverMPNET"])
self.qa_pipeline = pipe
return self.qa_pipeline
def filter_query(self, query, index, params):
"""
Filters a query based on specified parameters.
Args:
query (str): The query string.
index (str): The index to search in.
params (dict): Additional parameters for filtering.
Returns:
list: The filtered query results.
"""
return self.filter_retriever.retrieve(query=query, index=index, filters=params)
def query_by_ids(self, ids):
return self.doc_store_mpnet.get_documents_by_id(ids)
def query_by_emb(self, index, emb):
return self.vector_doc_store_llama.query_by_embedding(
query_emb=emb, index=index
)
def get_qa_pipeline(self):
"""
Gets the question-answering pipeline.
Returns:
Pipeline: The QA pipeline.
"""
return self.qa_pipeline
def get_all_weaviate_data(self, index):
"""
Retrieves all documents from a Weaviate document store.
Args:
index (str): The index to retrieve documents from.
Returns:
list: All documents from the specified index in Weaviate.
"""
return self.vector_doc_store_llama.get_all_documents(index=index)
def get_all_elastic_data(self, index):
"""
Retrieves all documents from an Elasticsearch document store.
Args:
index (str): The Elasticsearch index to retrieve documents from.
Returns:
list: All documents from the specified index in Elasticsearch.
"""
return self.doc_store_mpnet.get_all_documents(index=index)
def run(self, query, index, retrieval_method):
"""
Runs the QA pipeline with the given query, index, and retrieval method.
Args:
query (str): The query string.
index (str): The index to search in.
retrieval_method (str): The retrieval method to use.
Returns:
dict: The results from running the QA pipeline.
"""
return self.qa_pipeline.run(
query=query,
params={
"RetrieverClassifier": {
"method": retrieval_method,
"index": index,
"top_k": 10,
},
},
)