Source code for sfa.plot.si


from collections import Counter

import numpy as np
import pandas as pd
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.ticker import FormatStrFormatter
from matplotlib import rcParams

rcParams['font.family'] = 'sans-serif'
rcParams['font.sans-serif'] = ['Arial']

[docs]def siplot(df_splo, df_inf, output, min_splo=None, max_splo=None, thr_inf=1e-10, fmt_inf='%f', fig=None, cnt_max=None, ncol=4, designated=None, color='silver', dcolor='red', zcolor='red', alpha=0.7, xfontsize=8, yfontsize=8): # SPLO-Influence Data if not min_splo: min_splo = df_splo.min() if not max_splo: max_splo = df_splo.max() mask_splo = (min_splo <= df_splo) & (df_splo <= max_splo) df_splo = df_splo[mask_splo] df_splo = pd.DataFrame(df_splo) df_splo.columns = ['SPLO'] if output in df_splo.index: df_splo.drop(output, inplace=True) index_common = df_splo.index.intersection(df_inf.index) df_inf = pd.DataFrame(df_inf.loc[index_common]) mark_drop = df_inf[output].abs() <= thr_inf df_inf.drop(df_inf.loc[mark_drop, output].index, inplace=True) df_si = df_inf.join(df_splo.loc[index_common]) df_si.index.name = 'Source' df_si.reset_index(inplace=True) cnt_splo = Counter(df_si['SPLO']) if not cnt_max: cnt_max = max(cnt_splo.values()) splos = sorted(cnt_splo.keys()) nrow = int(np.ceil(len(splos)/ncol)) # Plot if not fig: fig = plt.figure() gs = gridspec.GridSpec(nrow, ncol) yvals = np.arange(1, cnt_max +1) for i, splo in enumerate(splos): idx_row = int(i / ncol) idx_col = int(i % ncol) ax = fig.add_subplot(gs[idx_row, idx_col]) df_sub = df_si[df_si['SPLO'] == splo] df_sub = df_sub.sort_values(by=output) num_items = df_sub[output].count() influence = np.zeros((cnt_max,)) # Influence num_empty = cnt_max - num_items influence[num_empty:] = df_sub[output] names = df_sub['Source'].tolist() names = ['' ] *(num_empty) + names # Plot bars plt.barh(yvals, influence, align='center', alpha=alpha) ax.set_title('SPLO=%d'%(splo)) ax.set_xlabel('') ax.xaxis.set_major_formatter(FormatStrFormatter(fmt_inf)) ax.tick_params(axis='x', which='major', labelsize=xfontsize) ax.set_ylabel('') ax.yaxis.set_ticks_position('right') ax.tick_params(axis='y', which='major', labelsize=yfontsize) plt.yticks(yvals, names) # Draw zero line. if not((influence <= 0).all() or (influence >= 0).all()): ax.vlines(x=0.0, ymin=0, ymax=yvals[-1]+1, color=zcolor) # Set limitations ax.set_ylim(0, cnt_max +1) if designated: # Filter bar graphics. bars = [] cnt_bars = 0 for obj in ax.get_children(): if cnt_bars == cnt_max: break if isinstance(obj, Rectangle): bars.append(obj) obj.set_color(color) cnt_bars += 1 # end of for # Change the bars of the designated names. for i, name in enumerate(names): if name in designated: bars[i].set_color(dcolor) # end of for # Change the text colors of the designated names. for obj in ax.get_yticklabels(): name = obj.get_text() if name in designated: obj.set_color(dcolor) # end of for # end of for # Make zero notation more simple. fig.canvas.draw() for ax in fig.axes: labels = [] for obj in ax.get_xticklabels(): try: text = obj.get_text() num = float(text) except ValueError: labels.append(text) continue if num == 0: labels.append('0') else: labels.append(text) # end of for ax.set_xticklabels(labels) # end of for # end of for return fig