forked from 1827133/BA-Chatbot
244 lines
9.1 KiB
Python
244 lines
9.1 KiB
Python
|
from typing import Dict, List
|
||
|
from api.embeddingsServiceCaller import EmbeddingServiceCaller
|
||
|
from retriever.retriever import Retriever
|
||
|
from reader import Reader
|
||
|
from embeddings.transformer_llama import LlamaTransformerEmbeddings
|
||
|
from retriever.retriever_pipeline import CustomPipeline
|
||
|
from embeddings.llama import Embedder
|
||
|
from haystack import Document
|
||
|
import json
|
||
|
import ast
|
||
|
import numpy as np
|
||
|
from scipy.special import softmax
|
||
|
from helper.openai import (
|
||
|
openai_doc_reference_prompt_v1,
|
||
|
openai_doc_citation_prompt_v2,
|
||
|
MAX_GPT4_TOKENS,
|
||
|
GPT4_COMPLETION_TOKENS,
|
||
|
MAX_GPT35_TURBO_TOKENS,
|
||
|
RERANKING_TOKENS,
|
||
|
count_prompt_tokens_gpt4,
|
||
|
count_prompt_tokens_gpt35,
|
||
|
)
|
||
|
from reranker import ReRanker
|
||
|
from expert_search import ExpertSearch
|
||
|
from module_recommendation import WPMRecommendation
|
||
|
from haystack.nodes import FARMReader
|
||
|
|
||
|
B_INST, E_INST = "[INST]", "[/INST]"
|
||
|
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
||
|
|
||
|
|
||
|
class QuestionAnswering:
|
||
|
"""
|
||
|
The QuestionAnswering class serves as a comprehensive manager for handling various aspects of question answering, including expert search and module recommendations. It integrates multiple components like retrievers, rerankers, and readers to facilitate efficient information retrieval and processing.
|
||
|
|
||
|
Attributes:
|
||
|
qa_pipeline (CustomPipeline): A pipeline for document retrieval and processing.
|
||
|
caller (LlamaTransformerEmbeddings | EmbeddingServiceCaller): MODEL SERVICE Caller
|
||
|
reranker (ReRanker): A component for reranking documents based on relevance.
|
||
|
retriever (Retriever): A component for retrieving documents.
|
||
|
reader (Reader): A component for reading and interpreting documents.
|
||
|
bert_reader (FARMReader): A FARM-based reader for additional Reader.
|
||
|
expert_search (ExpertSearch): A component for conducting expert searches.
|
||
|
wpm_recommendation (WPMRecommendation): A component for recommending Wahlpflichtmodule (elective modules).
|
||
|
"""
|
||
|
|
||
|
THRESHOLD = 0.5
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
pipeline: CustomPipeline,
|
||
|
embedder: LlamaTransformerEmbeddings | EmbeddingServiceCaller,
|
||
|
):
|
||
|
"""
|
||
|
Initializes the QuestionAnswering class with required components.
|
||
|
|
||
|
Args:
|
||
|
pipeline (CustomPipeline): A pipeline for document retrieval and processing.
|
||
|
embedder (LlamaTransformerEmbeddings | EmbeddingServiceCaller): MODEL SERVICE CALLER.
|
||
|
"""
|
||
|
self.qa_pipeline = pipeline
|
||
|
self.caller = embedder
|
||
|
self.reranker = ReRanker()
|
||
|
self.retriever = Retriever(pipeline=self.qa_pipeline, caller=self.caller)
|
||
|
self.reader = Reader(caller=self.caller)
|
||
|
# NOTE: The BERT Reader is here and not in reader.py
|
||
|
# TODO: Shift this to reader.py
|
||
|
self.bert_reader = FARMReader(
|
||
|
model_name_or_path="deepset/gelectra-base-germanquad-distilled",
|
||
|
use_gpu=True,
|
||
|
use_confidence_scores=False,
|
||
|
)
|
||
|
self.expert_search = ExpertSearch(
|
||
|
pipeline=self.qa_pipeline,
|
||
|
retriever=self.retriever,
|
||
|
reader=self.reader,
|
||
|
reRanker=self.reranker,
|
||
|
farm_reader= self.bert_reader
|
||
|
|
||
|
)
|
||
|
self.wpm_recommendation = WPMRecommendation(
|
||
|
reader=self.reader, retriever=self.retriever, reRanker=self.reranker, farm_reader=self.bert_reader
|
||
|
)
|
||
|
|
||
|
|
||
|
def search_experts(
|
||
|
self,
|
||
|
query: str,
|
||
|
search_method: str,
|
||
|
retriever_model: str,
|
||
|
generate_answer: bool,
|
||
|
rerank: bool,
|
||
|
):
|
||
|
"""
|
||
|
Conducts an expert search based on the specified parameters.
|
||
|
|
||
|
Args:
|
||
|
query (str): The search query.
|
||
|
search_method (str): The method of search.
|
||
|
retriever_model (str): The retrieval model to be used.
|
||
|
generate_answer (bool): Whether to generate an answer using a reader.
|
||
|
rerank (bool): Whether to rerank the retrieved documents.
|
||
|
|
||
|
Returns:
|
||
|
Varies: The result of the expert search.
|
||
|
"""
|
||
|
return self.expert_search.search_experts(
|
||
|
query=query,
|
||
|
rerank_documents=rerank,
|
||
|
retrieval_method=retriever_model,
|
||
|
generate_anwser=generate_answer,
|
||
|
search_method=search_method
|
||
|
)
|
||
|
|
||
|
def recommend_wpm(
|
||
|
self,
|
||
|
interets: str,
|
||
|
future_carrer: str,
|
||
|
previous_courses: str,
|
||
|
retrieval_method_or_model: str,
|
||
|
recommendation_method: str,
|
||
|
rerank_retrieved_results: bool,
|
||
|
):
|
||
|
"""
|
||
|
Provides recommendations for elective modules (Wahlpflichtmodule, WPM) based on user input.
|
||
|
|
||
|
Args:
|
||
|
interets (str): User's interests.
|
||
|
future_carrer (str): User's future career aspirations.
|
||
|
previous_courses (str): Previously taken courses.
|
||
|
retrieval_method_or_model (str): The retrieval model/method.
|
||
|
recommendation_method (str): The recommendation method.
|
||
|
rerank_retrieved_results (bool): Whether to rerank retrieved results.
|
||
|
|
||
|
Returns:
|
||
|
Varies: Recommendations for elective modules.
|
||
|
"""
|
||
|
return self.wpm_recommendation.recommend_wpms(
|
||
|
interets=interets,
|
||
|
future_carrer=future_carrer,
|
||
|
previous_courses=previous_courses,
|
||
|
recommendation_method=recommendation_method,
|
||
|
rerank_retrieved_results=rerank_retrieved_results,
|
||
|
retrieval_model_or_method=retrieval_method_or_model,
|
||
|
)
|
||
|
|
||
|
def get_top_k(self, query, index, meta, retrieval_method_or_model):
|
||
|
"""
|
||
|
Retrieves the top k documents based on the query and retrieval method.
|
||
|
|
||
|
Args:
|
||
|
query (str): The search query.
|
||
|
index (str): The index to search in.
|
||
|
meta (Dict): Additional metadata for the query.
|
||
|
retrieval_method_or_model (str): The retrieval method or model.
|
||
|
|
||
|
Returns:
|
||
|
List[Document]: A list of retrieved documents.
|
||
|
"""
|
||
|
return self.retriever.get_top_k_passages(
|
||
|
index=index, query=query, meta=meta, method=retrieval_method_or_model
|
||
|
)
|
||
|
|
||
|
# Answers for STUPO and Crawled Data
|
||
|
def get_answers(
|
||
|
self,
|
||
|
query: str,
|
||
|
index: str = "",
|
||
|
meta: Dict = {},
|
||
|
retrieval_method_or_model: str = "mpnet",
|
||
|
reader_model: str = "",
|
||
|
rerank_documents=True,
|
||
|
):
|
||
|
"""
|
||
|
Retrieves answers for a given query using various models and methods.
|
||
|
NOTE: This is only providing answers for stupo or crawled data questions. Expert Search and WPMs have own functions.
|
||
|
|
||
|
|
||
|
Args:
|
||
|
query (str): The query to answer.
|
||
|
index (str, optional): The index to search in.
|
||
|
meta (Dict, optional): Additional metadata.
|
||
|
retrieval_method_or_model (str, optional): Retrieval method/model.
|
||
|
reader_model (str, optional): Reader model for generating answers.
|
||
|
rerank_documents (bool, optional): Whether to rerank documents.
|
||
|
|
||
|
Returns:
|
||
|
Varies: The generated answers.
|
||
|
"""
|
||
|
|
||
|
top_k_passages = self.retriever.get_top_k_passages(
|
||
|
query=query, index=index, meta=meta, method=retrieval_method_or_model
|
||
|
)["documents"]
|
||
|
reranked_passages = None
|
||
|
if rerank_documents:
|
||
|
reranked_passages = self.reranker.rerank_documents_with_gpt35(
|
||
|
documents=top_k_passages, query=query
|
||
|
)
|
||
|
final_passages = self.reranker.get_final_references(
|
||
|
reranked_documents=reranked_passages or [],
|
||
|
retrieved_documents=top_k_passages,
|
||
|
)
|
||
|
if index in ["stupo", "crawled_hsma"]:
|
||
|
if reader_model == "GPT":
|
||
|
return self.reader.get_gpt_answer(
|
||
|
top_k_passages=final_passages, query=query
|
||
|
)
|
||
|
elif reader_model == "Bert":
|
||
|
return (
|
||
|
self.bert_reader.predict(
|
||
|
query=query,
|
||
|
documents=final_passages,
|
||
|
top_k=10,
|
||
|
),
|
||
|
final_passages,
|
||
|
)
|
||
|
elif reader_model == "Llama":
|
||
|
return {
|
||
|
"answers": [
|
||
|
{
|
||
|
"answer": self.reader.generate_llama_answer(
|
||
|
top_k_passages=final_passages, query=query
|
||
|
)
|
||
|
}
|
||
|
]
|
||
|
}, final_passages
|
||
|
else:
|
||
|
return {"choices": [{"text": "Ich weiß die Antwort nicht"}]}
|
||
|
|
||
|
def get_module_credits(self, module: str, index: str = "ib"):
|
||
|
return self.retriever.get_module_credits(
|
||
|
query="", index=index, params={"title": [module]}
|
||
|
)
|
||
|
|
||
|
def apply_softmax(self, documents: Dict):
|
||
|
"""Applies Softmax to the scores of the answers
|
||
|
Args:
|
||
|
documents (Dict): Responses from a pipeline in Haystack format
|
||
|
"""
|
||
|
scores = softmax(np.array([answer.score for answer in documents["documents"]]))
|
||
|
for answer, score in zip(documents["documents"], scores):
|
||
|
answer.score = score
|
||
|
return softmax(scores)
|