BA-Chatbot/backend/api/embeddingsServiceCaller.py

57 lines
2.0 KiB
Python
Raw Normal View History

2023-11-15 14:28:48 +01:00
import requests
import json
import os
host= os.environ.get("MODEL_SERVICE_HOST","127.0.0.1" )
BASE_URL= f"http://{host}:5000/"
ANSWER_URL= f"{BASE_URL}generate_answer"
EMBEDDINGS_URL= f"{BASE_URL}generate_embeddings"
RERANKING_URL= f"{BASE_URL}rerank_documents"
class EmbeddingServiceCaller:
def __init__(self) -> None:
pass
def get_embeddings(self, text,embedding_type="input_embeddings", operation="mean", embedding_model="llama", layer=-1 ):
headers = {
'Content-Type': 'application/json'
}
payload=json.dumps({
"query":text,
"embedding_type":embedding_type,
"operation":operation,
"embedding_model":embedding_model,
"layer":layer
})
response= requests.request("POST", f"{EMBEDDINGS_URL}", headers=headers, data=payload)
return response.json()
def get_answer(self,payload="", prompt="", llama=False):
headers = {
'Content-Type': 'application/json'
}
if payload:
print("PAYLOAD: ",payload, flush=True)
response = requests.request("POST", f"{ANSWER_URL}", headers=headers, data=payload)
print(response)
else:
payload = json.dumps({
"prompt": prompt
})
response = requests.request("POST", f"{ANSWER_URL}", headers=headers, data=payload)
print("PROMPT: ",prompt)
print(response.text)
if llama:
return response.text
return response.json()
def rerank_documents_gpt(self, payload):
headers = {
'Content-Type': 'application/json'
}
response= requests.request("POST", f"{RERANKING_URL}", headers=headers, data=payload)
return response.json()
def _call(self, url, method):
response = requests.request(method, url)
return response.json()
if __name__ == "__main__":
caller= EmbeddingServiceCaller()
print(caller.get_embeddings("Hallsdfasdf Hallsdfasdf Hallsdfasdf"))