import torch import nltk import time import json import os import re def get_device(verbose=False, include_mps=False): """ Get the current device (MPS, CPU or GPU) for PyTorch. """ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if verbose: print('Using device:', device) if include_mps: device = torch.device("mps" if torch.backends.mps.is_available() else device) return device def save_model_and_hyperparams(model, model_prefix_name, rmse, hyperparameters, timestamp=None): """ Save the model and hyperparameters to disk. hyperparameters: dictionary containing hyperparameters to save """ # Create a timestamp if timestamp is None: timestamp = time.strftime("%Y%m%d-%H%M%S") rmse = round(rmse, 4) # Save the model state dictionary model_path = f'models/{model_prefix_name}_acc_{rmse}_{timestamp}.pth' torch.save(model.state_dict(), model_path) print(f"Model saved to {model_path}.") # Save the hyperparameters as a JSON file hyperparameters['rmse'] = rmse hyperparameters_path = f'models/{model_prefix_name}_para_acc_{rmse}_{timestamp}.json' with open(hyperparameters_path, 'w') as f: json.dump(hyperparameters, f) print(f"Hyperparameters saved to {hyperparameters_path}.") def get_newest_file(path, name=None, extension=".pth", ensemble=False): """ Get the newest file in a directory. """ # List all files in the directory files = [f for f in os.listdir(path) if f.endswith(extension)] # List all files with name in it if name: files = [f for f in files if name in f] if ensemble: files = [f for f in files if "ensemble" in f] # Sort files by modification time files.sort(key=lambda x: os.path.getmtime(os.path.join(path, x)), reverse=True) # Get the newest file if files: if not ensemble: newest_model_path = os.path.join(path, files[0]) return newest_model_path else: # Extract timestamp from the newest file's filename regex = r"(\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2})" newest_stamp = None ret_files = [] for file in files: match = re.search(regex, file) if match: newest_timestamp = match.group(1) if not newest_stamp or newest_timestamp > newest_stamp: newest_stamp = newest_timestamp if newest_stamp: ret_files.append(os.path.join(path, file)) if ret_files: return ret_files else: print("No File found in the directory") return None else: print("No File found in the directory") return None def main(): """ Main function used to set up the environment. """ # download nltk data nltk.download('punkt') nltk.download('punkt_tab') # Check if CUDA is available cuda_available = torch.cuda.is_available() print(f"CUDA available: {cuda_available}") if cuda_available: # Print the current CUDA device current_device = torch.cuda.current_device() print(f"Current CUDA device: {current_device}") # Print the name of the current CUDA device device_name = torch.cuda.get_device_name(current_device) print(f"CUDA device name: {device_name}") else: print("CUDA is not available. Please check your CUDA installation and PyTorch configuration.") if __name__ == "__main__": main()