forked from 1827133/BA-Chatbot
223 lines
7.8 KiB
Python
223 lines
7.8 KiB
Python
|
import time
|
||
|
|
||
|
time.sleep(30) # sleep for 30 seconds cause elastic search and weaviate need time to ramp up
|
||
|
from flask import Flask, request, abort, render_template, jsonify
|
||
|
from api.embeddingsServiceCaller import EmbeddingServiceCaller
|
||
|
from retriever.retriever_pipeline import CustomPipeline
|
||
|
from question_answering import QuestionAnswering
|
||
|
from elasticsearch import Elasticsearch
|
||
|
import os
|
||
|
|
||
|
es_host = os.environ.get("ELASTIC_HOST", "localhost")
|
||
|
es = Elasticsearch([{"host": es_host, "port": 9200}])
|
||
|
|
||
|
server = Flask(__name__, static_folder="static")
|
||
|
caller = EmbeddingServiceCaller()
|
||
|
pipeline = CustomPipeline(api_key="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")
|
||
|
question_answering = QuestionAnswering(pipeline=pipeline, embedder=caller)
|
||
|
|
||
|
|
||
|
@server.route("/feedback", methods=["POST", "GET"])
|
||
|
def feedback():
|
||
|
if request.method == "POST":
|
||
|
request_data = request.get_json()
|
||
|
type = request_data.get("type")
|
||
|
user_queston = request_data.get("user_queston")
|
||
|
provided_answer = request_data.get("provided_answer")
|
||
|
retrieval_method_or_model = request_data.get("retrieval_method_or_model")
|
||
|
reader_model = request_data.get("reader_model")
|
||
|
feedback = request_data.get("feedback")
|
||
|
last_searched_index = request_data.get("last_searched_index")
|
||
|
document = {
|
||
|
"type": type,
|
||
|
"user_queston": user_queston,
|
||
|
"provided_answer": provided_answer,
|
||
|
"retrieval_method_or_model": retrieval_method_or_model,
|
||
|
"reader_model": reader_model,
|
||
|
"feedback": feedback,
|
||
|
"last_searched_index": last_searched_index,
|
||
|
}
|
||
|
response = es.index(index="feedback", body=document)
|
||
|
print(response, "response")
|
||
|
return jsonify(response)
|
||
|
elif request.method == "GET":
|
||
|
res = es.search(index="feedback", body={"query": {"match_all": {}}})
|
||
|
feedbacks = []
|
||
|
for hit in res["hits"]["hits"]:
|
||
|
feedbacks.append(hit["_source"])
|
||
|
return jsonify(feedbacks)
|
||
|
|
||
|
|
||
|
@server.route("/get_module_credits", methods=["POST"])
|
||
|
def get_credits():
|
||
|
request_data = request.get_json()
|
||
|
module = request_data["module"]
|
||
|
print(module, flush=True)
|
||
|
result = question_answering.get_module_credits(module=module, index="ib")
|
||
|
print(result, flush=True)
|
||
|
return result
|
||
|
|
||
|
|
||
|
@server.route("/get_relevant_documents", methods=["POST"])
|
||
|
def get_relevant_documents():
|
||
|
request_data = request.get_json()
|
||
|
query = request_data["query"]
|
||
|
index = request_data["index"]
|
||
|
retrieval_method_or_model = request_data.get("retrieval_method_or_model", "mpnet")
|
||
|
meta = request_data.get("meta", {})
|
||
|
result = question_answering.get_top_k(
|
||
|
query=query,
|
||
|
index=index,
|
||
|
retrieval_method_or_model=retrieval_method_or_model,
|
||
|
meta=meta,
|
||
|
)
|
||
|
print(result, flush=True)
|
||
|
return result
|
||
|
|
||
|
|
||
|
@server.route("/get_answer", methods=["POST"])
|
||
|
def get_answer():
|
||
|
request_data = request.get_json()
|
||
|
query = request_data["query"]
|
||
|
index = request_data["index"]
|
||
|
retrieval_method_or_model = request_data.get("retrieval_method_or_model", "mpnet")
|
||
|
reader_model = request_data.get("reader_model", "GPT")
|
||
|
rerank_documents = request_data.get("rerank_documents", True)
|
||
|
result = question_answering.get_answers(
|
||
|
query=query,
|
||
|
index=index,
|
||
|
retrieval_method_or_model=retrieval_method_or_model,
|
||
|
reader_model=reader_model,
|
||
|
rerank_documents=rerank_documents,
|
||
|
)
|
||
|
if isinstance(result, tuple) and len(result or []) > 0:
|
||
|
return {"answer": result[0], "documents": result[1]}
|
||
|
else:
|
||
|
return result
|
||
|
|
||
|
|
||
|
@server.route("/search_experts", methods=["POST"])
|
||
|
def search_experts():
|
||
|
request_data = request.get_json()
|
||
|
if not request_data:
|
||
|
abort(400, description="Bad Request: Expecting JSON data")
|
||
|
query = request_data.get("query")
|
||
|
if not query:
|
||
|
abort(400, description="Missing parameter 'query'")
|
||
|
retriever_model = request_data.get("retriever_model", "mpnet")
|
||
|
reader_model = request_data.get("reader_model", "GPT")
|
||
|
search_method = request_data.get("search_method", "classic_retriever_reader")
|
||
|
generate_answer = request_data.get("generate_answer", False)
|
||
|
rerank_retrieved_results = request_data.get("rerank_retrieved_results", True)
|
||
|
if not query:
|
||
|
return {"status": "failed", "message": "Missing parameter 'query'"}
|
||
|
result = question_answering.search_experts(
|
||
|
query=query,
|
||
|
search_method=search_method,
|
||
|
retriever_model=retriever_model,
|
||
|
generate_answer=generate_answer,
|
||
|
rerank=rerank_retrieved_results,
|
||
|
)
|
||
|
if isinstance(result, tuple) and len(result or []) > 0:
|
||
|
return {"answer": result[0], "documents": result[1]}
|
||
|
else:
|
||
|
return result
|
||
|
|
||
|
|
||
|
@server.route("/recommend_wpms", methods=["POST"])
|
||
|
def recommend_wpms():
|
||
|
request_data = request.get_json()
|
||
|
if not request_data:
|
||
|
abort(400, description="Bad Request: Expecting JSON data")
|
||
|
interests = request_data.get("interests")
|
||
|
previous_courses = request_data.get("previous_courses")
|
||
|
future_carrer = request_data.get("future_carrer")
|
||
|
if not (interests and previous_courses and future_carrer):
|
||
|
abort(
|
||
|
400,
|
||
|
description="Provide at least one of the parameters: 'interests', 'previous_courses' or 'future_carrer' ",
|
||
|
)
|
||
|
retrieval_method_or_model = request_data.get("retrieval_method_or_model", "mpnet")
|
||
|
recommendation_method = request_data.get(
|
||
|
"recommendation_method", "get_retrieved_results"
|
||
|
)
|
||
|
rerank_retrieved_results = request_data.get("rerank_retrieved_results", True)
|
||
|
result = question_answering.recommend_wpm(
|
||
|
interets=interests,
|
||
|
previous_courses=previous_courses,
|
||
|
future_carrer=future_carrer,
|
||
|
recommendation_method=recommendation_method,
|
||
|
rerank_retrieved_results=rerank_retrieved_results,
|
||
|
retrieval_method_or_model=retrieval_method_or_model,
|
||
|
)
|
||
|
if isinstance(result, tuple) and len(result or []) > 0:
|
||
|
return {"answer": result[0], "documents": result[1]}
|
||
|
else:
|
||
|
return result
|
||
|
|
||
|
|
||
|
@server.route("/get_all_weaviate_data", methods=["GET"])
|
||
|
def get_weaviate_data():
|
||
|
index = request.args.get("index")
|
||
|
return pipeline.get_all_weaviate_data(index=index)
|
||
|
|
||
|
|
||
|
@server.route("/get_all_es_data", methods=["GET"])
|
||
|
def get_elastic_data():
|
||
|
index = request.args.get("index")
|
||
|
return pipeline.get_all_elastic_data(index=index)
|
||
|
|
||
|
|
||
|
@server.route("/get_document_by_id", methods=["POST"])
|
||
|
def get_doc_by_id():
|
||
|
request.data = request.get_json()
|
||
|
id = request.data["id"]
|
||
|
return pipeline.query_by_ids([id])
|
||
|
|
||
|
|
||
|
@server.route("/conf1")
|
||
|
def config1():
|
||
|
return render_template("config1.html")
|
||
|
|
||
|
|
||
|
# @server.route("/conf2")
|
||
|
# def config2():
|
||
|
# return render_template("config2.html")
|
||
|
|
||
|
|
||
|
# @server.route("/conf3")
|
||
|
# def config3():
|
||
|
# return render_template("config3.html")
|
||
|
|
||
|
|
||
|
# @server.route("/conf4")
|
||
|
# def config4():
|
||
|
# return render_template("config4.html")
|
||
|
|
||
|
|
||
|
# @server.route("/conf5")
|
||
|
# def config5():
|
||
|
# return render_template("config5.html")
|
||
|
|
||
|
|
||
|
# @server.route("/conf6")
|
||
|
# def config6():
|
||
|
# return render_template("config6.html")
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
if not es.indices.exists(index="feedback"):
|
||
|
mapping = {
|
||
|
"mappings": {
|
||
|
"properties": {
|
||
|
"question": {"type": "text"},
|
||
|
"answer": {"type": "text"},
|
||
|
"feedback": {"type": "text"},
|
||
|
"timestamp": {"type": "date"},
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
es.indices.create(index="feedback", body=mapping)
|
||
|
|
||
|
server.run(host="::", port=8080)
|