forked from 1827133/BA-Chatbot
19 lines
927 B
Python
19 lines
927 B
Python
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
||
|
import torch
|
||
|
|
||
|
class OpenBuddy:
|
||
|
def __init__( self, model_path):
|
||
|
self.model = AutoModelForCausalLM.from_pretrained(model_path,
|
||
|
device_map="auto",
|
||
|
trust_remote_code=True,
|
||
|
torch_dtype=torch.float16)
|
||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||
|
self.model.eval()
|
||
|
def generate_answer(self,prompt):
|
||
|
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to('cuda')
|
||
|
with torch.no_grad():
|
||
|
output_ids = self.model.generate(
|
||
|
input_ids=input_ids,
|
||
|
max_new_tokens=100,
|
||
|
eos_token_id=self.tokenizer.eos_token_id)
|
||
|
return self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|