mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-26 08:41:23 +08:00
refracter gat experiment
This commit is contained in:
@ -17,7 +17,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from labml import lab, monit, tracker, experiment
|
||||
from labml.configs import BaseConfigs
|
||||
from labml.configs import BaseConfigs, option, calculate
|
||||
from labml.utils import download
|
||||
from labml_helpers.device import DeviceConfigs
|
||||
from labml_helpers.module import Module
|
||||
@ -194,26 +194,6 @@ class Configs(BaseConfigs):
|
||||
# Optimizer
|
||||
optimizer: torch.optim.Adam
|
||||
|
||||
def initialize(self):
|
||||
"""
|
||||
Initialize
|
||||
"""
|
||||
# Create the dataset
|
||||
self.dataset = CoraDataset(self.include_edges)
|
||||
# Get the number of classes
|
||||
self.n_classes = len(self.dataset.classes)
|
||||
# Number of features in the input
|
||||
self.in_features = self.dataset.features.shape[1]
|
||||
# Create the model
|
||||
self.model = GAT(self.in_features, self.n_hidden, self.n_classes, self.n_heads, self.dropout)
|
||||
# Move the model to the device
|
||||
self.model.to(self.device)
|
||||
# Configurable optimizer, so that we can set the configurations
|
||||
# such as learning rate by passing the dictionary later.
|
||||
optimizer_conf = OptimizerConfigs()
|
||||
optimizer_conf.parameters = self.model.parameters()
|
||||
self.optimizer = optimizer_conf
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
### Training loop
|
||||
@ -276,6 +256,38 @@ class Configs(BaseConfigs):
|
||||
tracker.save()
|
||||
|
||||
|
||||
@option(Configs.dataset)
|
||||
def cora_dataset(c: Configs):
|
||||
"""
|
||||
Create Cora dataset
|
||||
"""
|
||||
return CoraDataset(c.include_edges)
|
||||
|
||||
|
||||
# Get the number of classes
|
||||
calculate(Configs.n_classes, lambda c: len(c.dataset.classes))
|
||||
# Number of features in the input
|
||||
calculate(Configs.in_features, lambda c: c.dataset.features.shape[1])
|
||||
|
||||
|
||||
@option(Configs.model)
|
||||
def gat_model(c: Configs):
|
||||
"""
|
||||
Create GAT model
|
||||
"""
|
||||
return GAT(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout).to(c.device)
|
||||
|
||||
|
||||
@option(Configs.optimizer)
|
||||
def _optimizer(c: Configs):
|
||||
"""
|
||||
Create configurable optimizer
|
||||
"""
|
||||
opt_conf = OptimizerConfigs()
|
||||
opt_conf.parameters = c.model.parameters()
|
||||
return opt_conf
|
||||
|
||||
|
||||
def main():
|
||||
# Create configurations
|
||||
conf = Configs()
|
||||
@ -288,8 +300,6 @@ def main():
|
||||
'optimizer.learning_rate': 5e-3,
|
||||
'optimizer.weight_decay': 5e-4,
|
||||
})
|
||||
# Initialize
|
||||
conf.initialize()
|
||||
|
||||
# Start and watch the experiment
|
||||
with experiment.start():
|
||||
|
Reference in New Issue
Block a user