enhanced get_newest_file function to support ensemble file retrieval
parent
4469f55889
commit
282cb128e3
29
ml_helper.py
29
ml_helper.py
|
|
@ -4,6 +4,7 @@ import nltk
|
|||
import time
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
def get_device(verbose=False, include_mps=False):
|
||||
"""
|
||||
|
|
@ -39,7 +40,7 @@ def save_model_and_hyperparams(model, model_prefix_name, rmse, hyperparameters,
|
|||
json.dump(hyperparameters, f)
|
||||
print(f"Hyperparameters saved to {hyperparameters_path}.")
|
||||
|
||||
def get_newest_file(path, name=None, extension=".pth"):
|
||||
def get_newest_file(path, name=None, extension=".pth", ensemble=False):
|
||||
"""
|
||||
Get the newest file in a directory.
|
||||
"""
|
||||
|
|
@ -49,13 +50,35 @@ def get_newest_file(path, name=None, extension=".pth"):
|
|||
if name:
|
||||
files = [f for f in files if name in f]
|
||||
|
||||
if ensemble:
|
||||
files = [f for f in files if "ensemble" in f]
|
||||
|
||||
# Sort files by modification time
|
||||
files.sort(key=lambda x: os.path.getmtime(os.path.join(path, x)), reverse=True)
|
||||
|
||||
# Get the newest file
|
||||
if files:
|
||||
newest_model_path = os.path.join(path, files[0])
|
||||
return newest_model_path
|
||||
if not ensemble:
|
||||
newest_model_path = os.path.join(path, files[0])
|
||||
return newest_model_path
|
||||
else:
|
||||
# Extract timestamp from the newest file's filename
|
||||
regex = r"(\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2})"
|
||||
newest_stamp = None
|
||||
ret_files = []
|
||||
for file in files:
|
||||
match = re.search(regex, file)
|
||||
if match:
|
||||
newest_timestamp = match.group(1)
|
||||
if not newest_stamp or newest_timestamp > newest_stamp:
|
||||
newest_stamp = newest_timestamp
|
||||
if newest_stamp:
|
||||
ret_files.append(os.path.join(path, file))
|
||||
if ret_files:
|
||||
return ret_files
|
||||
else:
|
||||
print("No File found in the directory")
|
||||
return None
|
||||
else:
|
||||
print("No File found in the directory")
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
|
@ -50,7 +50,7 @@
|
|||
],
|
||||
"source": [
|
||||
"# load latest data if keyword is in the file name\n",
|
||||
"hist_file_name = ml_helper.get_newest_file('histories/', name='CNN', extension=\".json\")\n",
|
||||
"hist_file_name = ml_helper.get_newest_file('histories/', name='CNN', extension=\".json\", ensemble=False)\n",
|
||||
"print(f\"Loading {hist_file_name}\")"
|
||||
]
|
||||
},
|
||||
|
|
|
|||
Loading…
Reference in New Issue