forked from 1827133/BA-Chatbot
280 lines
10 KiB
Python
Executable File
280 lines
10 KiB
Python
Executable File
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
|
|
from haystack.nodes import (
|
|
EmbeddingRetriever,
|
|
BM25Retriever,
|
|
SentenceTransformersRanker,
|
|
FilterRetriever,
|
|
)
|
|
from haystack import Pipeline
|
|
from typing import List, Dict, Optional
|
|
import os
|
|
from dotenv import load_dotenv
|
|
from haystack.document_stores import WeaviateDocumentStore
|
|
from .LlamaRetriever import LlamaRetriever
|
|
from .custom_components.retrieval_model_classifier import MethodRetrieverClassifier
|
|
|
|
load_dotenv()
|
|
sys_path = os.environ.get("SYS_PATH")
|
|
es_host = os.environ.get("ELASTIC_HOST", "localhost")
|
|
PORT = 9210 if es_host == "localhost" else 9200
|
|
|
|
|
|
# Custom Elasticsearch mapping for multiple embedding fields for every retrieval model
|
|
custom_mapping = {
|
|
"mappings": {
|
|
"properties": {
|
|
"content": {"type": "text"},
|
|
"content_type": {"type": "text"},
|
|
"ada_embedding": {"type": "dense_vector", "dims": 1536},
|
|
"mpnet_embedding": {"type": "dense_vector", "dims": 768},
|
|
"distilbert_embedding": {"type": "dense_vector", "dims": 512},
|
|
"name": {"type": "keyword"},
|
|
},
|
|
"dynamic_templates": [
|
|
{
|
|
"strings": {
|
|
"path_match": "*",
|
|
"match_mapping_type": "string",
|
|
"mapping": {"type": "keyword"},
|
|
}
|
|
}
|
|
],
|
|
},
|
|
"settings": {"analysis": {"analyzer": {"default": {"type": "german"}}}},
|
|
}
|
|
|
|
|
|
class CustomPipeline:
|
|
"""
|
|
The CustomPipeline class orchestrates a variety of retrievers and document stores, utilizing a MethodRetrieverClassifier to direct queries based on defined parameters.
|
|
It integrates multiple embedding-based retrieval methods and reranking methods.
|
|
"""
|
|
def __init__(self, doc_index="document", label_index="label", api_key="") -> None:
|
|
"""Initializes the Question Answering Pipeline with retrievers, Document Stores for DB Connections and reranking components.
|
|
|
|
Args:
|
|
doc_index (str, optional): Default Elasticsearch / Weaviate Index. Defaults to "document".
|
|
label_index (str, optional): Label index for evaluation purposes. Defaults to "label".
|
|
api_key (str, optional): API Key for external Provider Services: Defaults to "".
|
|
"""
|
|
self.doc_store_ada = ElasticsearchDocumentStore(
|
|
host=es_host,
|
|
port=PORT,
|
|
analyzer="german",
|
|
index=doc_index,
|
|
label_index=label_index,
|
|
embedding_dim=1536,
|
|
similarity="dot_product",
|
|
embedding_field="ada_embedding",
|
|
custom_mapping=custom_mapping,
|
|
)
|
|
self.doc_store_mpnet = ElasticsearchDocumentStore(
|
|
host=es_host,
|
|
port=PORT,
|
|
analyzer="german",
|
|
index=doc_index,
|
|
label_index=label_index,
|
|
embedding_dim=768,
|
|
similarity="dot_product",
|
|
embedding_field="mpnet_embedding",
|
|
custom_mapping=custom_mapping,
|
|
)
|
|
self.doc_store_distilbert = ElasticsearchDocumentStore(
|
|
host=es_host,
|
|
port=PORT,
|
|
analyzer="german",
|
|
index=doc_index,
|
|
label_index=label_index,
|
|
embedding_dim=512,
|
|
similarity="dot_product",
|
|
embedding_field="distilbert_embedding",
|
|
custom_mapping=custom_mapping,
|
|
)
|
|
|
|
# self.vector_doc_store_llama = WeaviateDocumentStore(
|
|
# host="http://localhost", port=3434, embedding_dim=4096
|
|
# )
|
|
|
|
self.emb_retriever_ada = EmbeddingRetriever(
|
|
document_store=self.doc_store_ada,
|
|
batch_size=8,
|
|
embedding_model="text-embedding-ada-002",
|
|
api_key=api_key,
|
|
max_seq_len=1536,
|
|
)
|
|
|
|
self.emb_retriever_mpnet = EmbeddingRetriever(
|
|
document_store=self.doc_store_mpnet,
|
|
embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
|
|
model_format="sentence_transformers",
|
|
)
|
|
self.retriever_distilbert = EmbeddingRetriever(
|
|
document_store=self.doc_store_distilbert,
|
|
embedding_model="sentence-transformers/distiluse-base-multilingual-cased-v2",
|
|
model_format="sentence_transformers",
|
|
)
|
|
# self.llama_retriever = LlamaRetriever(
|
|
# document_store=self.vector_doc_store_llama
|
|
# )
|
|
self.bm25_retriever = BM25Retriever(document_store=self.doc_store_mpnet)
|
|
self.ranker = SentenceTransformersRanker(
|
|
model_name_or_path="svalabs/cross-electra-ms-marco-german-uncased",
|
|
use_gpu=True,
|
|
)
|
|
self.init_qa_pipeline()
|
|
self.filter_retriever = FilterRetriever(
|
|
document_store=self.doc_store_mpnet, all_terms_must_match=True
|
|
)
|
|
|
|
def __init_doc_store(
|
|
self,
|
|
host: str = os.getenv("ES_HOSTNAME"),
|
|
port: int = 9200,
|
|
analyzer: str = "german",
|
|
index: str = "",
|
|
embedding_dim: int = 768,
|
|
similarity: str = "dot_product",
|
|
custom_mapping: Optional[dict] = None,
|
|
):
|
|
""" Helper Function to u a document store with the provided configuration.
|
|
|
|
Args:
|
|
host (str, optional): hostname where the DB is running e.g. es01 or localhost. Defaults to os.getenv("ES_HOSTNAME").
|
|
port (int, optional): Port where the DB is running. Defaults to 9200.
|
|
analyzer (str, optional): Elasticsearch analyzer. Defaults to "german".
|
|
index (str, optional): Index which the Document Store referes to. Defaults to "".
|
|
embedding_dim (int, optional): Dimenstions of the Embeding Model. Defaults to 768.
|
|
similarity (str, optional): Similarity function for retrieval. Defaults to "dot_product".
|
|
custom_mapping (Optional[dict], optional): Custom DB Mapping. Defaults to None.
|
|
|
|
Returns:
|
|
_type_: _description_
|
|
"""
|
|
doc_store = ElasticsearchDocumentStore(
|
|
host=host,
|
|
port=port,
|
|
analyzer=analyzer,
|
|
index=index,
|
|
embedding_dim=embedding_dim,
|
|
similarity=similarity,
|
|
custom_mapping=custom_mapping,
|
|
)
|
|
self.doc_stores[index] = doc_store
|
|
return doc_store
|
|
|
|
def init_qa_pipeline(self):
|
|
"""
|
|
Initializes the question-answering pipeline by adding necessary retriever nodes , reranking nodes, and Custom Components for retriever routing .
|
|
|
|
Returns:
|
|
Pipeline: The initialized QA pipeline.
|
|
"""
|
|
pipe = Pipeline()
|
|
pipe.add_node(
|
|
component=MethodRetrieverClassifier(),
|
|
name="RetrieverClassifier",
|
|
inputs=["Query"],
|
|
)
|
|
pipe.add_node(
|
|
component=self.emb_retriever_mpnet,
|
|
name="EMBRetrieverMPNET",
|
|
inputs=["RetrieverClassifier.output_1"],
|
|
)
|
|
pipe.add_node(
|
|
component=self.retriever_distilbert,
|
|
name="EMBRetrieverDISTILBERT",
|
|
inputs=["RetrieverClassifier.output_2"],
|
|
)
|
|
pipe.add_node(
|
|
component=self.emb_retriever_ada,
|
|
name="EMBRetrieverADA",
|
|
inputs=["RetrieverClassifier.output_3"],
|
|
)
|
|
# pipe.add_node(
|
|
# component=self.llama_retriever,
|
|
# name="EMBRetrieverLLAMA",
|
|
# inputs=["RetrieverClassifier.output_4"],
|
|
# )
|
|
pipe.add_node(component=self.ranker, name="Ranker",
|
|
inputs=["EMBRetrieverADA","EMBRetrieverDISTILBERT","EMBRetrieverMPNET"])
|
|
self.qa_pipeline = pipe
|
|
return self.qa_pipeline
|
|
|
|
def filter_query(self, query, index, params):
|
|
"""
|
|
Filters a query based on specified parameters.
|
|
|
|
Args:
|
|
query (str): The query string.
|
|
index (str): The index to search in.
|
|
params (dict): Additional parameters for filtering.
|
|
|
|
Returns:
|
|
list: The filtered query results.
|
|
"""
|
|
return self.filter_retriever.retrieve(query=query, index=index, filters=params)
|
|
|
|
def query_by_ids(self, ids):
|
|
return self.doc_store_mpnet.get_documents_by_id(ids)
|
|
|
|
def query_by_emb(self, index, emb):
|
|
return self.vector_doc_store_llama.query_by_embedding(
|
|
query_emb=emb, index=index
|
|
)
|
|
|
|
def get_qa_pipeline(self):
|
|
"""
|
|
Gets the question-answering pipeline.
|
|
|
|
Returns:
|
|
Pipeline: The QA pipeline.
|
|
"""
|
|
return self.qa_pipeline
|
|
|
|
def get_all_weaviate_data(self, index):
|
|
"""
|
|
Retrieves all documents from a Weaviate document store.
|
|
|
|
Args:
|
|
index (str): The index to retrieve documents from.
|
|
|
|
Returns:
|
|
list: All documents from the specified index in Weaviate.
|
|
"""
|
|
return self.vector_doc_store_llama.get_all_documents(index=index)
|
|
|
|
def get_all_elastic_data(self, index):
|
|
"""
|
|
Retrieves all documents from an Elasticsearch document store.
|
|
|
|
Args:
|
|
index (str): The Elasticsearch index to retrieve documents from.
|
|
|
|
Returns:
|
|
list: All documents from the specified index in Elasticsearch.
|
|
"""
|
|
return self.doc_store_mpnet.get_all_documents(index=index)
|
|
|
|
def run(self, query, index, retrieval_method):
|
|
"""
|
|
Runs the QA pipeline with the given query, index, and retrieval method.
|
|
|
|
Args:
|
|
query (str): The query string.
|
|
index (str): The index to search in.
|
|
retrieval_method (str): The retrieval method to use.
|
|
|
|
Returns:
|
|
dict: The results from running the QA pipeline.
|
|
"""
|
|
return self.qa_pipeline.run(
|
|
query=query,
|
|
params={
|
|
"RetrieverClassifier": {
|
|
"method": retrieval_method,
|
|
"index": index,
|
|
"top_k": 10,
|
|
},
|
|
},
|
|
)
|