BA-Chatbot/backend/app.py

223 lines
7.8 KiB
Python
Raw Normal View History

2023-11-15 14:28:48 +01:00
import time
time.sleep(30) # sleep for 30 seconds cause elastic search and weaviate need time to ramp up
from flask import Flask, request, abort, render_template, jsonify
from api.embeddingsServiceCaller import EmbeddingServiceCaller
from retriever.retriever_pipeline import CustomPipeline
from question_answering import QuestionAnswering
from elasticsearch import Elasticsearch
import os
es_host = os.environ.get("ELASTIC_HOST", "localhost")
es = Elasticsearch([{"host": es_host, "port": 9200}])
server = Flask(__name__, static_folder="static")
caller = EmbeddingServiceCaller()
pipeline = CustomPipeline(api_key="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")
question_answering = QuestionAnswering(pipeline=pipeline, embedder=caller)
@server.route("/feedback", methods=["POST", "GET"])
def feedback():
if request.method == "POST":
request_data = request.get_json()
type = request_data.get("type")
user_queston = request_data.get("user_queston")
provided_answer = request_data.get("provided_answer")
retrieval_method_or_model = request_data.get("retrieval_method_or_model")
reader_model = request_data.get("reader_model")
feedback = request_data.get("feedback")
last_searched_index = request_data.get("last_searched_index")
document = {
"type": type,
"user_queston": user_queston,
"provided_answer": provided_answer,
"retrieval_method_or_model": retrieval_method_or_model,
"reader_model": reader_model,
"feedback": feedback,
"last_searched_index": last_searched_index,
}
response = es.index(index="feedback", body=document)
print(response, "response")
return jsonify(response)
elif request.method == "GET":
res = es.search(index="feedback", body={"query": {"match_all": {}}})
feedbacks = []
for hit in res["hits"]["hits"]:
feedbacks.append(hit["_source"])
return jsonify(feedbacks)
@server.route("/get_module_credits", methods=["POST"])
def get_credits():
request_data = request.get_json()
module = request_data["module"]
print(module, flush=True)
result = question_answering.get_module_credits(module=module, index="ib")
print(result, flush=True)
return result
@server.route("/get_relevant_documents", methods=["POST"])
def get_relevant_documents():
request_data = request.get_json()
query = request_data["query"]
index = request_data["index"]
retrieval_method_or_model = request_data.get("retrieval_method_or_model", "mpnet")
meta = request_data.get("meta", {})
result = question_answering.get_top_k(
query=query,
index=index,
retrieval_method_or_model=retrieval_method_or_model,
meta=meta,
)
print(result, flush=True)
return result
@server.route("/get_answer", methods=["POST"])
def get_answer():
request_data = request.get_json()
query = request_data["query"]
index = request_data["index"]
retrieval_method_or_model = request_data.get("retrieval_method_or_model", "mpnet")
reader_model = request_data.get("reader_model", "GPT")
rerank_documents = request_data.get("rerank_documents", True)
result = question_answering.get_answers(
query=query,
index=index,
retrieval_method_or_model=retrieval_method_or_model,
reader_model=reader_model,
rerank_documents=rerank_documents,
)
if isinstance(result, tuple) and len(result or []) > 0:
return {"answer": result[0], "documents": result[1]}
else:
return result
@server.route("/search_experts", methods=["POST"])
def search_experts():
request_data = request.get_json()
if not request_data:
abort(400, description="Bad Request: Expecting JSON data")
query = request_data.get("query")
if not query:
abort(400, description="Missing parameter 'query'")
retriever_model = request_data.get("retriever_model", "mpnet")
reader_model = request_data.get("reader_model", "GPT")
search_method = request_data.get("search_method", "classic_retriever_reader")
generate_answer = request_data.get("generate_answer", False)
rerank_retrieved_results = request_data.get("rerank_retrieved_results", True)
if not query:
return {"status": "failed", "message": "Missing parameter 'query'"}
result = question_answering.search_experts(
query=query,
search_method=search_method,
retriever_model=retriever_model,
generate_answer=generate_answer,
rerank=rerank_retrieved_results,
)
if isinstance(result, tuple) and len(result or []) > 0:
return {"answer": result[0], "documents": result[1]}
else:
return result
@server.route("/recommend_wpms", methods=["POST"])
def recommend_wpms():
request_data = request.get_json()
if not request_data:
abort(400, description="Bad Request: Expecting JSON data")
interests = request_data.get("interests")
previous_courses = request_data.get("previous_courses")
future_carrer = request_data.get("future_carrer")
if not (interests and previous_courses and future_carrer):
abort(
400,
description="Provide at least one of the parameters: 'interests', 'previous_courses' or 'future_carrer' ",
)
retrieval_method_or_model = request_data.get("retrieval_method_or_model", "mpnet")
recommendation_method = request_data.get(
"recommendation_method", "get_retrieved_results"
)
rerank_retrieved_results = request_data.get("rerank_retrieved_results", True)
result = question_answering.recommend_wpm(
interets=interests,
previous_courses=previous_courses,
future_carrer=future_carrer,
recommendation_method=recommendation_method,
rerank_retrieved_results=rerank_retrieved_results,
retrieval_method_or_model=retrieval_method_or_model,
)
if isinstance(result, tuple) and len(result or []) > 0:
return {"answer": result[0], "documents": result[1]}
else:
return result
@server.route("/get_all_weaviate_data", methods=["GET"])
def get_weaviate_data():
index = request.args.get("index")
return pipeline.get_all_weaviate_data(index=index)
@server.route("/get_all_es_data", methods=["GET"])
def get_elastic_data():
index = request.args.get("index")
return pipeline.get_all_elastic_data(index=index)
@server.route("/get_document_by_id", methods=["POST"])
def get_doc_by_id():
request.data = request.get_json()
id = request.data["id"]
return pipeline.query_by_ids([id])
@server.route("/conf1")
def config1():
return render_template("config1.html")
# @server.route("/conf2")
# def config2():
# return render_template("config2.html")
# @server.route("/conf3")
# def config3():
# return render_template("config3.html")
# @server.route("/conf4")
# def config4():
# return render_template("config4.html")
# @server.route("/conf5")
# def config5():
# return render_template("config5.html")
# @server.route("/conf6")
# def config6():
# return render_template("config6.html")
if __name__ == "__main__":
if not es.indices.exists(index="feedback"):
mapping = {
"mappings": {
"properties": {
"question": {"type": "text"},
"answer": {"type": "text"},
"feedback": {"type": "text"},
"timestamp": {"type": "date"},
}
}
}
es.indices.create(index="feedback", body=mapping)
server.run(host="::", port=8080)