118 lines
4.4 KiB
Python
118 lines
4.4 KiB
Python
|
|
||
|
"""
|
||
|
The CustomPipeline class is designed for initializing and managing various retriever components in a data processing pipeline.
|
||
|
This class is particularly focused on generating embeddings and storing them for efficient retrieval.
|
||
|
It supports multiple document stores including Elasticsearch, Weaviate, and integrates various types of retrievers like EmbeddingRetriever and LlamaRetriever.
|
||
|
Each retriever is configured to work with specific embedding models (e.g., Ada, MPNet, DistilBERT, Llama) and document stores.
|
||
|
The class also defines custom mappings for Elasticsearch to handle different types of embeddings and document properties.
|
||
|
This setup facilitates advanced information retrieval tasks by leveraging the strengths of different embedding models and retrieval strategies.
|
||
|
"""
|
||
|
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
|
||
|
from haystack.document_stores import WeaviateDocumentStore, FAISSDocumentStore
|
||
|
from haystack.nodes import (
|
||
|
EmbeddingRetriever,
|
||
|
BM25Retriever,
|
||
|
JoinDocuments,
|
||
|
SentenceTransformersRanker,
|
||
|
FilterRetriever,
|
||
|
)
|
||
|
from haystack import Pipeline
|
||
|
from typing import List, Dict, Optional
|
||
|
import os
|
||
|
from dotenv import load_dotenv
|
||
|
from .LlamaRetriever import LlamaRetriever
|
||
|
|
||
|
load_dotenv()
|
||
|
sys_path = os.environ.get("SYS_PATH")
|
||
|
es_host = os.environ.get("ELASTIC_HOST", "localhost")
|
||
|
PORT = 9210 if es_host == "localhost" else 9200
|
||
|
|
||
|
custom_mapping = {
|
||
|
"mappings": {
|
||
|
"properties": {
|
||
|
"content": {"type": "text"},
|
||
|
"content_type": {"type": "text"},
|
||
|
"ada_embedding": {"type": "dense_vector", "dims": 1536},
|
||
|
"mpnet_embedding": {"type": "dense_vector", "dims": 768},
|
||
|
"distilbert_embedding": {"type": "dense_vector", "dims": 512},
|
||
|
"name": {"type": "keyword"},
|
||
|
},
|
||
|
"dynamic_templates": [
|
||
|
{
|
||
|
"strings": {
|
||
|
"path_match": "*",
|
||
|
"match_mapping_type": "string",
|
||
|
"mapping": {"type": "keyword"},
|
||
|
}
|
||
|
}
|
||
|
],
|
||
|
},
|
||
|
"settings": {"analysis": {"analyzer": {"default": {"type": "german"}}}},
|
||
|
}
|
||
|
|
||
|
|
||
|
class CustomPipeline:
|
||
|
def __init__(self, doc_index="document", label_index="label", api_key="") -> None:
|
||
|
self.doc_store_ada = ElasticsearchDocumentStore(
|
||
|
host=es_host,
|
||
|
port=PORT,
|
||
|
analyzer="german",
|
||
|
index=doc_index,
|
||
|
label_index=label_index,
|
||
|
embedding_dim=1536,
|
||
|
similarity="dot_product",
|
||
|
embedding_field="ada_embedding",
|
||
|
custom_mapping=custom_mapping,
|
||
|
)
|
||
|
self.doc_store_mpnet = ElasticsearchDocumentStore(
|
||
|
host=es_host,
|
||
|
port=PORT,
|
||
|
analyzer="german",
|
||
|
index=doc_index,
|
||
|
label_index=label_index,
|
||
|
embedding_dim=768,
|
||
|
similarity="dot_product",
|
||
|
embedding_field="mpnet_embedding",
|
||
|
custom_mapping=custom_mapping,
|
||
|
)
|
||
|
self.doc_store_distilbert = ElasticsearchDocumentStore(
|
||
|
host=es_host,
|
||
|
port=PORT,
|
||
|
analyzer="german",
|
||
|
index=doc_index,
|
||
|
label_index=label_index,
|
||
|
embedding_dim=512,
|
||
|
similarity="dot_product",
|
||
|
embedding_field="distilbert_embedding",
|
||
|
custom_mapping=custom_mapping,
|
||
|
)
|
||
|
|
||
|
self.vector_doc_store_llama = WeaviateDocumentStore(
|
||
|
host="http://localhost", port=3434, embedding_dim=4096
|
||
|
)
|
||
|
|
||
|
self.emb_retriever_ada = EmbeddingRetriever(
|
||
|
document_store=self.doc_store_ada,
|
||
|
batch_size=8,
|
||
|
embedding_model="text-embedding-ada-002",
|
||
|
api_key=api_key,
|
||
|
max_seq_len=1536,
|
||
|
)
|
||
|
|
||
|
self.emb_retriever_mpnet = EmbeddingRetriever(
|
||
|
document_store=self.doc_store_mpnet,
|
||
|
embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
|
||
|
model_format="sentence_transformers",
|
||
|
use_gpu=False
|
||
|
)
|
||
|
self.retriever_distilbert = EmbeddingRetriever(
|
||
|
document_store=self.doc_store_distilbert,
|
||
|
embedding_model="sentence-transformers/distiluse-base-multilingual-cased-v2",
|
||
|
model_format="sentence_transformers",
|
||
|
use_gpu=False
|
||
|
|
||
|
)
|
||
|
self.llama_retriever = LlamaRetriever(
|
||
|
document_store=self.vector_doc_store_llama
|
||
|
)
|