BA-Chatbot/backend/retriever/LlamaRetriever.py

253 lines
9.4 KiB
Python
Raw Normal View History

2023-11-15 14:28:48 +01:00
# pylint: disable=ungrouped-imports
"""
---------------------------------------------------------------------------
NOTE:
Custom Implementation of an Retriever based on the LLaMA Model, which is compatible with Haystack Retriever Pipeline.
Calls under the hood the MODEL SERVICE.
NOTE: SEE functions embed_queries and embed_documents for pooling strategy and layer extraction
---------------------------------------------------------------------------
"""
from typing import List, Dict, Union, Optional, Any, Literal, Callable
import logging
from pathlib import Path
from copy import deepcopy
from requests.exceptions import HTTPError
import numpy as np
from tqdm import tqdm
import pandas as pd
from huggingface_hub import hf_hub_download
from haystack.errors import HaystackError
from haystack.schema import Document, FilterType
from haystack.document_stores import BaseDocumentStore
from haystack.telemetry import send_event
from haystack.lazy_imports import LazyImport
from haystack.nodes.retriever import DenseRetriever
logger = logging.getLogger(__name__)
with LazyImport(message="Run 'pip install farm-haystack[inference]'") as torch_and_transformers_import:
import torch
from haystack.modeling.utils import initialize_device_settings # pylint: disable=ungrouped-imports
from transformers import AutoConfig
import sys
sys.path.append("../..")
from api.embeddingsServiceCaller import EmbeddingServiceCaller
_EMBEDDING_ENCODERS: Dict[str, Callable] = {
"llama": {}
}
class LlamaRetriever(DenseRetriever):
def __init__(
self,
model_format = "llama",
document_store: Optional[BaseDocumentStore] = None,
model_version: Optional[str] = None,
use_gpu: bool = True,
batch_size: int = 32,
max_seq_len: int = 512,
pooling_strategy: str = "reduce_mean",
emb_extraction_layer: int = -1,
top_k: int = 10,
progress_bar: bool = True,
devices: Optional[List[Union[str, "torch.device"]]] = None,
use_auth_token: Optional[Union[str, bool]] = None,
scale_score: bool = True,
embed_meta_fields: Optional[List[str]] = None,
api_key: Optional[str] = None,
azure_api_version: str = "2022-12-01",
azure_base_url: Optional[str] = None,
azure_deployment_name: Optional[str] = None,
api_base: str = "https://api.openai.com/v1",
openai_organization: Optional[str] = None,
):
torch_and_transformers_import.check()
if embed_meta_fields is None:
embed_meta_fields = []
super().__init__()
self.devices, _ = initialize_device_settings(devices=devices, use_cuda=use_gpu, multi_gpu=True)
if batch_size < len(self.devices):
logger.warning("Batch size is less than the number of devices.All gpus will not be utilized.")
self.document_store = document_store
self.model_version = model_version
self.use_gpu = use_gpu
self.batch_size = batch_size
self.max_seq_len = max_seq_len
self.pooling_strategy = pooling_strategy
self.emb_extraction_layer = emb_extraction_layer
self.top_k = top_k
self.progress_bar = progress_bar
self.use_auth_token = use_auth_token
self.scale_score = scale_score
self.api_key = api_key
self.api_base = api_base
self.api_version = azure_api_version
self.azure_base_url = azure_base_url
self.azure_deployment_name = azure_deployment_name
self.openai_organization = openai_organization
self.model_format= model_format
self.emb_caller= EmbeddingServiceCaller()
self.embed_meta_fields = embed_meta_fields
def retrieve(
self,
query: str,
filters: Optional[FilterType] = None,
top_k: Optional[int] = None,
index: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
scale_score: Optional[bool] = None,
document_store: Optional[BaseDocumentStore] = None,
) -> List[Document]:
document_store = document_store or self.document_store
if document_store is None:
raise ValueError(
"This Retriever was not initialized with a Document Store. Provide one to the retrieve() method."
)
if top_k is None:
top_k = self.top_k
if index is None:
index = document_store.index
if scale_score is None:
scale_score = self.scale_score
query_emb = self.embed_queries(queries=[query])
documents = document_store.query_by_embedding(
query_emb=query_emb, filters=filters, top_k=top_k, index=index, headers=headers, scale_score=scale_score
)
return documents
def retrieve_batch(
self,
queries: List[str],
filters: Optional[Union[FilterType, List[Optional[FilterType]]]] = None,
top_k: Optional[int] = None,
index: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
batch_size: Optional[int] = None,
scale_score: Optional[bool] = None,
document_store: Optional[BaseDocumentStore] = None,
) -> List[List[Document]]:
document_store = document_store or self.document_store
if document_store is None:
raise ValueError(
"This Retriever was not initialized with a Document Store. Provide one to the retrieve_batch() method."
)
if top_k is None:
top_k = self.top_k
if batch_size is None:
batch_size = self.batch_size
if index is None:
index = document_store.index
if scale_score is None:
scale_score = self.scale_score
query_embs: np.ndarray = self.embed_queries(queries=queries)
batched_query_embs: List[np.ndarray] = []
for i in range(0, len(query_embs), batch_size):
batched_query_embs.extend(query_embs[i : i + batch_size])
documents = document_store.query_by_embedding_batch(
query_embs=batched_query_embs,
top_k=top_k,
filters=filters,
index=index,
headers=headers,
scale_score=scale_score,
)
return documents
def embed_queries(self, queries: List[str]) -> np.ndarray:
if isinstance(queries, str):
queries = [queries]
assert isinstance(queries, list), "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])"
return np.array(self.emb_caller.get_embeddings(queries[0] ))
def embed_documents(self, documents: List[Document]) -> np.ndarray:
documents = self._preprocess_documents(documents)
embeddings=[]
for doc in documents:
embeddings.append(self.emb_caller.get_embeddings(doc.content))
return np.array(embeddings)
def _preprocess_documents(self, docs: List[Document]) -> List[Document]:
linearized_docs = []
for doc in docs:
doc = deepcopy(doc)
if doc.content_type == "table":
if isinstance(doc.content, pd.DataFrame):
doc.content = doc.content.to_csv(index=False)
else:
raise HaystackError("Documents of type 'table' need to have a pd.DataFrame as content field")
meta_data_fields = []
for key in self.embed_meta_fields:
if key in doc.meta and doc.meta[key]:
if isinstance(doc.meta[key], list):
meta_data_fields.extend([item for item in doc.meta[key]])
else:
meta_data_fields.append(doc.meta[key])
meta_data_fields = [str(field) for field in meta_data_fields]
doc.content = "\n".join(meta_data_fields + [doc.content])
linearized_docs.append(doc)
return linearized_docs
@staticmethod
def _infer_model_format(model_name_or_path: str, use_auth_token: Optional[Union[str, bool]]) -> str:
valid_openai_model_name = model_name_or_path in ["ada", "babbage", "davinci", "curie"] or any(
m in model_name_or_path for m in ["-ada-", "-babbage-", "-davinci-", "-curie-"]
)
if valid_openai_model_name:
return "openai"
if model_name_or_path in ["small", "medium", "large", "multilingual-22-12", "finance-sentiment"]:
return "cohere"
if Path(model_name_or_path).exists():
if Path(f"{model_name_or_path}/config_sentence_transformers.json").exists():
return "sentence_transformers"
else:
try:
hf_hub_download(
repo_id=model_name_or_path,
filename="config_sentence_transformers.json",
use_auth_token=use_auth_token,
)
return "sentence_transformers"
except HTTPError:
pass
config = AutoConfig.from_pretrained(model_name_or_path, use_auth_token=use_auth_token)
if config.model_type == "retribert":
return "retribert"
return "farm"
def train(
self,
training_data: List[Dict[str, Any]],
learning_rate: float = 2e-5,
n_epochs: int = 1,
num_warmup_steps: Optional[int] = None,
batch_size: int = 16,
train_loss: Literal["mnrl", "margin_mse"] = "mnrl",
num_workers: int = 0,
use_amp: bool = False,
**kwargs,
) -> None:
pass