BA-Chatbot/data_service/api/embeddingsServiceCaller.py

43 lines
1.5 KiB
Python
Raw Normal View History

2023-11-15 14:28:48 +01:00
import requests
import json
import os
host= os.environ.get("MODEL_SERVICE_HOST","127.0.0.1" )
BASE_URL= f"http://{host}:5000/"
ANSWER_URL= f"{BASE_URL}generate_answer"
EMBEDDINGS_URL= f"{BASE_URL}generate_embeddings"
class EmbeddingServiceCaller:
def __init__(self) -> None:
pass
def get_embeddings(self, text,embedding_type="input_embeddings", operation="mean", embedding_model="llama", layer=-1):
headers = {
'Content-Type': 'application/json'
}
payload=json.dumps({
"query":text,
"embedding_type":embedding_type,
"operation":operation,
"embedding_model":embedding_model,
"layer":layer
})
return requests.request("POST", f"{EMBEDDINGS_URL}", headers=headers, data=payload).json()
def get_answer(self,payload="", prompt=""):
headers = {
'Content-Type': 'application/json'
}
if payload:
response = requests.request("POST", f"{ANSWER_URL}", headers=headers, data=payload)
else:
payload = json.dumps({
"prompt": prompt
})
response = requests.request("POST", f"{ANSWER_URL}", headers=headers, data=payload)
return response.json()
def _call(self, url, method):
response = requests.request(method, url)
return response.json()
if __name__ == "__main__":
caller= EmbeddingServiceCaller()
print(caller.get_embeddings("Hallsdfasdf Hallsdfasdf Hallsdfasdf"))