BA-Chatbot/backend/reranker.py

101 lines
3.9 KiB
Python
Raw Permalink Normal View History

2023-11-15 14:28:48 +01:00
from typing import Dict, List
from haystack.schema import Document
from api.embeddingsServiceCaller import EmbeddingServiceCaller
from helper.openai import (
openai_doc_reference_prompt_v1,
openai_doc_citation_prompt_v2,
MAX_GPT4_TOKENS,
GPT4_COMPLETION_TOKENS,
MAX_GPT35_TURBO_TOKENS,
RERANKING_TOKENS,
count_prompt_tokens_gpt4,
count_prompt_tokens_gpt35,
)
import json
import ast
class ReRanker:
def __init__(self) -> None:
"""
Initializes the ReRanker class with an caller for MODEL SERVICE.
"""
self.caller = EmbeddingServiceCaller()
def rerank_documents_with_gpt35(self, documents: List[Document], query: str):
"""
Reranks a list of documents using GPT-3.5 based on a given query.
Args:
documents (List[Document]): A list of Document objects to be reranked.
query (str): The query string used for reranking.
Returns:
List[Document]: A list of reranked Document objects.
"""
formatted_documents = []
reranked_documents_token_count = count_prompt_tokens_gpt35(
openai_doc_citation_prompt_v2
)
for doc in documents:
reranked_documents_token_count += count_prompt_tokens_gpt35(doc.content)
if (
reranked_documents_token_count
< MAX_GPT35_TURBO_TOKENS - RERANKING_TOKENS
):
formatted_documents.append({"content": doc.content, "id": doc.id})
payload = json.dumps(
{
"system_prompt": openai_doc_citation_prompt_v2,
"documents": formatted_documents,
"query": query,
}
)
sorted_document_ids = self.caller.rerank_documents_gpt(payload=payload)
print(sorted_document_ids, "sorted_document_ids")
message_content = sorted_document_ids["choices"][0]["message"]["content"]
# Check if the message content is a string representation of a list. If not then return empty list.
# If yes then parse it, and check if the returned ids exists.
try:
content_list = ast.literal_eval(message_content)
if isinstance(content_list, list):
# Proceed with further processing
return [doc for doc in documents for id in content_list if id == doc.id]
else:
return []
except (SyntaxError, ValueError):
return []
def get_final_references(
self, reranked_documents: List[Document], retrieved_documents: List[Document]
) -> List[Document]:
"""
Combines reranked and retrieved documents, ensuring no duplicates and maintaining order.
Args:
reranked_documents (List[Document]): The documents after reranking.
retrieved_documents (List[Document]): The original set of retrieved documents.
Returns:
List[Document]: A combined list of reranked and retrieved documents.
"""
final_references = list(reranked_documents)
if not reranked_documents:
return retrieved_documents
# If The model in the Re-Ranking process did not return all document ids.
# In that Case, we create a new sorted list. The first indexes are the existing documents
# from the re-ranking, followed by the missing ones from the retriever.
elif len(reranked_documents) < len(retrieved_documents):
reranked_ids = set(doc.id for doc in reranked_documents)
missing_documents = [
doc for doc in retrieved_documents if doc.id not in reranked_ids
]
final_references.extend(missing_documents)
return final_references
elif len(reranked_documents) == len(retrieved_documents):
return final_references
else:
return retrieved_documents