Source code for rarity.features.feat_similarities_counter_factuals

# Copyright 2021 AI Singapore. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union, List
import pandas as pd

import dash
from dash.dependencies import Input, Output, State
from dash.exceptions import PreventUpdate
import dash_core_components as dcc
import dash_html_components as html
import dash_bootstrap_components as dbc

from rarity.app import app
from rarity.data_loader import CSVDataLoader, DataframeLoader
from rarity.interpreters.structured_data import IntSimilaritiesCounterFactuals
from rarity.visualizers import shared_viz_component as viz_shared
from rarity.utils import style_configs
from rarity.utils.common_functions import is_regression, is_classification, get_max_value_on_slider, \
                                            detected_bimodal, detected_invalid_index_inputs


[docs]def generate_similarities(data_loader: Union[CSVDataLoader, DataframeLoader], user_defined_idx, feature_to_exclude=None, top_n=3): ''' Tapout table collating feature info corresponding to user defined index and top N index based on distance score. Applicable to both regression and classification Arguments: data_loader (:class:`~rarity.data_loader.CSVDataLoader` or :class:`~rarity.data_loader.DataframeLoader`): Class object from data_loader module user_defined_idx (int): Index of the data point of interest specified by user feature_to_exclude (:obj:`List[str]`, `optional`): A list of features to be excluded from the ranking and similarities distance calculation top_n (int): Number indicating the max limit of records to be displayed based on the distance ranking Returns: :obj:`~dash_table.DataTable`: table object outlining the dataframe content with dynamic-conditional styles ''' df_viz, idx_for_top_n, calculated_distance = IntSimilaritiesCounterFactuals(data_loader).xform(user_defined_idx, feature_to_exclude, top_n) df_top_n = _base_df_by_calculated_distance(df_viz, idx_for_top_n, calculated_distance) category = ['User_defined_idx'] category_top_n = [f'Top_{i + 1}' for i in range(len(df_top_n) - 1)] # len(df_viz) - 1 as first row is user_defined_idx category = category + category_top_n df_top_n.insert(0, 'category', category) feature_cols = list(data_loader.get_features().columns) if is_regression(data_loader.get_analysis_type()): try: # to limit table cell having values with long decimals for better viz purpose df_top_n[list(df_top_n.columns)[3:]] = df_top_n[list(df_top_n.columns)[3:]].round(2) except TypeError: df_top_n elif is_classification(data_loader.get_analysis_type()): for col in df_top_n.columns: if 'float' in str(df_top_n[col].dtypes): df_top_n[col] = df_top_n[col].apply(lambda x: round(x, 4)) table_obj = viz_shared.reponsive_table_to_filtered_datapoints_similaritiesCF(df_top_n, list(df_top_n.columns), feature_cols, style_configs.DEFAULT_HEADER_STYLE, 'csv') return table_obj
[docs]def generate_counterfactuals(data_loader: Union[CSVDataLoader, DataframeLoader], user_defined_idx, feature_to_exclude=None, top_n=3): ''' Tapout table collating feature info corresponding to user defined index and top N index based on distance score with condition that the prediction labels of top N index differ from prediction label of user defined index Applicable to both classification only Arguments: data_loader (:class:`~rarity.data_loader.CSVDataLoader` or :class:`~rarity.data_loader.DataframeLoader`): Class object from data_loader module user_defined_idx (int): Index of the data point of interest specified by user feature_to_exclude (:obj:`List[str]`, `optional`): A list of features to be excluded from the ranking and similarities distance calculation top_n (int): Number indicating the max limit of records to be displayed based on the distance ranking Returns: :obj:`~dash_table.DataTable`: table object outlining the dataframe content with dynamic-conditional styles ''' org_data_size = len(data_loader.get_features()) df_viz, idx_sorted_by_distance, calculated_distance = IntSimilaritiesCounterFactuals(data_loader).xform(user_defined_idx, feature_to_exclude, org_data_size) df_top_n = _base_df_by_calculated_distance(df_viz, idx_sorted_by_distance, calculated_distance) table_objs = [] for model in data_loader.get_model_list(): usr_pred_label = df_top_n.loc[lambda x: x['index'] == user_defined_idx, f'yPred_{model}'].values[0] usr_true_label = df_top_n.loc[lambda x: x['index'] == user_defined_idx, 'yTrue'].values[0] df_user_idx = df_top_n.loc[lambda x: x['index'] == user_defined_idx, :] df_filtered_cf = df_top_n[(df_top_n['yTrue'] == usr_true_label) & (df_top_n[f'yPred_{model}'] != usr_pred_label)].head(top_n) df_top_n_cf = pd.concat([df_user_idx, df_filtered_cf], axis=0) category = ['User_defined_idx'] category_top_n = [f'Top_CounterFactual_{i + 1}' for i in range(len(df_top_n_cf) - 1)] # len(df_top_n_cf)-1 : 1st row is user_defined_idx category = category + category_top_n df_top_n_cf.insert(0, 'category', category) feature_cols = list(data_loader.get_features().columns) for col in df_top_n_cf.columns: if 'float' in str(df_top_n_cf[col].dtypes): df_top_n_cf[col] = df_top_n_cf[col].apply(lambda x: round(x, 4)) table_obj = viz_shared.reponsive_table_to_filtered_datapoints_similaritiesCF(df_top_n_cf, list(df_top_n_cf.columns), feature_cols, style_configs.DEFAULT_HEADER_STYLE, 'csv') table_objs.append(table_obj) return table_objs
def _base_df_by_calculated_distance(df: pd.DataFrame, idx_sorted_by_distance: List[int], calculated_distance: List[float]): ''' Setup new dataframe storing calculated distance info ''' df_top_n = pd.DataFrame() df_top_n['index'] = idx_sorted_by_distance df_top_n['calculated_distance'] = calculated_distance df_top_n = df_top_n.merge(df, how='left', on='index') return df_top_n def _table_objs_similarities(data_loader: Union[CSVDataLoader, DataframeLoader], user_defined_idx: int, feature_to_exclude: List[str], top_n: int): ''' List collating layouts for similarities table based on user index/indices ''' table_objs_similarities = [] for idx in str(user_defined_idx).replace(' ', '').split(','): table_obj_similarities = generate_similarities(data_loader, int(idx), feature_to_exclude, top_n) row_layout_single_idx_table_obj = dbc.Row( html.Div(table_obj_similarities, className='div__table-proba-misspred'), justify='center') table_objs_similarities.append(row_layout_single_idx_table_obj) return table_objs_similarities def _table_objs_counterfactuals(data_loader: Union[CSVDataLoader, DataframeLoader], user_defined_idx: int, feature_to_exclude: List[str], top_n: int): ''' List collating layouts for counterfactual tables based on user index/indices ''' models = data_loader.get_model_list() table_objs_counterfactuals = [] for idx in str(user_defined_idx).replace(' ', '').split(','): table_objs_cf_single_idx = generate_counterfactuals(data_loader, int(idx), feature_to_exclude, top_n) table_objs_counterfactuals.append(table_objs_cf_single_idx) viz_table_objs_counterfactuals = [] for table_objs_cf in table_objs_counterfactuals: if detected_bimodal(models): cf_table_obj_single_idx = [dbc.Row([html.H5('Counter-Factuals for model : ', id='title-after-topn-bottomn-reg', className='h5__counterfactuals-section-title'), html.H5(f'{models[0]}', className='h5__counterfactual-model')]), dbc.Row( html.Div(table_objs_cf[0], id='table-obj-similarities-bm-1', className='div__table-proba-misspred'), justify='center'), html.Br(), dbc.Row([dbc.Row(html.H5('Counter-Factuals for model : ', id='title-after-topn-bottomn-reg', className='h5__counterfactuals-section-title')), dbc.Row(html.H5(f'{models[1]}', className='h5__counterfactual-model'))]), dbc.Row( html.Div(table_objs_cf[1], id='table-obj-similarities-bm-2', className='div__table-proba-misspred'), justify='center')] viz_table_objs_counterfactuals += cf_table_obj_single_idx else: cf_table_obj_single_idx = [dbc.Row([html.H5('Counter-Factuals for model : ', id='title-after-topn-bottomn-reg', className='h5__counterfactuals-section-title'), html.H5(f'{models[0]}', className='h5__counterfactual-model')]), dbc.Row( html.Div(table_objs_cf[0], id='table-obj-similarities-sm', className='div__table-proba-misspred'), justify='center')] viz_table_objs_counterfactuals += cf_table_obj_single_idx return viz_table_objs_counterfactuals
[docs]class SimilaritiesCF: ''' Main integration for feature component on Similarities-CounterFactuals Arguments: data_loader (:class:`~rarity.data_loader.CSVDataLoader` or :class:`~rarity.data_loader.DataframeLoader`): Class object from data_loader module Important Attributes: analysis_type (str): Analysis type defined by user during initial inputs preparation via data_loader stage. df_features (:obj:`~pandas.DataFrame`): Dataframe storing all features used in dataset feature_to_exclude (:obj:`List[str]`, `optional`): A list of features to be excluded from the ranking and similarities distance calculation user_defined_idx (int): Index of the data point of interest specified by user top_n (int): Number indicating the max limit of records to be displayed based on the distance ranking Returns: :obj:`~dash_core_components.Container`: styled dash components displaying graph and/or table objects ''' def __init__(self, data_loader: Union[CSVDataLoader, DataframeLoader]): self.data_loader = data_loader self.analysis_type = data_loader.get_analysis_type() self.df_features = data_loader.get_features() self.feature_to_exclude = [] self.user_defined_idx = '1' self.top_n = 3 self.table_objs_similarities = _table_objs_similarities(self.data_loader, self.user_defined_idx, self.feature_to_exclude, self.top_n) def show(self): ''' Method to tapout styled html for Similarities-CounterFactuals ''' options_feature_ls = [{'label': f'{col}', 'value': f'{col}'} for col in self.df_features.columns] shared_layout_reg_cls = [ html.Div([ dbc.Row([ dbc.Col([ dbc.Row(html.Div(html.H6('Specify data index ( default index: 1 ) :'), className='h6__similaritiesCF-index')), dbc.Row(dbc.Input(id='input-specific-idx-similaritiesCF', placeholder='example: 1 OR 12, 123, 1234', type='text', value=None)), dbc.Row(html.Div(html.Pre(style_configs.input_range_subnote(self.df_features), className='text__range-header-kldiv-featfist')))], width=6), dbc.Col([ dbc.Row(html.Div(html.H6('Select feature to exclude from similarities calculation ' '( if applicable ) :'), className='h6__feature-to-exclude')), dbc.Row(dbc.Col(dcc.Dropdown(id='select-feature-to-exclude-similaritiesCF', options=options_feature_ls, value=[], multi=True)))], width=6)]), html.Br(), dbc.Row([ dbc.Col([ dbc.Row(html.Div(html.H6('Select number of records to display'), className='h6__display-top-bottom-n')), dbc.Row(html.Div(html.Span('( ranked by calculated distance on overall feature similarites ' 'referencing to the index defined above ):'), className='text__display-header-similaritiesCF')), dcc.Slider(id='select-slider-top-n-similaritiesCF', min=1, max=get_max_value_on_slider(self.df_features, 'similaritiesCF'), # max at 10 step=1, value=3, marks=style_configs.DEFAULT_SLIDER_RANGE)], width=10), dbc.Col(dbc.Row( dcc.Loading(id='loading-output-similaritiesCF', type='circle', color='#a80202'), justify='left', className='loading__similaritiesCF'), width=1), dbc.Col(dbc.Row( dbc.Button("Update", id='button-similaritiesCF-update', n_clicks=0, className='button__update-dataset'), justify='right'))]) ], className='border__select-dataset'), html.Div(id='alert-index-input-error-similaritiesCF'), html.Br(), dbc.Row(html.H5('Comparison based on Feature Similarities', id='title-after-topn-bottomn-reg', className='h5__counterfactuals-section-title')) ] + [html.Div(self.table_objs_similarities, id='table-objs-similarities-shared')] if is_regression(self.analysis_type): similaritiesCF = dbc.Container(shared_layout_reg_cls + [html.Div(id='table-objs-counter-factuals')], fluid=True) elif is_classification(self.analysis_type): self.models = self.data_loader.get_model_list() self.table_objs_counterfactuals = _table_objs_counterfactuals(self.data_loader, self.user_defined_idx, self.feature_to_exclude, self.top_n) combined_layouts_similaritiesCF = shared_layout_reg_cls + [html.Br(), html.Div(self.table_objs_counterfactuals, id='table-objs-counter-factuals')] similaritiesCF = dbc.Container(combined_layouts_similaritiesCF, fluid=True) return similaritiesCF def callbacks(self): @app.callback( Output('loading-output-similaritiesCF', 'children'), Output('alert-index-input-error-similaritiesCF', 'children'), Output('table-objs-similarities-shared', 'children'), Output('table-objs-counter-factuals', 'children'), Input('button-similaritiesCF-update', 'n_clicks'), State('input-specific-idx-similaritiesCF', 'value'), State('select-feature-to-exclude-similaritiesCF', 'value'), State('select-slider-top-n-similaritiesCF', 'value')) def generate_table_objs_based_on_user_selected_params(click_count, specific_idx, feature_to_exclude, top_n): ''' Callbacks functionalities on params related to top-n, botton-n / both on regression and classification tasks ''' if click_count > 0: if specific_idx is None: # during first spin-up specific_idx = self.user_defined_idx # default value idx_input_err_alert = style_configs.no_error_alert() if detected_invalid_index_inputs(specific_idx, self.df_features): idx_input_err_alert = style_configs.activate_invalid_index_input_alert(self.df_features) return '', idx_input_err_alert, dash.no_update, dash.no_update if is_regression(self.analysis_type): table_objs_similarities = _table_objs_similarities(self.data_loader, specific_idx, feature_to_exclude, top_n) return '', idx_input_err_alert, table_objs_similarities, '' elif is_classification(self.analysis_type): table_objs_similarities = _table_objs_similarities(self.data_loader, specific_idx, feature_to_exclude, top_n) table_objs_counterfactuals = _table_objs_counterfactuals(self.data_loader, specific_idx, feature_to_exclude, top_n) return '', idx_input_err_alert, table_objs_similarities, table_objs_counterfactuals else: raise PreventUpdate