BA-Chatbot/model_service/app.py

113 lines
5.4 KiB
Python
Raw Permalink Normal View History

2023-11-15 14:28:48 +01:00
"""
The Model Service is designed to handle various tasks related to language models, such as generating embeddings, reranking documents,
and creating responses using different transformer models.
It utilizes Flask to set up a server that responds to HTTP requests for these tasks.
The service can work with different models like Llama, GPTQ, and OpenAI's GPT models,
and it supports various operations like pooling strategies and extracting embeddings from different layers of the models.
The script is structured to be flexible, allowing for easy integration and switch between different models and methods based on runtime arguments.
This makes it a versatile tool for various NLP tasks in a production environment where different model capabilities are required.
The Llama class encapsulates functionality specific to the Llama model, providing methods to generate embeddings and responses.
It offers flexibility in using different versions of the model (like GGML, HF, and GPTQ versions) and supports operations like getting embeddings from different layers or using specific pooling strategies.
The openai_models.py file contains functions to interact with OpenAI's GPT models.
It includes functions to generate responses, rerank documents, and create embeddings using OpenAI's API, handling potential errors and offering fallback options if necessary.
+------------------------------------------------------------------+
| NOTE: To run the application, use the following command: |
| $ python app.py --model LLama (or a similar arg) |
+------------------------------------------------------------------+
"""
from typing import Dict
from flask import Flask, request
from openai_models import generate_answer_gpt, generate_embeddings_ada, rerank_documents_with_gpt35
from gptq import GPTQ
import argparse
import os
from embeddings.llama import LLama
model_arg = os.environ.get('MODEL_ARG', None)
if model_arg is None:
parser = argparse.ArgumentParser(description='Flask App Configuration.')
parser.add_argument('--model', type=str, default=None, help='Which model to use: LLama, GPTQ, OPENAI_GPT.')
parser.add_argument('--method', type=str, default=None, help='Method to use: ggml or other_method_name.')
args = parser.parse_args()
if args.model == 'LLama' or args.method == 'ggml':
model_arg = "LLama"
if model_arg== "LLama":
llama = LLama(
model_path_hf="./models/Llama-2-7b-hf",
model_path_ggml="./models/openbuddy-llama2-13b-v8.1-q3_K.bin",
lora_path=None,
lora_base=None,
ggml=False,
hf_model=True,
gptq_hf=False,
model_path_gpqt_hf="./openbuddy-llama2-34b-v11.1-bf16-GPTQ")
# if args.model== 'GPTQ':
# model_name_or_path = "./openbuddy-llama2-34b-v11.1-bf16-GPTQ"
# gptq= GPTQ(model_name_or_path)
# if args.model =="openBuddy":
# openBuddy= OpenBuddy(model_path="./models/openbuddy-llama2-13b-v11.1-bf16")
server = Flask(__name__, static_folder="static")
#pip install bitsandbytes
# pip install accelerate
@server.route("/generate_embeddings", methods=["POST"])
def get_embeddings():
request_data:Dict = request.get_json()
embedding_type = request_data.get("embedding_type", "input_embeddings")
operation = request_data.get("operation", "mean")
layer = request_data.get("layer", -1)
query = request_data.get("query", "mean")
embedding_model = request_data.get("embedding_model", "llama")
print("Generating Embeddings for Input: ", query)
if embedding_model== "llama":
if embedding_type == "input_embeddings":
print(f"Returning Input Embeddings with operation {operation}...")
return llama.get_input_embeddings(text=query, operation=operation).tolist()
elif embedding_type == "last_layer":
print(f"Returning Embeddings from Last layer with operation {operation}... ")
return llama.get_embeddings_last_layer(text=query, operation=operation).tolist()
elif embedding_type == "nth_layer":
print(f"Returning Embeddings from nth-layer: {layer} ....")
return llama.get_embeddings(text=query,layer_num=layer ).tolist()
else:
return llama.get_embeddings_ggml(text=query).tolist()
@server.route("/rerank_documents", methods=["POST"])
def rerank_documents():
request_data:Dict = request.get_json()
sys_prompt= request_data.get("system_prompt")
query= request_data.get("query")
documents= request_data.get("documents")
return rerank_documents_with_gpt35(system_prompt=sys_prompt, question=query, completion_tokens=300,reference=documents )
@server.route("/generate_answer", methods=["POST"])
def generate_answer():
request_data:Dict = request.get_json()
print("REQUEST DATA MODEL SERVICE GENERATE ANSWER: ", request_data,flush=True )
prompt = request_data["prompt"]
model= request_data["model"]
question= request_data.get("question")
reference= request_data.get("reference")
completion_tokens= request_data.get("completion_tokens", 1000)
if model == "GGML":
return llama.generate_answer_ggml(text=prompt)
# if model == "GPTQ":
# return gptq.generate_answer(prompt=prompt)
if model == "HF":
return llama.generate_answer(prompt=prompt)
if model == "GPT":
return generate_answer_gpt(default_prompt=prompt, question=question, reference=reference, completion_tokens= completion_tokens)
if __name__ == "__main__":
server.run(host='0.0.0.0', port=5000)