114 lines
4.5 KiB
Python
114 lines
4.5 KiB
Python
from typing import List
|
|
from haystack.schema import Document
|
|
from reranker import ReRanker
|
|
from reader import Reader
|
|
from retriever.retriever import Retriever
|
|
from haystack.nodes import FARMReader
|
|
|
|
|
|
class WPMRecommendation:
|
|
def __init__(
|
|
self,
|
|
retriever: Retriever,
|
|
reader: Reader,
|
|
reRanker: ReRanker,
|
|
farm_reader: FARMReader,
|
|
) -> None:
|
|
"""
|
|
Initializes the WPMRecommendation class with required components for retrieving, reranking, and reading documents.
|
|
Args:
|
|
retriever (Retriever): An instance of Retriever for fetching relevant documents.
|
|
reader (Reader): An instance of Reader for interpreting and processing documents.
|
|
reRanker (ReRanker): An instance of ReRanker for reranking documents based on relevance.
|
|
farm_reader (FARMReader): An instance of FARMReader for additional reading capabilities.
|
|
"""
|
|
self.retriever = retriever
|
|
self.reader = reader
|
|
self.reranker = reRanker
|
|
self.farm_reader = farm_reader
|
|
|
|
def _filter_wpms(self, documents: List[Document]):
|
|
"""
|
|
Filters documents to include only those marked as Wahlpflichtmodule (WPM).
|
|
|
|
Args:
|
|
documents (List[Document]): A list of documents to be filtered.
|
|
|
|
Returns:
|
|
List[Document]: Filtered documents marked as WPM.
|
|
"""
|
|
return [doc for doc in documents if doc.meta.get("is_wpm") is True]
|
|
|
|
def _build_query_for_prompt(
|
|
self, interets: str, future_carrer: str, previous_courses: str
|
|
):
|
|
"""
|
|
Constructs a query based on the user's interests, future career plans, and previously taken courses.
|
|
|
|
Args:
|
|
interets (str): User's interests.
|
|
future_carrer (str): User's future career aspirations.
|
|
previous_courses (str): Previously taken courses by the user.
|
|
|
|
Returns:
|
|
str: A constructed query based on the provided information.
|
|
"""
|
|
|
|
query = ""
|
|
if interets:
|
|
query += f"Ich habe folgende Interessen: \n{interets}.\n"
|
|
if future_carrer:
|
|
query += f"Zudem möchte ich zukünftig im folgenden Bereich arbeiten:\n{future_carrer}.\n"
|
|
if previous_courses:
|
|
query += f"Ich habe bereits schon folgenden Wahlplfichtmodule belegt:\n{previous_courses}.\n"
|
|
return query
|
|
|
|
def recommend_wpms(
|
|
self,
|
|
interets: str,
|
|
future_carrer: str,
|
|
previous_courses: str,
|
|
retrieval_model_or_method="mpnet",
|
|
recommendation_method: str = "get_retrieved_results",
|
|
rerank_retrieved_results=True,
|
|
):
|
|
"""
|
|
Recommends Wahlpflichtmodule (WPM) based on the user's interests, future career plans, and previous courses.
|
|
|
|
Args:
|
|
interets (str): User's interests.
|
|
future_carrer (str): User's future career aspirations.
|
|
previous_courses (str): Previously taken courses by the user.
|
|
retrieval_model_or_method (str, optional): The retrieval model or method to use. Defaults to "mpnet".
|
|
recommendation_method (str, optional): The method for generating recommendations. Defaults to "get_retrieved_results".
|
|
rerank_retrieved_results (bool, optional): Flag to determine if reranking should be done on retrieved results. Defaults to True.
|
|
|
|
Returns:
|
|
Varies: Returns different types of outputs based on the recommendation method chosen.
|
|
"""
|
|
top_k_docs = self.retriever.get_top_k_passages(
|
|
query=interets, index="ib", method=retrieval_model_or_method
|
|
)["documents"]
|
|
retrieved_wpms = self._filter_wpms(top_k_docs)
|
|
final_references = retrieved_wpms
|
|
query = self._build_query_for_prompt(
|
|
interets=interets,
|
|
future_carrer=future_carrer,
|
|
previous_courses=previous_courses,
|
|
)
|
|
if rerank_retrieved_results:
|
|
reranked_top_k = self.reranker.rerank_documents_with_gpt35(
|
|
documents=retrieved_wpms, query=query
|
|
)
|
|
final_references = self.reranker.get_final_references(
|
|
reranked_documents=reranked_top_k, retrieved_documents=retrieved_wpms
|
|
)
|
|
if recommendation_method == "generate_llm_answer":
|
|
return self.reader.get_gpt_wpm_recommendation(
|
|
query=query, top_k_wpms=final_references
|
|
)
|
|
if recommendation_method == "generate_farm_reader_answer":
|
|
pass
|
|
|
|
return final_references
|