refracter gat experiment

This commit is contained in:
Varuna Jayasiri
2021-07-27 19:09:18 +05:30
parent 79505c4a89
commit c39e805448
2 changed files with 58 additions and 235 deletions

View File

@ -17,7 +17,7 @@ import torch
from torch import nn
from labml import lab, monit, tracker, experiment
from labml.configs import BaseConfigs
from labml.configs import BaseConfigs, option, calculate
from labml.utils import download
from labml_helpers.device import DeviceConfigs
from labml_helpers.module import Module
@ -194,26 +194,6 @@ class Configs(BaseConfigs):
# Optimizer
optimizer: torch.optim.Adam
def initialize(self):
"""
Initialize
"""
# Create the dataset
self.dataset = CoraDataset(self.include_edges)
# Get the number of classes
self.n_classes = len(self.dataset.classes)
# Number of features in the input
self.in_features = self.dataset.features.shape[1]
# Create the model
self.model = GAT(self.in_features, self.n_hidden, self.n_classes, self.n_heads, self.dropout)
# Move the model to the device
self.model.to(self.device)
# Configurable optimizer, so that we can set the configurations
# such as learning rate by passing the dictionary later.
optimizer_conf = OptimizerConfigs()
optimizer_conf.parameters = self.model.parameters()
self.optimizer = optimizer_conf
def run(self):
"""
### Training loop
@ -276,6 +256,38 @@ class Configs(BaseConfigs):
tracker.save()
@option(Configs.dataset)
def cora_dataset(c: Configs):
"""
Create Cora dataset
"""
return CoraDataset(c.include_edges)
# Get the number of classes
calculate(Configs.n_classes, lambda c: len(c.dataset.classes))
# Number of features in the input
calculate(Configs.in_features, lambda c: c.dataset.features.shape[1])
@option(Configs.model)
def gat_model(c: Configs):
"""
Create GAT model
"""
return GAT(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout).to(c.device)
@option(Configs.optimizer)
def _optimizer(c: Configs):
"""
Create configurable optimizer
"""
opt_conf = OptimizerConfigs()
opt_conf.parameters = c.model.parameters()
return opt_conf
def main():
# Create configurations
conf = Configs()
@ -288,8 +300,6 @@ def main():
'optimizer.learning_rate': 5e-3,
'optimizer.weight_decay': 5e-4,
})
# Initialize
conf.initialize()
# Start and watch the experiment
with experiment.start():