mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 17:41:37 +08:00
67 lines
1.8 KiB
Python
67 lines
1.8 KiB
Python
from typing import List
|
|
|
|
import altair as alt
|
|
import numpy as np
|
|
|
|
from labml import analytics
|
|
from labml.analytics import IndicatorCollection
|
|
|
|
|
|
def calculate_percentages(means: List[np.ndarray], names: List[List[str]]):
|
|
normalized = []
|
|
|
|
for i in range(len(means)):
|
|
total = np.zeros_like(means[i])
|
|
for j, n in enumerate(names):
|
|
if n[-1][:-1] == names[i][-1][:-1]:
|
|
total += means[j]
|
|
normalized.append(means[i] / (total + np.finfo(float).eps))
|
|
|
|
return normalized
|
|
|
|
|
|
def plot_infosets(indicators: IndicatorCollection, *,
|
|
is_normalize: bool = True,
|
|
width: int = 600,
|
|
height: int = 300):
|
|
data, names = analytics.indicator_data(indicators)
|
|
step = [d[:, 0] for d in data]
|
|
means = [d[:, 5] for d in data]
|
|
|
|
if is_normalize:
|
|
normalized = calculate_percentages(means, names)
|
|
else:
|
|
normalized = means
|
|
|
|
common = names[0][-1]
|
|
for i, n in enumerate(names):
|
|
n = n[-1]
|
|
if len(n) < len(common):
|
|
common = common[:len(n)]
|
|
for j in range(len(common)):
|
|
if common[j] != n[j]:
|
|
common = common[:j]
|
|
break
|
|
|
|
table = []
|
|
for i, n in enumerate(names):
|
|
for j, v in zip(step[i], normalized[i]):
|
|
table.append({
|
|
'series': n[-1][len(common):],
|
|
'step': j,
|
|
'value': v
|
|
})
|
|
|
|
table = alt.Data(values=table)
|
|
|
|
selection = alt.selection_multi(fields=['series'], bind='legend')
|
|
|
|
return alt.Chart(table).mark_line().encode(
|
|
alt.X('step:Q'),
|
|
alt.Y('value:Q'),
|
|
alt.Color('series:N', scale=alt.Scale(scheme='tableau20')),
|
|
opacity=alt.condition(selection, alt.value(1), alt.value(0.0001))
|
|
).add_selection(
|
|
selection
|
|
).properties(width=width, height=height)
|