BA-Chatbot/model_service/openbuddy.py

19 lines
927 B
Python
Raw Normal View History

2023-11-15 14:28:48 +01:00
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
class OpenBuddy:
def __init__( self, model_path):
self.model = AutoModelForCausalLM.from_pretrained(model_path,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.float16)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model.eval()
def generate_answer(self,prompt):
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to('cuda')
with torch.no_grad():
output_ids = self.model.generate(
input_ids=input_ids,
max_new_tokens=100,
eos_token_id=self.tokenizer.eos_token_id)
return self.tokenizer.decode(output_ids[0], skip_special_tokens=True)