3import numpy as np
4import os
5import torch
6from ray import tune
7from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining
8from utils.train import Trainer
9from models.cnn import GetCNNCheck if GPU is available
12device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
13print("Device:  " + str(device))16num_samples= 40  # for multiple trials
17max_num_epochs= 25
18gpus_per_trial= 1Cifar 10 Datasets location
Code has been referenced from the official ray tune documentation ASHA https://docs.ray.io/en/master/tune/api_docs/schedulers.html#tune-scheduler-hyperband
PBT https://docs.ray.io/en/latest/tune/api_docs/schedulers.html#tune-scheduler-pbt
21data_dir = './data/Cifar10'config - returns a dict of hyperparameters
Selecting different hyperparameters for tuning l1 : Number of units in first fully connected layer l2 : Number of units in second fully connected layer lr : Learning rate decay : Decay rate for regularization batch_size : Batch size of test and train data
41config = {
42    "l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), # eg. 4, 8, 16 .. 512
43    "l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), # eg. 4, 8, 16 .. 512
44    "lr": tune.loguniform(1e-4, 1e-1), # Sampling from log uniform distribution
45    "decay": tune.sample_from(lambda _: 10 ** np.random.randint(-7, -3)), # eg. 1e-7, 1e-6, .. 1e-3
46    "batch_size": tune.choice([32, 64, 128, 256])
47}calling trainer ASHA (Asynchronous Successive Halving Algorithm) scheduler max_t : Maximum number of units per trail (can be time or epochs) grace_period : Stop trials after specific number of unit if model is not performing well (can be time or epochs) reduction_factor : Set halving rate
50trainer = Trainer(device=device)Population based training scheduler time_attr : Can be time or epochs metric : Objective of training (loss or accuracy) perturbation_interval : Perturbation occur after specified unit (can be time or epochs) hyperparam_mutations : Hyperparameters to mutate
57scheduler = ASHAScheduler(
58    max_t=max_num_epochs,
59    grace_period=4,
60    reduction_factor=4)70scheduler = PopulationBasedTraining(
71        time_attr= "training_iteration", # epochs
72        metric='loss', # loss is objective function
73        mode='min', # minimizing loss is objective of training
74        perturbation_interval=5.0, # after 5 epochs perturbate
75        hyperparam_mutations={
76            "lr": [1e-3, 5e-4, 1e-4, 5e-4, 1e-5], # choose from given learning rates
77            "batch_size": [64, 128, 256], # choose from given batch sizes
78            "decay": tune.uniform(10**-8, 10**-4) # sample from uniform distribution
79            }
80        )
81
82result = tune.run(
83    tune.with_parameters(trainer.Train_ray, data_dir=data_dir),
84    name="ray_test_basic-CNN", # name for identifying models (checkpoints)
85    scheduler=scheduler, # select scheduler PBT or ASHA
86    resources_per_trial={"cpu": 8, "gpu": gpus_per_trial}, # select number of CPUs or GPUs
87    config=config, # input config dict consisting of different hyperparameters
88    stop={
89        "training_iteration": max_num_epochs, # stopping criterea
90    },
91    metric="loss", # uncomment for ASHA scheduler
92    mode="min", # uncomment for ASHA scheduler
93    num_samples=num_samples,
94    verbose=True, # keep to true to check how training progresses
95    fail_fast=True, # fail on first error
96    keep_checkpoints_num=5, # number of checkpoints to be saved per num_samples
97
98)
99
100best_trial = result.get_best_trial("loss", "min", "last")
101print("Best configuration: {}".format(best_trial.config))
102print("Best validation loss: {}".format(best_trial.last_result["loss"]))
103print("Best validation accuracy: {}".format(
104    best_trial.last_result["accuracy"]))
105
106
107best_trained_model = GetCNN(best_trial.config["l1"], best_trial.config["l2"])
108best_trained_model.to(device)
109checkpoint_path = os.path.join(best_trial.checkpoint.value, "checkpoint")
110model_state, optimizer_state = torch.load(checkpoint_path)
111best_trained_model.load_state_dict(model_state)Check accuracy of best model
114test_acc =  trainer.Test(best_trained_model, save=data_dir)
115print("Best Test accuracy: {}".format(test_acc))