41 lines
2.6 KiB
Python
41 lines
2.6 KiB
Python
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, TextStreamer
|
|
|
|
class EMGerman:
|
|
def __init__(self, model_path='./models/em_german_7b_v01' ) -> None:
|
|
self.model=AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True)
|
|
self.tokenizer=AutoTokenizer.from_pretrained(model_path)
|
|
self.tokenizer.pad_token_id=self.tokenizer.eos_token_id
|
|
self.generation_config=GenerationConfig(max_new_tokens=500,
|
|
temperature=0.4,
|
|
top_p=0.95,
|
|
top_k=40,
|
|
repetition_penalty=1.2,
|
|
bos_token_id=self.tokenizer.bos_token_id,
|
|
eos_token_id=self.tokenizer.eos_token_id,
|
|
do_sample=True,
|
|
use_cache=True,
|
|
output_attentions=False,
|
|
output_hidden_states=False,
|
|
output_scores=False,
|
|
remove_invalid_values=True
|
|
)
|
|
self.streamer = TextStreamer(self.tokenizer)
|
|
|
|
|
|
def ask_model(self,instruction, system='Du bist ein hilfreicher Assistent.'):
|
|
prompt=f"{system} USER: {instruction} ASSISTANT:"
|
|
input_tokens=self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
|
|
output_tokens=self.model.generate(**input_tokens, generation_config=self.generation_config, streamer=self.streamer)[0]
|
|
answer=self.tokenizer.decode(output_tokens, skip_special_tokens=True)
|
|
return answer
|
|
|
|
def retrieval_qa(self, question:str, references):
|
|
retrieval_system="Du bist ein hilfreicher Assistent. Für die folgende Aufgabe stehen dir zwischen den tags BEGININPUT und ENDINPUT mehrere Quellen zur Verfügung. Metadaten zu den einzelnen Quellen wie Autor, URL o.ä. sind zwischen BEGINCONTEXT und ENDCONTEXT zu finden, danach folgt der Text der Quelle. Die eigentliche Aufgabe oder Frage ist zwischen BEGININSTRUCTION und ENDINCSTRUCTION zu finden. Beantworte diese wortwörtlich mit einem Zitat aus den Quellen. Sollten diese keine Antwort enthalten, antworte, dass auf Basis der gegebenen Informationen keine Antwort möglich ist!"
|
|
retrieval_question=f"""\
|
|
BEGININPUT
|
|
BEGINCONTEXT
|
|
ENDCONTEXT
|
|
{references}
|
|
ENDINPUT
|
|
BEGININSTRUCTION {question} ENDINSTRUCTION"""
|
|
return self.ask_model(retrieval_question, system=retrieval_system) |