cleanup hook model outputs

This commit is contained in:
Varuna Jayasiri
2025-07-20 09:02:34 +05:30
parent 5bdedcffec
commit a713c92b82
12 changed files with 36 additions and 142 deletions

View File

@ -75,43 +75,6 @@ class Accuracy(Metric):
tracker.add("accuracy.", self.data.correct / self.data.samples)
class AccuracyMovingAvg(Metric):
def __init__(self, ignore_index: int = -1, queue_size: int = 5):
super().__init__()
self.ignore_index = ignore_index
tracker.set_queue('accuracy.*', queue_size, is_print=True)
def __call__(self, output: torch.Tensor, target: torch.Tensor):
output = output.view(-1, output.shape[-1])
target = target.view(-1)
pred = output.argmax(dim=-1)
mask = target == self.ignore_index
pred.masked_fill_(mask, self.ignore_index)
n_masked = mask.sum().item()
if len(target) - n_masked > 0:
tracker.add('accuracy.', (pred.eq(target).sum().item() - n_masked) / (len(target) - n_masked))
def create_state(self):
return None
def set_state(self, data: any):
pass
def on_epoch_start(self):
pass
def on_epoch_end(self):
pass
class BinaryAccuracy(Accuracy):
def __call__(self, output: torch.Tensor, target: torch.Tensor):
pred = output.view(-1) > 0
target = target.view(-1)
self.data.correct += pred.eq(target).sum().item()
self.data.samples += len(target)
class AccuracyDirect(Accuracy):
data: AccuracyState