Source code for rarity.features.feat_general_metrics

# Copyright 2021 AI Singapore. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

import dash_core_components as dcc
import dash_html_components as html
import dash_bootstrap_components as dbc

from rarity.data_loader import CSVDataLoader, DataframeLoader
from rarity.interpreters.structured_data import IntGeneralMetrics
from rarity.visualizers import general_metrics as viz_general
from rarity.utils.common_functions import is_regression, is_classification


[docs]def fig_confusion_matrix(data_loader: Union[CSVDataLoader, DataframeLoader]): """ Create confusion matrix Arguments: data_loader (:class:`~rarity.data_loader.CSVDataLoader` or :class:`~rarity.data_loader.DataframeLoader`): Class object from data_loader module Returns: :obj:`~plotly.graph_objects.Figure`: figure displaying confusion matrix details """ yTrue, yPred, model_names = IntGeneralMetrics(data_loader, 'confMat').xform() fig_objs = viz_general.plot_confusion_matrix(yTrue, yPred, model_names) return fig_objs
[docs]def fig_classification_report(data_loader: Union[CSVDataLoader, DataframeLoader]): """ Create classification report in table form Arguments: data_loader (:class:`~rarity.data_loader.CSVDataLoader` or :class:`~rarity.data_loader.DataframeLoader`): Class object from data_loader module Returns: :obj:`List[~plotly.graph_objects.Figure]`: list of tables displaying classification report details """ yTrue, yPred, model_names = IntGeneralMetrics(data_loader, 'classRpt').xform() fig_objs = viz_general.plot_classification_report(yTrue, yPred, model_names) return fig_objs
[docs]def fig_roc_curve(data_loader: Union[CSVDataLoader, DataframeLoader]): """ Display roc curve for comparison on various models Arguments: data_loader (:class:`~rarity.data_loader.CSVDataLoader` or :class:`~rarity.data_loader.DataframeLoader`): Class object from data_loader module Returns: :obj:`~plotly.graph_objects.Figure`: figure displaying line curves comparing roc-auc score for various models """ yTrue, yPred, model_names = IntGeneralMetrics(data_loader, 'rocAuc').xform() fig_obj = viz_general.plot_roc_curve(yTrue, yPred, model_names) return fig_obj
[docs]def fig_precisionRecall_curve(data_loader: Union[CSVDataLoader, DataframeLoader]): """ Display precision-recall curve for comparison on various models Arguments: data_loader (:class:`~rarity.data_loader.CSVDataLoader` or :class:`~rarity.data_loader.DataframeLoader`): Class object from data_loader module Returns: :obj:`~plotly.graph_objects.Figure`: figure displaying line curves comparing precision-recall for various models """ yTrue, yPred, model_names = IntGeneralMetrics(data_loader, 'precRecall').xform() fig_obj = viz_general.plot_precisionRecall_curve(yTrue, yPred, model_names) return fig_obj
[docs]def fig_prediction_actual_comparison(data_loader: Union[CSVDataLoader, DataframeLoader]): """ Display scatter plot for comparison on actual values (yTrue) vs prediction values (yPred) Arguments: data_loader (:class:`~rarity.data_loader.CSVDataLoader` or :class:`~rarity.data_loader.DataframeLoader`): Class object from data_loader module Returns: :obj:`~plotly.graph_objects.Figure`: figure displaying scatter plot comparing actual values vs prediction values """ df = IntGeneralMetrics(data_loader).xform() fig_obj = viz_general.plot_prediction_vs_actual(df) return fig_obj
[docs]def fig_prediction_offset_overview(data_loader: Union[CSVDataLoader, DataframeLoader]): """ Display scatter plot for overview on prediction offset values Arguments: data_loader (:class:`~rarity.data_loader.CSVDataLoader` or :class:`~rarity.data_loader.DataframeLoader`): Class object from data_loader module Returns: :obj:`~plotly.graph_objects.Figure`: figure displaying scatter plot outlining overview on prediction offset values """ df = IntGeneralMetrics(data_loader).xform() fig_obj = viz_general.plot_prediction_offset_overview(df) return fig_obj
[docs]def fig_standard_error_metrics(data_loader: Union[CSVDataLoader, DataframeLoader]): """ Display table comparing various standard metrics for regression task Arguments: data_loader (:class:`~rarity.data_loader.CSVDataLoader` or :class:`~rarity.data_loader.DataframeLoader`): Class object from data_loader module Returns: :obj:`~dash_table.DataTable`: table object comparing various standard metrics for regression task """ df = IntGeneralMetrics(data_loader, 'stdErr').xform() fig_obj = viz_general.plot_std_error_metrics(df) return fig_obj
[docs]class GeneralMetrics: ''' Main integration for feature component on General Metrics. - On Regression: ``Prediction vs Actual``, ``Prediction vs Offset`` - On Classification: ``Confusion Matrix``, ``Classification Report``, ``ROC``, ``Precision-Recall`` Arguments: data_loader (:class:`~rarity.data_loader.CSVDataLoader` or :class:`~rarity.data_loader.DataframeLoader`): Class object from data_loader module Returns: :obj:`~dash_core_components.Container`: styled dash components displaying graph and/or table objects ''' def __init__(self, data_loader: Union[CSVDataLoader, DataframeLoader]): self.data_loader = data_loader self.analysis_type = data_loader.get_analysis_type() def show(self): if is_regression(self.analysis_type): self.pred_actual = fig_prediction_actual_comparison(self.data_loader) self.pred_offset = fig_prediction_offset_overview(self.data_loader) self.std_error_metrics = fig_standard_error_metrics(self.data_loader) gen_metrics = dbc.Container([ dbc.Row([ dbc.Col([ dcc.Graph(figure=self.pred_actual) ], className="border__common-gen-metrics-reg-left"), dbc.Col([ dcc.Graph(figure=self.pred_offset) ], className="border__common-gen-metrics-reg-right"), ]), html.Div(html.Div(self.std_error_metrics, className="div__std-err")), html.Br(), ], fluid=True) elif is_classification(self.analysis_type): self.conf_matrix = fig_confusion_matrix(self.data_loader) self.cls_report = fig_classification_report(self.data_loader) self.roc = fig_roc_curve(self.data_loader) self.prec_recall = fig_precisionRecall_curve(self.data_loader) if len(self.conf_matrix) > 1: gen_metrics = dbc.Container([ dbc.Row([ dbc.Col([ dcc.Graph(figure=self.conf_matrix[0]), dcc.Graph(figure=self.cls_report[0]), ], className="border__common-left"), dbc.Col([ dcc.Graph(figure=self.conf_matrix[1]), dcc.Graph(figure=self.cls_report[1]), ], className="border__common-right") ]), html.Div(html.Div(dcc.Graph(figure=self.roc), className="fig__roc-prec-recall"), className="border__common"), html.Div(html.Div(dcc.Graph(figure=self.prec_recall), className="fig__roc-prec-recall"), className="border__common") ], fluid=True) elif len(self.conf_matrix) == 1: gen_metrics = dbc.Container([ html.Div(dbc.Row([ dbc.Col(dcc.Graph(figure=self.conf_matrix[0]), className="border__common-left"), dbc.Col(dcc.Graph(figure=self.cls_report[0]), className="border__common-right") ]), className="boundary__common"), html.Div(html.Div(dcc.Graph(figure=self.roc), className="fig__roc-prec-recall"), className="border__common"), html.Div(html.Div(dcc.Graph(figure=self.prec_recall), className="fig__roc-prec-recall"), className="border__common") ], fluid=True, className="boundary__common") return gen_metrics