mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-30 10:18:50 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			67 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			67 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from typing import List
 | |
| 
 | |
| import altair as alt
 | |
| import numpy as np
 | |
| 
 | |
| from labml import analytics
 | |
| from labml.analytics import IndicatorCollection
 | |
| 
 | |
| 
 | |
| def calculate_percentages(means: List[np.ndarray], names: List[List[str]]):
 | |
|     normalized = []
 | |
| 
 | |
|     for i in range(len(means)):
 | |
|         total = np.zeros_like(means[i])
 | |
|         for j, n in enumerate(names):
 | |
|             if n[-1][:-1] == names[i][-1][:-1]:
 | |
|                 total += means[j]
 | |
|         normalized.append(means[i] / (total + np.finfo(float).eps))
 | |
| 
 | |
|     return normalized
 | |
| 
 | |
| 
 | |
| def plot_infosets(indicators: IndicatorCollection, *,
 | |
|                   is_normalize: bool = True,
 | |
|                   width: int = 600,
 | |
|                   height: int = 300):
 | |
|     data, names = analytics.indicator_data(indicators)
 | |
|     step = [d[:, 0] for d in data]
 | |
|     means = [d[:, 5] for d in data]
 | |
| 
 | |
|     if is_normalize:
 | |
|         normalized = calculate_percentages(means, names)
 | |
|     else:
 | |
|         normalized = means
 | |
| 
 | |
|     common = names[0][-1]
 | |
|     for i, n in enumerate(names):
 | |
|         n = n[-1]
 | |
|         if len(n) < len(common):
 | |
|             common = common[:len(n)]
 | |
|         for j in range(len(common)):
 | |
|             if common[j] != n[j]:
 | |
|                 common = common[:j]
 | |
|                 break
 | |
| 
 | |
|     table = []
 | |
|     for i, n in enumerate(names):
 | |
|         for j, v in zip(step[i], normalized[i]):
 | |
|             table.append({
 | |
|                 'series': n[-1][len(common):],
 | |
|                 'step': j,
 | |
|                 'value': v
 | |
|             })
 | |
| 
 | |
|     table = alt.Data(values=table)
 | |
| 
 | |
|     selection = alt.selection_multi(fields=['series'], bind='legend')
 | |
| 
 | |
|     return alt.Chart(table).mark_line().encode(
 | |
|         alt.X('step:Q'),
 | |
|         alt.Y('value:Q'),
 | |
|         alt.Color('series:N', scale=alt.Scale(scheme='tableau20')),
 | |
|         opacity=alt.condition(selection, alt.value(1), alt.value(0.0001))
 | |
|     ).add_selection(
 | |
|         selection
 | |
|     ).properties(width=width, height=height)
 | 
