Files
Varuna Jayasiri 07065dea92 CFR (#60)
2021-06-21 17:04:20 +05:30

67 lines
1.8 KiB
Python

from typing import List
import altair as alt
import numpy as np
from labml import analytics
from labml.analytics import IndicatorCollection
def calculate_percentages(means: List[np.ndarray], names: List[List[str]]):
normalized = []
for i in range(len(means)):
total = np.zeros_like(means[i])
for j, n in enumerate(names):
if n[-1][:-1] == names[i][-1][:-1]:
total += means[j]
normalized.append(means[i] / (total + np.finfo(float).eps))
return normalized
def plot_infosets(indicators: IndicatorCollection, *,
is_normalize: bool = True,
width: int = 600,
height: int = 300):
data, names = analytics.indicator_data(indicators)
step = [d[:, 0] for d in data]
means = [d[:, 5] for d in data]
if is_normalize:
normalized = calculate_percentages(means, names)
else:
normalized = means
common = names[0][-1]
for i, n in enumerate(names):
n = n[-1]
if len(n) < len(common):
common = common[:len(n)]
for j in range(len(common)):
if common[j] != n[j]:
common = common[:j]
break
table = []
for i, n in enumerate(names):
for j, v in zip(step[i], normalized[i]):
table.append({
'series': n[-1][len(common):],
'step': j,
'value': v
})
table = alt.Data(values=table)
selection = alt.selection_multi(fields=['series'], bind='legend')
return alt.Chart(table).mark_line().encode(
alt.X('step:Q'),
alt.Y('value:Q'),
alt.Color('series:N', scale=alt.Scale(scheme='tableau20')),
opacity=alt.condition(selection, alt.value(1), alt.value(0.0001))
).add_selection(
selection
).properties(width=width, height=height)