mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-11-04 22:38:36 +08:00 
			
		
		
		
	group norm cleanup
This commit is contained in:
		@ -56,14 +56,14 @@ class GroupNorm(Module):
 | 
				
			|||||||
        assert self.channels == x.shape[1]
 | 
					        assert self.channels == x.shape[1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Reshape into `[batch_size, channels, n]`
 | 
					        # Reshape into `[batch_size, channels, n]`
 | 
				
			||||||
        x = x.view(batch_size, self.groups, self.channels // self.groups, -1)
 | 
					        x = x.view(batch_size, self.groups, -1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Calculate the mean across first and last dimension;
 | 
					        # Calculate the mean across first and last dimension;
 | 
				
			||||||
        # i.e. the means for each feature $\mathbb{E}[x^{(k)}]$
 | 
					        # i.e. the means for each feature $\mathbb{E}[x^{(k)}]$
 | 
				
			||||||
        mean = x.mean(dim=[2, 3], keepdims=True)
 | 
					        mean = x.mean(dim=[2], keepdims=True)
 | 
				
			||||||
        # Calculate the squared mean across first and last dimension;
 | 
					        # Calculate the squared mean across first and last dimension;
 | 
				
			||||||
        # i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$
 | 
					        # i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$
 | 
				
			||||||
        mean_x2 = (x ** 2).mean(dim=[2, 3], keepdims=True)
 | 
					        mean_x2 = (x ** 2).mean(dim=[2], keepdims=True)
 | 
				
			||||||
        # Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$
 | 
					        # Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$
 | 
				
			||||||
        var = mean_x2 - mean ** 2
 | 
					        var = mean_x2 - mean ** 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -40,19 +40,23 @@ class Model(Module):
 | 
				
			|||||||
        return self.fc(x)
 | 
					        return self.fc(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@option(CIFAR10Configs.model)
 | 
					class Configs(CIFAR10Configs):
 | 
				
			||||||
def model(c: CIFAR10Configs):
 | 
					    groups: int = 16
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@option(Configs.model)
 | 
				
			||||||
 | 
					def model(c: Configs):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    ### Create model
 | 
					    ### Create model
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    return Model().to(c.device)
 | 
					    return Model(c.groups).to(c.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def main():
 | 
					def main():
 | 
				
			||||||
    # Create experiment
 | 
					    # Create experiment
 | 
				
			||||||
    experiment.create(name='cifar10', comment='group norm')
 | 
					    experiment.create(name='cifar10', comment='group norm')
 | 
				
			||||||
    # Create configurations
 | 
					    # Create configurations
 | 
				
			||||||
    conf = CIFAR10Configs()
 | 
					    conf = Configs()
 | 
				
			||||||
    # Load configurations
 | 
					    # Load configurations
 | 
				
			||||||
    experiment.configs(conf, {
 | 
					    experiment.configs(conf, {
 | 
				
			||||||
        'optimizer.optimizer': 'Adam',
 | 
					        'optimizer.optimizer': 'Adam',
 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user