import torch import nltk import time import json import os def get_device(verbose=False): """ Get the current device (CPU or GPU) for PyTorch. """ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if verbose: print('Using device:', device) return device def save_model_and_hyperparameters(model, model_prefix_name, accuracy, timestamp=None,**kwargs): """ Save the model and hyperparameters to disk. **kwargs: hyperparameters to save """ # Create a timestamp if timestamp is None: timestamp = time.strftime("%Y%m%d-%H%M%S") accuracy = round(accuracy, 4) # Save the model state dictionary model_path = f'models/{model_prefix_name}_acc_{accuracy}_{timestamp}.pth' torch.save(model.state_dict(), model_path) print(f"Model saved to {model_path}.") # Save the hyperparameters as a JSON file hyperparameters = kwargs hyperparameters['accuracy'] = accuracy hyperparameters_path = f'models/{model_prefix_name}_para_acc_{accuracy}_{timestamp}.json' with open(hyperparameters_path, 'w') as f: json.dump(hyperparameters, f) print(f"Hyperparameters saved to {hyperparameters_path}.") def get_newest_model_path(path, name=None, extension=".pth"): """ Get the newest file in a directory. """ # List all files in the directory files = [f for f in os.listdir(path) if f.endswith(extension)] # List all files with name in it if name: files = [f for f in files if name in f] # Sort files by modification time files.sort(key=lambda x: os.path.getmtime(os.path.join(path, x)), reverse=True) # Get the newest file if files: newest_model_path = os.path.join(path, files[0]) return newest_model_path else: print("No File found in the directory") return None def main(): """ Main function used to set up the environment. """ # download nltk data nltk.download('punkt') nltk.download('punkt_tab') # Check if CUDA is available cuda_available = torch.cuda.is_available() print(f"CUDA available: {cuda_available}") if cuda_available: # Print the current CUDA device current_device = torch.cuda.current_device() print(f"Current CUDA device: {current_device}") # Print the name of the current CUDA device device_name = torch.cuda.get_device_name(current_device) print(f"CUDA device name: {device_name}") else: print("CUDA is not available. Please check your CUDA installation and PyTorch configuration.") if __name__ == "__main__": main()