BA-Chatbot/model_service/em_german.py

41 lines
2.6 KiB
Python

from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, TextStreamer
class EMGerman:
def __init__(self, model_path='./models/em_german_7b_v01' ) -> None:
self.model=AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True)
self.tokenizer=AutoTokenizer.from_pretrained(model_path)
self.tokenizer.pad_token_id=self.tokenizer.eos_token_id
self.generation_config=GenerationConfig(max_new_tokens=500,
temperature=0.4,
top_p=0.95,
top_k=40,
repetition_penalty=1.2,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
do_sample=True,
use_cache=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
remove_invalid_values=True
)
self.streamer = TextStreamer(self.tokenizer)
def ask_model(self,instruction, system='Du bist ein hilfreicher Assistent.'):
prompt=f"{system} USER: {instruction} ASSISTANT:"
input_tokens=self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
output_tokens=self.model.generate(**input_tokens, generation_config=self.generation_config, streamer=self.streamer)[0]
answer=self.tokenizer.decode(output_tokens, skip_special_tokens=True)
return answer
def retrieval_qa(self, question:str, references):
retrieval_system="Du bist ein hilfreicher Assistent. Für die folgende Aufgabe stehen dir zwischen den tags BEGININPUT und ENDINPUT mehrere Quellen zur Verfügung. Metadaten zu den einzelnen Quellen wie Autor, URL o.ä. sind zwischen BEGINCONTEXT und ENDCONTEXT zu finden, danach folgt der Text der Quelle. Die eigentliche Aufgabe oder Frage ist zwischen BEGININSTRUCTION und ENDINCSTRUCTION zu finden. Beantworte diese wortwörtlich mit einem Zitat aus den Quellen. Sollten diese keine Antwort enthalten, antworte, dass auf Basis der gegebenen Informationen keine Antwort möglich ist!"
retrieval_question=f"""\
BEGININPUT
BEGINCONTEXT
ENDCONTEXT
{references}
ENDINPUT
BEGININSTRUCTION {question} ENDINSTRUCTION"""
return self.ask_model(retrieval_question, system=retrieval_system)