3import numpy as np
4import os
5import torch
6from ray import tune
7from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining
8from utils.train import Trainer
9from models.cnn import GetCNN

Check if GPU is available

12device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
13print("Device:  " + str(device))
16num_samples= 40  # for multiple trials
17max_num_epochs= 25
18gpus_per_trial= 1

Cifar 10 Datasets location

Code has been referenced from the official ray tune documentation ASHA https://docs.ray.io/en/master/tune/api_docs/schedulers.html#tune-scheduler-hyperband

PBT https://docs.ray.io/en/latest/tune/api_docs/schedulers.html#tune-scheduler-pbt

21data_dir = './data/Cifar10'

config - returns a dict of hyperparameters

Selecting different hyperparameters for tuning l1 : Number of units in first fully connected layer l2 : Number of units in second fully connected layer lr : Learning rate decay : Decay rate for regularization batch_size : Batch size of test and train data

41config = {
42    "l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), # eg. 4, 8, 16 .. 512
43    "l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), # eg. 4, 8, 16 .. 512
44    "lr": tune.loguniform(1e-4, 1e-1), # Sampling from log uniform distribution
45    "decay": tune.sample_from(lambda _: 10 ** np.random.randint(-7, -3)), # eg. 1e-7, 1e-6, .. 1e-3
46    "batch_size": tune.choice([32, 64, 128, 256])
47}

calling trainer ASHA (Asynchronous Successive Halving Algorithm) scheduler max_t : Maximum number of units per trail (can be time or epochs) grace_period : Stop trials after specific number of unit if model is not performing well (can be time or epochs) reduction_factor : Set halving rate

50trainer = Trainer(device=device)

Population based training scheduler time_attr : Can be time or epochs metric : Objective of training (loss or accuracy) perturbation_interval : Perturbation occur after specified unit (can be time or epochs) hyperparam_mutations : Hyperparameters to mutate

57scheduler = ASHAScheduler(
58    max_t=max_num_epochs,
59    grace_period=4,
60    reduction_factor=4)
70scheduler = PopulationBasedTraining(
71        time_attr= "training_iteration", # epochs
72        metric='loss', # loss is objective function
73        mode='min', # minimizing loss is objective of training
74        perturbation_interval=5.0, # after 5 epochs perturbate
75        hyperparam_mutations={
76            "lr": [1e-3, 5e-4, 1e-4, 5e-4, 1e-5], # choose from given learning rates
77            "batch_size": [64, 128, 256], # choose from given batch sizes
78            "decay": tune.uniform(10**-8, 10**-4) # sample from uniform distribution
79            }
80        )
81
82result = tune.run(
83    tune.with_parameters(trainer.Train_ray, data_dir=data_dir),
84    name="ray_test_basic-CNN", # name for identifying models (checkpoints)
85    scheduler=scheduler, # select scheduler PBT or ASHA
86    resources_per_trial={"cpu": 8, "gpu": gpus_per_trial}, # select number of CPUs or GPUs
87    config=config, # input config dict consisting of different hyperparameters
88    stop={
89        "training_iteration": max_num_epochs, # stopping criterea
90    },
91    metric="loss", # uncomment for ASHA scheduler
92    mode="min", # uncomment for ASHA scheduler
93    num_samples=num_samples,
94    verbose=True, # keep to true to check how training progresses
95    fail_fast=True, # fail on first error
96    keep_checkpoints_num=5, # number of checkpoints to be saved per num_samples
97
98)
99
100best_trial = result.get_best_trial("loss", "min", "last")
101print("Best configuration: {}".format(best_trial.config))
102print("Best validation loss: {}".format(best_trial.last_result["loss"]))
103print("Best validation accuracy: {}".format(
104    best_trial.last_result["accuracy"]))
105
106
107best_trained_model = GetCNN(best_trial.config["l1"], best_trial.config["l2"])
108best_trained_model.to(device)
109checkpoint_path = os.path.join(best_trial.checkpoint.value, "checkpoint")
110model_state, optimizer_state = torch.load(checkpoint_path)
111best_trained_model.load_state_dict(model_state)

Check accuracy of best model

114test_acc =  trainer.Test(best_trained_model, save=data_dir)
115print("Best Test accuracy: {}".format(test_acc))