BA-Chatbot/model_service/gptq.py

25 lines
1.1 KiB
Python
Raw Normal View History

2023-11-15 14:28:48 +01:00
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
class GPTQ:
def __init__( self, model_path):
self.model = AutoModelForCausalLM.from_pretrained(model_path,
device_map="auto",
trust_remote_code=False,
revision="main")
self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
def generate_answer(self, prompt ):
input_ids = self.tokenizer(prompt, return_tensors='pt').input_ids.cuda()
output = self.model.generate(inputs=input_ids, temperature=0.7, do_sample=True, top_p=0.95, top_k=40, max_new_tokens=512)
self.tokenizer.decode(output[0])
pipe = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.95,
top_k=40,
repetition_penalty=1.1
)
return pipe(prompt)