140 lines
4.0 KiB
Python
140 lines
4.0 KiB
Python
from flask import Flask, request, jsonify
|
||
from extractSpacy import extract, load_model
|
||
import requests
|
||
import os
|
||
import json
|
||
from flask_cors import CORS
|
||
import shutil
|
||
import subprocess
|
||
|
||
|
||
training_status = {"running": False}
|
||
|
||
|
||
app = Flask(__name__)
|
||
CORS(app)
|
||
|
||
COORDINATOR_URL = os.getenv("COORDINATOR_URL", "http://coordinator:5000")
|
||
VALIDATE_SERVICE_URL = os.getenv(
|
||
"VALIDATE_SERVICE_URL", "http://localhost:5054/validate"
|
||
)
|
||
|
||
|
||
@app.route("/extract", methods=["POST"])
|
||
def extract_pdf():
|
||
json_data = request.get_json()
|
||
|
||
pitchbook_id = json_data["id"]
|
||
pages_data = json_data["extracted_text_per_page"]
|
||
|
||
entities_json = extract(pages_data)
|
||
entities = (
|
||
json.loads(entities_json) if isinstance(entities_json, str) else entities_json
|
||
)
|
||
|
||
validate_payload = {"id": pitchbook_id, "service": "spacy", "entities": entities}
|
||
|
||
print(f"[SPACY] Sending to validate service: {VALIDATE_SERVICE_URL}")
|
||
print(f"[SPACY] Payload: {validate_payload} entities for pitchbook {pitchbook_id}")
|
||
|
||
try:
|
||
response = requests.post(
|
||
VALIDATE_SERVICE_URL, json=validate_payload, timeout=600
|
||
)
|
||
print(f"[SPACY] Validate service response: {response.status_code}")
|
||
if response.status_code != 200:
|
||
print(f"[SPACY] Validate service error: {response.text}")
|
||
except Exception as e:
|
||
print(f"[SPACY] Error sending to validate service: {e}")
|
||
|
||
return jsonify("Sent to validate-service"), 200
|
||
|
||
|
||
@app.route("/append-training-entry", methods=["POST"])
|
||
def append_training_entry():
|
||
entry = request.get_json()
|
||
|
||
if not entry or "text" not in entry or "entities" not in entry:
|
||
return (
|
||
jsonify(
|
||
{"error": "Ungültiges Format – 'text' und 'entities' erforderlich."}
|
||
),
|
||
400,
|
||
)
|
||
|
||
path = os.path.join("spacy_training", "annotation_data.json")
|
||
|
||
try:
|
||
if os.path.exists(path):
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
else:
|
||
data = []
|
||
|
||
# Duplikate prüfen
|
||
if entry in data:
|
||
return jsonify({"message": "Eintrag existiert bereits."}), 200
|
||
|
||
data.append(entry)
|
||
with open(path, "w", encoding="utf-8") as f:
|
||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||
|
||
return jsonify({"message": "Eintrag erfolgreich gespeichert."}), 200
|
||
except Exception as e:
|
||
print(f"[ERROR] Fehler beim Schreiben der Datei: {e}")
|
||
return jsonify({"error": "Interner Fehler beim Schreiben."}), 500
|
||
|
||
|
||
@app.route("/train", methods=["POST"])
|
||
def trigger_training():
|
||
from threading import Thread
|
||
|
||
Thread(target=run_training).start()
|
||
return jsonify({"message": "Training gestartet"}), 200
|
||
|
||
|
||
@app.route("/reload-model", methods=["POST"])
|
||
def reload_model():
|
||
try:
|
||
load_model()
|
||
return jsonify({"message": "Modell wurde erfolgreich neu geladen."}), 200
|
||
except Exception as e:
|
||
return (
|
||
jsonify({"error": "Fehler beim Neuladen des Modells", "details": str(e)}),
|
||
500,
|
||
)
|
||
|
||
|
||
def run_training():
|
||
training_status["running"] = True
|
||
notify_coordinator(True)
|
||
|
||
try:
|
||
if os.path.exists("output/model-last"):
|
||
shutil.copytree(
|
||
"output/model-last", "output/model-backup", dirs_exist_ok=True
|
||
)
|
||
subprocess.run(["python", "spacy_training/ner_trainer.py"], check=True)
|
||
load_model()
|
||
except Exception as e:
|
||
print("Training failed:", e)
|
||
training_status["running"] = False
|
||
notify_coordinator(False)
|
||
|
||
|
||
def notify_coordinator(running: bool):
|
||
try:
|
||
response = requests.post(
|
||
f"{COORDINATOR_URL}/api/spacy/training/status", json={"running": running}
|
||
)
|
||
print(
|
||
f"[SPACY] Coordinator: running = {running}, Status = {response.status_code}"
|
||
)
|
||
|
||
except Exception as e:
|
||
print(f"[SPACY] Fehler beim Senden des Trainingsstatus: {e}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
app.run(host="0.0.0.0", port=5052, debug=True)
|