Source code for rarity.visualizers.miss_predictions

# Copyright 2021 AI Singapore. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
import dash_table


[docs]def plot_probabilities_spread_pattern(df_specific_label: pd.DataFrame): ''' Display scatter plot for probabilities comparison on correct data point vs miss-predicted data point for each class label Arguments: df_specific_label (:obj:`~pd.DataFrame`): dataframe of 1 specific label of 1 model type, output from int_miss_predictions Returns: :obj:`~plotly.graph_objects.Figure`: figure displaying scatter plot outlining probabilities comparison on correct data point vs miss-predicted data point \ for each class label ''' label = list(df_specific_label.columns)[1] model_name = df_specific_label['model'].values[0] fig = px.scatter(df_specific_label, x=list(df_specific_label.index), y=df_specific_label[label], color='pred_state', category_orders={"pred_state": ["correct", "miss-predict"]}, color_discrete_sequence=px.colors.qualitative.D3) fig.update_layout( title=f'<b>Class {label}<br>[ {model_name} ]</b><br>', title_x=0.6, yaxis_title=f"Probability is_class_{label} ", yaxis_showgrid=False, xaxis_title="data_point index", xaxis_showgrid=False, legend_title="", legend=dict(orientation="h", yanchor="bottom", y=1.03, xanchor="right", x=0.8), width=250, height=600, margin=dict(t=170, b=0, l=12, r=12, pad=10)) fig.update_yaxes(range=[0, 1]) fig.update_xaxes(rangemode="tozero") fig.add_hline(y=0.5, line_dash="dot") # iterate through all traces, to ensure all label-class have consistent format for i in range(len(fig.data)): if fig.data[i]['legendgroup'] == 'correct': fig.data[i]['marker']['color'] = '#1f77b4' fig.data[i]['hovertemplate'] = "<b>Index %{x}</b><br>" + "<b>[ correct ]</b><br><br>" + \ "probability: %{y:.4f}<br>" + "<extra></extra>" elif fig.data[i]['legendgroup'] == 'miss-predict': fig.data[i]['marker']['color'] = '#FF7F0E' fig.data[i]['hovertemplate'] = "<b>Index %{x}</b><br>" + "<b>[ miss-predict ]</b><br><br>" + \ "probability: %{y:.4f}<br>" + "<extra></extra>" return fig
[docs]def plot_simple_probs_spread_overview(df_label_state: pd.DataFrame): ''' Display data table listing simple stats on ss, %correct, % wrong, accuracy for each label class Arguments: df_label_state (:obj:`~pd.DataFrame`): dataframe containing info on simple stats, output from int_miss_predictions Returns: :obj:`~dash_table.DataTable`: table object outlining simple stats on ss, %correct, % wrong, accuracy for each label class ''' fig = dash_table.DataTable( id='table', columns=[{'id': c, 'name': c} for c in df_label_state.columns], style_cell={'font-family': 'verdana', 'font-size': '14px', 'border': 'none', 'minWidth': '100px'}, style_header={'display': 'none'}, style_table={'width': '550', 'margin': 'auto'}, style_data={'lineHeight': '15px'}, style_data_conditional=[{'if': {'column_id': 'index'}, 'textAlign': 'left'}, {'if': {'column_id': 'state_value'}, 'textAlign': 'right'}], data=df_label_state.to_dict('records')) return fig
[docs]def plot_prediction_offset_overview(df: pd.DataFrame): ''' Display scatter plot for overview on prediction offset values Arguments: df (:obj:`~pd.DataFrame`): dataframe containing calculated offset values, output from int_miss_predictions Returns: :obj:`~plotly.graph_objects.Figure`: figure displaying scatter plot outlining overview on prediction offset values by index ''' pred_cols = [col for col in df.columns if 'yPred_' in col] offset_cols = [col for col in df.columns if 'offset_' in col] corrected_legend_names = [col.replace('yPred_', '') for col in pred_cols] df.insert(0, 'index', list(df.index)) fig = px.scatter(df, x='index', y=offset_cols[0], custom_data=['index'], color_discrete_sequence=px.colors.qualitative.D3) fig.data[0].name = corrected_legend_names[0] fig.update_traces(showlegend=True, hovertemplate="Data Index : %{x}<br>Prediction Offset : %{y}") if len(pred_cols) > 1: # Bimodal fig.add_trace(go.Scatter( x=df['index'], y=df[offset_cols[1]], name=corrected_legend_names[1], mode='markers', marker=dict(color='#FF7F0E'), hovertemplate="Data Index : %{x}<br>Prediction Offset : %{y}")) # add reference baseline [mainly to have baseline included in legend] fig.add_trace(go.Scatter( x=[0, len(df)], y=[0] * 2, name="Baseline [Prediction - Actual]", visible=True, hoverinfo='skip', mode='lines', line=dict(color="green", dash="dot"))) # referece baseline [mainly for the dotted line in graph, but no legend generated] fig.add_hline(y=0, line_dash="dot") fig.update_layout( title='<b>Prediction Offset Overview by Datapoint Index</b>', xaxis_title='Datapoint Index', yaxis_title='Offset from baseline', title_x=0.5, legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), width=1000, height=550, margin=dict(t=110), clickmode='event+select') return fig