1import dataclasses
2from abc import ABC
3
4import torch
5from labml import tracker
8class StateModule:
9    def __init__(self):
10        pass

def __call__(self): raise NotImplementedError

15    def create_state(self) -> any:
16        raise NotImplementedError
18    def set_state(self, data: any):
19        raise NotImplementedError
21    def on_epoch_start(self):
22        raise NotImplementedError
24    def on_epoch_end(self):
25        raise NotImplementedError
28class Metric(StateModule, ABC):
29    def track(self):
30        pass
33@dataclasses.dataclass
34class AccuracyState:
35    samples: int = 0
36    correct: int = 0
37
38    def reset(self):
39        self.samples = 0
40        self.correct = 0
41
42
43class Accuracy(Metric):
44    data: AccuracyState
45
46    def __init__(self, ignore_index: int = -1):
47        super().__init__()
48        self.ignore_index = ignore_index
49
50    def __call__(self, output: torch.Tensor, target: torch.Tensor):
51        output = output.view(-1, output.shape[-1])
52        target = target.view(-1)
53        pred = output.argmax(dim=-1)
54        mask = target == self.ignore_index
55        pred.masked_fill_(mask, self.ignore_index)
56        n_masked = mask.sum().item()
57        self.data.correct += pred.eq(target).sum().item() - n_masked
58        self.data.samples += len(target) - n_masked
59
60    def create_state(self):
61        return AccuracyState()
62
63    def set_state(self, data: any):
64        self.data = data
65
66    def on_epoch_start(self):
67        self.data.reset()
68
69    def on_epoch_end(self):
70        self.track()
71
72    def track(self):
73        if self.data.samples == 0:
74            return
75        tracker.add("accuracy.", self.data.correct / self.data.samples)
76
77
78class AccuracyDirect(Accuracy):
79    data: AccuracyState
80
81    def __call__(self, output: torch.Tensor, target: torch.Tensor):
82        output = output.view(-1)
83        target = target.view(-1)
84        self.data.correct += output.eq(target).sum().item()
85        self.data.samples += len(target)