57 lines
2.0 KiB
Python
57 lines
2.0 KiB
Python
|
import requests
|
||
|
import json
|
||
|
import os
|
||
|
host= os.environ.get("MODEL_SERVICE_HOST","127.0.0.1" )
|
||
|
BASE_URL= f"http://{host}:5000/"
|
||
|
ANSWER_URL= f"{BASE_URL}generate_answer"
|
||
|
EMBEDDINGS_URL= f"{BASE_URL}generate_embeddings"
|
||
|
RERANKING_URL= f"{BASE_URL}rerank_documents"
|
||
|
class EmbeddingServiceCaller:
|
||
|
def __init__(self) -> None:
|
||
|
pass
|
||
|
def get_embeddings(self, text,embedding_type="input_embeddings", operation="mean", embedding_model="llama", layer=-1 ):
|
||
|
headers = {
|
||
|
'Content-Type': 'application/json'
|
||
|
}
|
||
|
payload=json.dumps({
|
||
|
"query":text,
|
||
|
"embedding_type":embedding_type,
|
||
|
"operation":operation,
|
||
|
"embedding_model":embedding_model,
|
||
|
"layer":layer
|
||
|
|
||
|
})
|
||
|
response= requests.request("POST", f"{EMBEDDINGS_URL}", headers=headers, data=payload)
|
||
|
return response.json()
|
||
|
def get_answer(self,payload="", prompt="", llama=False):
|
||
|
headers = {
|
||
|
'Content-Type': 'application/json'
|
||
|
}
|
||
|
if payload:
|
||
|
print("PAYLOAD: ",payload, flush=True)
|
||
|
response = requests.request("POST", f"{ANSWER_URL}", headers=headers, data=payload)
|
||
|
print(response)
|
||
|
else:
|
||
|
payload = json.dumps({
|
||
|
"prompt": prompt
|
||
|
})
|
||
|
response = requests.request("POST", f"{ANSWER_URL}", headers=headers, data=payload)
|
||
|
print("PROMPT: ",prompt)
|
||
|
print(response.text)
|
||
|
if llama:
|
||
|
return response.text
|
||
|
return response.json()
|
||
|
def rerank_documents_gpt(self, payload):
|
||
|
headers = {
|
||
|
'Content-Type': 'application/json'
|
||
|
}
|
||
|
response= requests.request("POST", f"{RERANKING_URL}", headers=headers, data=payload)
|
||
|
return response.json()
|
||
|
|
||
|
def _call(self, url, method):
|
||
|
response = requests.request(method, url)
|
||
|
return response.json()
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
caller= EmbeddingServiceCaller()
|
||
|
print(caller.get_embeddings("Hallsdfasdf Hallsdfasdf Hallsdfasdf"))
|