enhanced get_newest_file function to support ensemble file retrieval
parent
4469f55889
commit
282cb128e3
25
ml_helper.py
25
ml_helper.py
|
|
@ -4,6 +4,7 @@ import nltk
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
def get_device(verbose=False, include_mps=False):
|
def get_device(verbose=False, include_mps=False):
|
||||||
"""
|
"""
|
||||||
|
|
@ -39,7 +40,7 @@ def save_model_and_hyperparams(model, model_prefix_name, rmse, hyperparameters,
|
||||||
json.dump(hyperparameters, f)
|
json.dump(hyperparameters, f)
|
||||||
print(f"Hyperparameters saved to {hyperparameters_path}.")
|
print(f"Hyperparameters saved to {hyperparameters_path}.")
|
||||||
|
|
||||||
def get_newest_file(path, name=None, extension=".pth"):
|
def get_newest_file(path, name=None, extension=".pth", ensemble=False):
|
||||||
"""
|
"""
|
||||||
Get the newest file in a directory.
|
Get the newest file in a directory.
|
||||||
"""
|
"""
|
||||||
|
|
@ -49,13 +50,35 @@ def get_newest_file(path, name=None, extension=".pth"):
|
||||||
if name:
|
if name:
|
||||||
files = [f for f in files if name in f]
|
files = [f for f in files if name in f]
|
||||||
|
|
||||||
|
if ensemble:
|
||||||
|
files = [f for f in files if "ensemble" in f]
|
||||||
|
|
||||||
# Sort files by modification time
|
# Sort files by modification time
|
||||||
files.sort(key=lambda x: os.path.getmtime(os.path.join(path, x)), reverse=True)
|
files.sort(key=lambda x: os.path.getmtime(os.path.join(path, x)), reverse=True)
|
||||||
|
|
||||||
# Get the newest file
|
# Get the newest file
|
||||||
if files:
|
if files:
|
||||||
|
if not ensemble:
|
||||||
newest_model_path = os.path.join(path, files[0])
|
newest_model_path = os.path.join(path, files[0])
|
||||||
return newest_model_path
|
return newest_model_path
|
||||||
|
else:
|
||||||
|
# Extract timestamp from the newest file's filename
|
||||||
|
regex = r"(\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2})"
|
||||||
|
newest_stamp = None
|
||||||
|
ret_files = []
|
||||||
|
for file in files:
|
||||||
|
match = re.search(regex, file)
|
||||||
|
if match:
|
||||||
|
newest_timestamp = match.group(1)
|
||||||
|
if not newest_stamp or newest_timestamp > newest_stamp:
|
||||||
|
newest_stamp = newest_timestamp
|
||||||
|
if newest_stamp:
|
||||||
|
ret_files.append(os.path.join(path, file))
|
||||||
|
if ret_files:
|
||||||
|
return ret_files
|
||||||
|
else:
|
||||||
|
print("No File found in the directory")
|
||||||
|
return None
|
||||||
else:
|
else:
|
||||||
print("No File found in the directory")
|
print("No File found in the directory")
|
||||||
return None
|
return None
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
|
@ -50,7 +50,7 @@
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"# load latest data if keyword is in the file name\n",
|
"# load latest data if keyword is in the file name\n",
|
||||||
"hist_file_name = ml_helper.get_newest_file('histories/', name='CNN', extension=\".json\")\n",
|
"hist_file_name = ml_helper.get_newest_file('histories/', name='CNN', extension=\".json\", ensemble=False)\n",
|
||||||
"print(f\"Loading {hist_file_name}\")"
|
"print(f\"Loading {hist_file_name}\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue