BA-Chatbot/backend/module_recommendation.py

114 lines
4.5 KiB
Python

from typing import List
from haystack.schema import Document
from reranker import ReRanker
from reader import Reader
from retriever.retriever import Retriever
from haystack.nodes import FARMReader
class WPMRecommendation:
def __init__(
self,
retriever: Retriever,
reader: Reader,
reRanker: ReRanker,
farm_reader: FARMReader,
) -> None:
"""
Initializes the WPMRecommendation class with required components for retrieving, reranking, and reading documents.
Args:
retriever (Retriever): An instance of Retriever for fetching relevant documents.
reader (Reader): An instance of Reader for interpreting and processing documents.
reRanker (ReRanker): An instance of ReRanker for reranking documents based on relevance.
farm_reader (FARMReader): An instance of FARMReader for additional reading capabilities.
"""
self.retriever = retriever
self.reader = reader
self.reranker = reRanker
self.farm_reader = farm_reader
def _filter_wpms(self, documents: List[Document]):
"""
Filters documents to include only those marked as Wahlpflichtmodule (WPM).
Args:
documents (List[Document]): A list of documents to be filtered.
Returns:
List[Document]: Filtered documents marked as WPM.
"""
return [doc for doc in documents if doc.meta.get("is_wpm") is True]
def _build_query_for_prompt(
self, interets: str, future_carrer: str, previous_courses: str
):
"""
Constructs a query based on the user's interests, future career plans, and previously taken courses.
Args:
interets (str): User's interests.
future_carrer (str): User's future career aspirations.
previous_courses (str): Previously taken courses by the user.
Returns:
str: A constructed query based on the provided information.
"""
query = ""
if interets:
query += f"Ich habe folgende Interessen: \n{interets}.\n"
if future_carrer:
query += f"Zudem möchte ich zukünftig im folgenden Bereich arbeiten:\n{future_carrer}.\n"
if previous_courses:
query += f"Ich habe bereits schon folgenden Wahlplfichtmodule belegt:\n{previous_courses}.\n"
return query
def recommend_wpms(
self,
interets: str,
future_carrer: str,
previous_courses: str,
retrieval_model_or_method="mpnet",
recommendation_method: str = "get_retrieved_results",
rerank_retrieved_results=True,
):
"""
Recommends Wahlpflichtmodule (WPM) based on the user's interests, future career plans, and previous courses.
Args:
interets (str): User's interests.
future_carrer (str): User's future career aspirations.
previous_courses (str): Previously taken courses by the user.
retrieval_model_or_method (str, optional): The retrieval model or method to use. Defaults to "mpnet".
recommendation_method (str, optional): The method for generating recommendations. Defaults to "get_retrieved_results".
rerank_retrieved_results (bool, optional): Flag to determine if reranking should be done on retrieved results. Defaults to True.
Returns:
Varies: Returns different types of outputs based on the recommendation method chosen.
"""
top_k_docs = self.retriever.get_top_k_passages(
query=interets, index="ib", method=retrieval_model_or_method
)["documents"]
retrieved_wpms = self._filter_wpms(top_k_docs)
final_references = retrieved_wpms
query = self._build_query_for_prompt(
interets=interets,
future_carrer=future_carrer,
previous_courses=previous_courses,
)
if rerank_retrieved_results:
reranked_top_k = self.reranker.rerank_documents_with_gpt35(
documents=retrieved_wpms, query=query
)
final_references = self.reranker.get_final_references(
reranked_documents=reranked_top_k, retrieved_documents=retrieved_wpms
)
if recommendation_method == "generate_llm_answer":
return self.reader.get_gpt_wpm_recommendation(
query=query, top_k_wpms=final_references
)
if recommendation_method == "generate_farm_reader_answer":
pass
return final_references