forked from 1827133/BA-Chatbot
25 lines
1.1 KiB
Python
25 lines
1.1 KiB
Python
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
||
|
class GPTQ:
|
||
|
def __init__( self, model_path):
|
||
|
self.model = AutoModelForCausalLM.from_pretrained(model_path,
|
||
|
device_map="auto",
|
||
|
trust_remote_code=False,
|
||
|
revision="main")
|
||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
||
|
|
||
|
def generate_answer(self, prompt ):
|
||
|
input_ids = self.tokenizer(prompt, return_tensors='pt').input_ids.cuda()
|
||
|
output = self.model.generate(inputs=input_ids, temperature=0.7, do_sample=True, top_p=0.95, top_k=40, max_new_tokens=512)
|
||
|
self.tokenizer.decode(output[0])
|
||
|
pipe = pipeline(
|
||
|
"text-generation",
|
||
|
model=self.model,
|
||
|
tokenizer=self.tokenizer,
|
||
|
max_new_tokens=512,
|
||
|
do_sample=True,
|
||
|
temperature=0.7,
|
||
|
top_p=0.95,
|
||
|
top_k=40,
|
||
|
repetition_penalty=1.1
|
||
|
)
|
||
|
return pipe(prompt)
|