forked from 1827133/BA-Chatbot
254 lines
9.4 KiB
Python
254 lines
9.4 KiB
Python
# pylint: disable=ungrouped-imports
|
|
|
|
|
|
"""
|
|
---------------------------------------------------------------------------
|
|
NOTE:
|
|
Custom Implementation of an Retriever based on the LLaMA Model, which is compatible with Haystack Retriever Pipeline.
|
|
Calls under the hood the MODEL SERVICE.
|
|
|
|
NOTE: SEE functions embed_queries and embed_documents for pooling strategy and layer extraction
|
|
---------------------------------------------------------------------------
|
|
"""
|
|
from typing import List, Dict, Union, Optional, Any, Literal, Callable
|
|
import logging
|
|
from pathlib import Path
|
|
from copy import deepcopy
|
|
from requests.exceptions import HTTPError
|
|
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
|
|
import pandas as pd
|
|
from huggingface_hub import hf_hub_download
|
|
|
|
from haystack.errors import HaystackError
|
|
from haystack.schema import Document, FilterType
|
|
from haystack.document_stores import BaseDocumentStore
|
|
from haystack.telemetry import send_event
|
|
from haystack.lazy_imports import LazyImport
|
|
from haystack.nodes.retriever import DenseRetriever
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
with LazyImport(message="Run 'pip install farm-haystack[inference]'") as torch_and_transformers_import:
|
|
import torch
|
|
from haystack.modeling.utils import initialize_device_settings # pylint: disable=ungrouped-imports
|
|
from transformers import AutoConfig
|
|
|
|
import sys
|
|
sys.path.append("../..")
|
|
from api.embeddingsServiceCaller import EmbeddingServiceCaller
|
|
|
|
_EMBEDDING_ENCODERS: Dict[str, Callable] = {
|
|
"llama": {}
|
|
}
|
|
|
|
class LlamaRetriever(DenseRetriever):
|
|
def __init__(
|
|
self,
|
|
model_format = "llama",
|
|
document_store: Optional[BaseDocumentStore] = None,
|
|
model_version: Optional[str] = None,
|
|
use_gpu: bool = True,
|
|
batch_size: int = 32,
|
|
max_seq_len: int = 512,
|
|
pooling_strategy: str = "reduce_mean",
|
|
emb_extraction_layer: int = -1,
|
|
top_k: int = 10,
|
|
progress_bar: bool = True,
|
|
devices: Optional[List[Union[str, "torch.device"]]] = None,
|
|
use_auth_token: Optional[Union[str, bool]] = None,
|
|
scale_score: bool = True,
|
|
embed_meta_fields: Optional[List[str]] = None,
|
|
api_key: Optional[str] = None,
|
|
azure_api_version: str = "2022-12-01",
|
|
azure_base_url: Optional[str] = None,
|
|
azure_deployment_name: Optional[str] = None,
|
|
api_base: str = "https://api.openai.com/v1",
|
|
openai_organization: Optional[str] = None,
|
|
):
|
|
torch_and_transformers_import.check()
|
|
|
|
if embed_meta_fields is None:
|
|
embed_meta_fields = []
|
|
super().__init__()
|
|
|
|
self.devices, _ = initialize_device_settings(devices=devices, use_cuda=use_gpu, multi_gpu=True)
|
|
|
|
if batch_size < len(self.devices):
|
|
logger.warning("Batch size is less than the number of devices.All gpus will not be utilized.")
|
|
|
|
self.document_store = document_store
|
|
self.model_version = model_version
|
|
self.use_gpu = use_gpu
|
|
self.batch_size = batch_size
|
|
self.max_seq_len = max_seq_len
|
|
self.pooling_strategy = pooling_strategy
|
|
self.emb_extraction_layer = emb_extraction_layer
|
|
self.top_k = top_k
|
|
self.progress_bar = progress_bar
|
|
self.use_auth_token = use_auth_token
|
|
self.scale_score = scale_score
|
|
self.api_key = api_key
|
|
self.api_base = api_base
|
|
self.api_version = azure_api_version
|
|
self.azure_base_url = azure_base_url
|
|
self.azure_deployment_name = azure_deployment_name
|
|
self.openai_organization = openai_organization
|
|
self.model_format= model_format
|
|
self.emb_caller= EmbeddingServiceCaller()
|
|
|
|
|
|
|
|
|
|
self.embed_meta_fields = embed_meta_fields
|
|
|
|
def retrieve(
|
|
self,
|
|
query: str,
|
|
filters: Optional[FilterType] = None,
|
|
top_k: Optional[int] = None,
|
|
index: Optional[str] = None,
|
|
headers: Optional[Dict[str, str]] = None,
|
|
scale_score: Optional[bool] = None,
|
|
document_store: Optional[BaseDocumentStore] = None,
|
|
) -> List[Document]:
|
|
document_store = document_store or self.document_store
|
|
if document_store is None:
|
|
raise ValueError(
|
|
"This Retriever was not initialized with a Document Store. Provide one to the retrieve() method."
|
|
)
|
|
if top_k is None:
|
|
top_k = self.top_k
|
|
if index is None:
|
|
index = document_store.index
|
|
if scale_score is None:
|
|
scale_score = self.scale_score
|
|
query_emb = self.embed_queries(queries=[query])
|
|
documents = document_store.query_by_embedding(
|
|
query_emb=query_emb, filters=filters, top_k=top_k, index=index, headers=headers, scale_score=scale_score
|
|
)
|
|
return documents
|
|
|
|
def retrieve_batch(
|
|
self,
|
|
queries: List[str],
|
|
filters: Optional[Union[FilterType, List[Optional[FilterType]]]] = None,
|
|
top_k: Optional[int] = None,
|
|
index: Optional[str] = None,
|
|
headers: Optional[Dict[str, str]] = None,
|
|
batch_size: Optional[int] = None,
|
|
scale_score: Optional[bool] = None,
|
|
document_store: Optional[BaseDocumentStore] = None,
|
|
) -> List[List[Document]]:
|
|
document_store = document_store or self.document_store
|
|
if document_store is None:
|
|
raise ValueError(
|
|
"This Retriever was not initialized with a Document Store. Provide one to the retrieve_batch() method."
|
|
)
|
|
if top_k is None:
|
|
top_k = self.top_k
|
|
|
|
if batch_size is None:
|
|
batch_size = self.batch_size
|
|
|
|
if index is None:
|
|
index = document_store.index
|
|
if scale_score is None:
|
|
scale_score = self.scale_score
|
|
|
|
query_embs: np.ndarray = self.embed_queries(queries=queries)
|
|
batched_query_embs: List[np.ndarray] = []
|
|
for i in range(0, len(query_embs), batch_size):
|
|
batched_query_embs.extend(query_embs[i : i + batch_size])
|
|
documents = document_store.query_by_embedding_batch(
|
|
query_embs=batched_query_embs,
|
|
top_k=top_k,
|
|
filters=filters,
|
|
index=index,
|
|
headers=headers,
|
|
scale_score=scale_score,
|
|
)
|
|
|
|
return documents
|
|
|
|
def embed_queries(self, queries: List[str]) -> np.ndarray:
|
|
if isinstance(queries, str):
|
|
queries = [queries]
|
|
assert isinstance(queries, list), "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])"
|
|
return np.array(self.emb_caller.get_embeddings(queries[0]))
|
|
|
|
def embed_documents(self, documents: List[Document]) -> np.ndarray:
|
|
documents = self._preprocess_documents(documents)
|
|
embeddings=[]
|
|
for doc in documents:
|
|
embeddings.append(self.emb_caller.get_embeddings(doc.content))
|
|
return np.array(embeddings)
|
|
|
|
def _preprocess_documents(self, docs: List[Document]) -> List[Document]:
|
|
linearized_docs = []
|
|
for doc in docs:
|
|
doc = deepcopy(doc)
|
|
if doc.content_type == "table":
|
|
if isinstance(doc.content, pd.DataFrame):
|
|
doc.content = doc.content.to_csv(index=False)
|
|
else:
|
|
raise HaystackError("Documents of type 'table' need to have a pd.DataFrame as content field")
|
|
meta_data_fields = []
|
|
for key in self.embed_meta_fields:
|
|
if key in doc.meta and doc.meta[key]:
|
|
if isinstance(doc.meta[key], list):
|
|
meta_data_fields.extend([item for item in doc.meta[key]])
|
|
else:
|
|
meta_data_fields.append(doc.meta[key])
|
|
meta_data_fields = [str(field) for field in meta_data_fields]
|
|
doc.content = "\n".join(meta_data_fields + [doc.content])
|
|
linearized_docs.append(doc)
|
|
return linearized_docs
|
|
|
|
@staticmethod
|
|
def _infer_model_format(model_name_or_path: str, use_auth_token: Optional[Union[str, bool]]) -> str:
|
|
valid_openai_model_name = model_name_or_path in ["ada", "babbage", "davinci", "curie"] or any(
|
|
m in model_name_or_path for m in ["-ada-", "-babbage-", "-davinci-", "-curie-"]
|
|
)
|
|
if valid_openai_model_name:
|
|
return "openai"
|
|
if model_name_or_path in ["small", "medium", "large", "multilingual-22-12", "finance-sentiment"]:
|
|
return "cohere"
|
|
if Path(model_name_or_path).exists():
|
|
if Path(f"{model_name_or_path}/config_sentence_transformers.json").exists():
|
|
return "sentence_transformers"
|
|
else:
|
|
try:
|
|
hf_hub_download(
|
|
repo_id=model_name_or_path,
|
|
filename="config_sentence_transformers.json",
|
|
use_auth_token=use_auth_token,
|
|
)
|
|
return "sentence_transformers"
|
|
except HTTPError:
|
|
pass
|
|
|
|
config = AutoConfig.from_pretrained(model_name_or_path, use_auth_token=use_auth_token)
|
|
if config.model_type == "retribert":
|
|
return "retribert"
|
|
|
|
return "farm"
|
|
|
|
def train(
|
|
self,
|
|
training_data: List[Dict[str, Any]],
|
|
learning_rate: float = 2e-5,
|
|
n_epochs: int = 1,
|
|
num_warmup_steps: Optional[int] = None,
|
|
batch_size: int = 16,
|
|
train_loss: Literal["mnrl", "margin_mse"] = "mnrl",
|
|
num_workers: int = 0,
|
|
use_amp: bool = False,
|
|
**kwargs,
|
|
) -> None:
|
|
pass
|
|
|