1import dataclasses
2from abc import ABC
3
4import torch
5from labml import tracker
8class StateModule:
9 def __init__(self):
10 pass
def __call__(self): raise NotImplementedError
15 def create_state(self) -> any:
16 raise NotImplementedError
18 def set_state(self, data: any):
19 raise NotImplementedError
21 def on_epoch_start(self):
22 raise NotImplementedError
24 def on_epoch_end(self):
25 raise NotImplementedError
28class Metric(StateModule, ABC):
29 def track(self):
30 pass
33@dataclasses.dataclass
34class AccuracyState:
35 samples: int = 0
36 correct: int = 0
37
38 def reset(self):
39 self.samples = 0
40 self.correct = 0
41
42
43class Accuracy(Metric):
44 data: AccuracyState
45
46 def __init__(self, ignore_index: int = -1):
47 super().__init__()
48 self.ignore_index = ignore_index
49
50 def __call__(self, output: torch.Tensor, target: torch.Tensor):
51 output = output.view(-1, output.shape[-1])
52 target = target.view(-1)
53 pred = output.argmax(dim=-1)
54 mask = target == self.ignore_index
55 pred.masked_fill_(mask, self.ignore_index)
56 n_masked = mask.sum().item()
57 self.data.correct += pred.eq(target).sum().item() - n_masked
58 self.data.samples += len(target) - n_masked
59
60 def create_state(self):
61 return AccuracyState()
62
63 def set_state(self, data: any):
64 self.data = data
65
66 def on_epoch_start(self):
67 self.data.reset()
68
69 def on_epoch_end(self):
70 self.track()
71
72 def track(self):
73 if self.data.samples == 0:
74 return
75 tracker.add("accuracy.", self.data.correct / self.data.samples)
76
77
78class AccuracyDirect(Accuracy):
79 data: AccuracyState
80
81 def __call__(self, output: torch.Tensor, target: torch.Tensor):
82 output = output.view(-1)
83 target = target.view(-1)
84 self.data.correct += output.eq(target).sum().item()
85 self.data.samples += len(target)