BA-Chatbot/backend/module_recommendation.py

114 lines
4.5 KiB
Python
Raw Normal View History

2023-11-15 14:28:48 +01:00
from typing import List
from haystack.schema import Document
from reranker import ReRanker
from reader import Reader
from retriever.retriever import Retriever
from haystack.nodes import FARMReader
class WPMRecommendation:
def __init__(
self,
retriever: Retriever,
reader: Reader,
reRanker: ReRanker,
farm_reader: FARMReader,
) -> None:
"""
Initializes the WPMRecommendation class with required components for retrieving, reranking, and reading documents.
Args:
retriever (Retriever): An instance of Retriever for fetching relevant documents.
reader (Reader): An instance of Reader for interpreting and processing documents.
reRanker (ReRanker): An instance of ReRanker for reranking documents based on relevance.
farm_reader (FARMReader): An instance of FARMReader for additional reading capabilities.
"""
self.retriever = retriever
self.reader = reader
self.reranker = reRanker
self.farm_reader = farm_reader
def _filter_wpms(self, documents: List[Document]):
"""
Filters documents to include only those marked as Wahlpflichtmodule (WPM).
Args:
documents (List[Document]): A list of documents to be filtered.
Returns:
List[Document]: Filtered documents marked as WPM.
"""
return [doc for doc in documents if doc.meta.get("is_wpm") is True]
def _build_query_for_prompt(
self, interets: str, future_carrer: str, previous_courses: str
):
"""
Constructs a query based on the user's interests, future career plans, and previously taken courses.
Args:
interets (str): User's interests.
future_carrer (str): User's future career aspirations.
previous_courses (str): Previously taken courses by the user.
Returns:
str: A constructed query based on the provided information.
"""
query = ""
if interets:
query += f"Ich habe folgende Interessen: \n{interets}.\n"
if future_carrer:
query += f"Zudem möchte ich zukünftig im folgenden Bereich arbeiten:\n{future_carrer}.\n"
if previous_courses:
query += f"Ich habe bereits schon folgenden Wahlplfichtmodule belegt:\n{previous_courses}.\n"
return query
def recommend_wpms(
self,
interets: str,
future_carrer: str,
previous_courses: str,
retrieval_model_or_method="mpnet",
recommendation_method: str = "get_retrieved_results",
rerank_retrieved_results=True,
):
"""
Recommends Wahlpflichtmodule (WPM) based on the user's interests, future career plans, and previous courses.
Args:
interets (str): User's interests.
future_carrer (str): User's future career aspirations.
previous_courses (str): Previously taken courses by the user.
retrieval_model_or_method (str, optional): The retrieval model or method to use. Defaults to "mpnet".
recommendation_method (str, optional): The method for generating recommendations. Defaults to "get_retrieved_results".
rerank_retrieved_results (bool, optional): Flag to determine if reranking should be done on retrieved results. Defaults to True.
Returns:
Varies: Returns different types of outputs based on the recommendation method chosen.
"""
top_k_docs = self.retriever.get_top_k_passages(
query=interets, index="ib", method=retrieval_model_or_method
)["documents"]
retrieved_wpms = self._filter_wpms(top_k_docs)
final_references = retrieved_wpms
query = self._build_query_for_prompt(
interets=interets,
future_carrer=future_carrer,
previous_courses=previous_courses,
)
if rerank_retrieved_results:
reranked_top_k = self.reranker.rerank_documents_with_gpt35(
documents=retrieved_wpms, query=query
)
final_references = self.reranker.get_final_references(
reranked_documents=reranked_top_k, retrieved_documents=retrieved_wpms
)
if recommendation_method == "generate_llm_answer":
return self.reader.get_gpt_wpm_recommendation(
query=query, top_k_wpms=final_references
)
if recommendation_method == "generate_farm_reader_answer":
pass
return final_references