Files
2025-07-20 09:02:34 +05:30

86 lines
1.9 KiB
Python

import dataclasses
from abc import ABC
import torch
from labml import tracker
class StateModule:
def __init__(self):
pass
# def __call__(self):
# raise NotImplementedError
def create_state(self) -> any:
raise NotImplementedError
def set_state(self, data: any):
raise NotImplementedError
def on_epoch_start(self):
raise NotImplementedError
def on_epoch_end(self):
raise NotImplementedError
class Metric(StateModule, ABC):
def track(self):
pass
@dataclasses.dataclass
class AccuracyState:
samples: int = 0
correct: int = 0
def reset(self):
self.samples = 0
self.correct = 0
class Accuracy(Metric):
data: AccuracyState
def __init__(self, ignore_index: int = -1):
super().__init__()
self.ignore_index = ignore_index
def __call__(self, output: torch.Tensor, target: torch.Tensor):
output = output.view(-1, output.shape[-1])
target = target.view(-1)
pred = output.argmax(dim=-1)
mask = target == self.ignore_index
pred.masked_fill_(mask, self.ignore_index)
n_masked = mask.sum().item()
self.data.correct += pred.eq(target).sum().item() - n_masked
self.data.samples += len(target) - n_masked
def create_state(self):
return AccuracyState()
def set_state(self, data: any):
self.data = data
def on_epoch_start(self):
self.data.reset()
def on_epoch_end(self):
self.track()
def track(self):
if self.data.samples == 0:
return
tracker.add("accuracy.", self.data.correct / self.data.samples)
class AccuracyDirect(Accuracy):
data: AccuracyState
def __call__(self, output: torch.Tensor, target: torch.Tensor):
output = output.view(-1)
target = target.view(-1)
self.data.correct += output.eq(target).sum().item()
self.data.samples += len(target)