113 lines
3.6 KiB
Python
113 lines
3.6 KiB
Python
import torch
|
|
import nltk
|
|
|
|
import time
|
|
import json
|
|
import os
|
|
import re
|
|
|
|
def get_device(verbose=False, include_mps=False):
|
|
"""
|
|
Get the current device (MPS, CPU or GPU) for PyTorch.
|
|
"""
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
if verbose:
|
|
print('Using device:', device)
|
|
if include_mps:
|
|
device = torch.device("mps" if torch.backends.mps.is_available() else device)
|
|
return device
|
|
|
|
def save_model_and_hyperparams(model, model_prefix_name, rmse, hyperparameters, timestamp=None):
|
|
"""
|
|
Save the model and hyperparameters to disk.
|
|
hyperparameters: dictionary containing hyperparameters to save
|
|
"""
|
|
# Create a timestamp
|
|
if timestamp is None:
|
|
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
|
|
|
rmse = round(rmse, 4)
|
|
|
|
# Save the model state dictionary
|
|
model_path = f'models/{model_prefix_name}_acc_{rmse}_{timestamp}.pth'
|
|
torch.save(model.state_dict(), model_path)
|
|
print(f"Model saved to {model_path}.")
|
|
|
|
# Save the hyperparameters as a JSON file
|
|
hyperparameters['rmse'] = rmse
|
|
hyperparameters_path = f'models/{model_prefix_name}_para_acc_{rmse}_{timestamp}.json'
|
|
with open(hyperparameters_path, 'w') as f:
|
|
json.dump(hyperparameters, f)
|
|
print(f"Hyperparameters saved to {hyperparameters_path}.")
|
|
|
|
def get_newest_file(path, name=None, extension=".pth", ensemble=False):
|
|
"""
|
|
Get the newest file in a directory.
|
|
"""
|
|
# List all files in the directory
|
|
files = [f for f in os.listdir(path) if f.endswith(extension)]
|
|
# List all files with name in it
|
|
if name:
|
|
files = [f for f in files if name in f]
|
|
|
|
if ensemble:
|
|
files = [f for f in files if "ensemble" in f]
|
|
|
|
# Sort files by modification time
|
|
files.sort(key=lambda x: os.path.getmtime(os.path.join(path, x)), reverse=True)
|
|
|
|
# Get the newest file
|
|
if files:
|
|
if not ensemble:
|
|
newest_model_path = os.path.join(path, files[0])
|
|
return newest_model_path
|
|
else:
|
|
# Extract timestamp from the newest file's filename
|
|
regex = r"(\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2})"
|
|
newest_stamp = None
|
|
ret_files = []
|
|
for file in files:
|
|
match = re.search(regex, file)
|
|
if match:
|
|
newest_timestamp = match.group(1)
|
|
if not newest_stamp or newest_timestamp > newest_stamp:
|
|
newest_stamp = newest_timestamp
|
|
if newest_stamp:
|
|
ret_files.append(os.path.join(path, file))
|
|
if ret_files:
|
|
return ret_files
|
|
else:
|
|
print("No File found in the directory")
|
|
return None
|
|
else:
|
|
print("No File found in the directory")
|
|
return None
|
|
|
|
|
|
def main():
|
|
"""
|
|
Main function used to set up the environment.
|
|
"""
|
|
# download nltk data
|
|
nltk.download('punkt')
|
|
nltk.download('punkt_tab')
|
|
|
|
|
|
# Check if CUDA is available
|
|
cuda_available = torch.cuda.is_available()
|
|
print(f"CUDA available: {cuda_available}")
|
|
|
|
if cuda_available:
|
|
# Print the current CUDA device
|
|
current_device = torch.cuda.current_device()
|
|
print(f"Current CUDA device: {current_device}")
|
|
|
|
# Print the name of the current CUDA device
|
|
device_name = torch.cuda.get_device_name(current_device)
|
|
print(f"CUDA device name: {device_name}")
|
|
else:
|
|
print("CUDA is not available. Please check your CUDA installation and PyTorch configuration.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |