forked from 1827133/BA-Chatbot
101 lines
3.9 KiB
Python
101 lines
3.9 KiB
Python
|
from typing import Dict, List
|
||
|
from haystack.schema import Document
|
||
|
from api.embeddingsServiceCaller import EmbeddingServiceCaller
|
||
|
from helper.openai import (
|
||
|
openai_doc_reference_prompt_v1,
|
||
|
openai_doc_citation_prompt_v2,
|
||
|
MAX_GPT4_TOKENS,
|
||
|
GPT4_COMPLETION_TOKENS,
|
||
|
MAX_GPT35_TURBO_TOKENS,
|
||
|
RERANKING_TOKENS,
|
||
|
count_prompt_tokens_gpt4,
|
||
|
count_prompt_tokens_gpt35,
|
||
|
)
|
||
|
import json
|
||
|
import ast
|
||
|
|
||
|
|
||
|
class ReRanker:
|
||
|
def __init__(self) -> None:
|
||
|
"""
|
||
|
Initializes the ReRanker class with an caller for MODEL SERVICE.
|
||
|
"""
|
||
|
self.caller = EmbeddingServiceCaller()
|
||
|
|
||
|
def rerank_documents_with_gpt35(self, documents: List[Document], query: str):
|
||
|
"""
|
||
|
Reranks a list of documents using GPT-3.5 based on a given query.
|
||
|
|
||
|
Args:
|
||
|
documents (List[Document]): A list of Document objects to be reranked.
|
||
|
query (str): The query string used for reranking.
|
||
|
|
||
|
Returns:
|
||
|
List[Document]: A list of reranked Document objects.
|
||
|
"""
|
||
|
formatted_documents = []
|
||
|
reranked_documents_token_count = count_prompt_tokens_gpt35(
|
||
|
openai_doc_citation_prompt_v2
|
||
|
)
|
||
|
for doc in documents:
|
||
|
reranked_documents_token_count += count_prompt_tokens_gpt35(doc.content)
|
||
|
if (
|
||
|
reranked_documents_token_count
|
||
|
< MAX_GPT35_TURBO_TOKENS - RERANKING_TOKENS
|
||
|
):
|
||
|
formatted_documents.append({"content": doc.content, "id": doc.id})
|
||
|
|
||
|
payload = json.dumps(
|
||
|
{
|
||
|
"system_prompt": openai_doc_citation_prompt_v2,
|
||
|
"documents": formatted_documents,
|
||
|
"query": query,
|
||
|
}
|
||
|
)
|
||
|
sorted_document_ids = self.caller.rerank_documents_gpt(payload=payload)
|
||
|
print(sorted_document_ids, "sorted_document_ids")
|
||
|
message_content = sorted_document_ids["choices"][0]["message"]["content"]
|
||
|
|
||
|
# Check if the message content is a string representation of a list. If not then return empty list.
|
||
|
# If yes then parse it, and check if the returned ids exists.
|
||
|
try:
|
||
|
content_list = ast.literal_eval(message_content)
|
||
|
if isinstance(content_list, list):
|
||
|
# Proceed with further processing
|
||
|
return [doc for doc in documents for id in content_list if id == doc.id]
|
||
|
else:
|
||
|
return []
|
||
|
except (SyntaxError, ValueError):
|
||
|
return []
|
||
|
|
||
|
def get_final_references(
|
||
|
self, reranked_documents: List[Document], retrieved_documents: List[Document]
|
||
|
) -> List[Document]:
|
||
|
"""
|
||
|
Combines reranked and retrieved documents, ensuring no duplicates and maintaining order.
|
||
|
|
||
|
Args:
|
||
|
reranked_documents (List[Document]): The documents after reranking.
|
||
|
retrieved_documents (List[Document]): The original set of retrieved documents.
|
||
|
|
||
|
Returns:
|
||
|
List[Document]: A combined list of reranked and retrieved documents.
|
||
|
"""
|
||
|
final_references = list(reranked_documents)
|
||
|
if not reranked_documents:
|
||
|
return retrieved_documents
|
||
|
# If The model in the Re-Ranking process did not return all document ids.
|
||
|
# In that Case, we create a new sorted list. The first indexes are the existing documents
|
||
|
# from the re-ranking, followed by the missing ones from the retriever.
|
||
|
elif len(reranked_documents) < len(retrieved_documents):
|
||
|
reranked_ids = set(doc.id for doc in reranked_documents)
|
||
|
missing_documents = [
|
||
|
doc for doc in retrieved_documents if doc.id not in reranked_ids
|
||
|
]
|
||
|
final_references.extend(missing_documents)
|
||
|
return final_references
|
||
|
elif len(reranked_documents) == len(retrieved_documents):
|
||
|
return final_references
|
||
|
else:
|
||
|
return retrieved_documents
|