Source code for yellowbrick.classifier.class_balance

# yellowbrick.classifier.class_balance
# Class balance visualizer for showing per-class support.
#
# Author:   Rebecca Bilbro <[email protected]>
# Author:   Benjamin Bengfort <[email protected]>
# Author:   Neal Humphrey
# Created:  Wed May 18 12:39:40 2016 -0400
#
# Copyright (C) 2017 District Data Labs
# For license information, see LICENSE.txt
#
# ID: class_balance.py [5388065] [email protected] $

"""
Class balance visualizer for showing per-class support.
"""

##########################################################################
## Imports
##########################################################################

import numpy as np

from .base import ClassificationScoreVisualizer

from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support


##########################################################################
## Class Balance Chart
##########################################################################

[docs]class ClassBalance(ClassificationScoreVisualizer): """ Class balance chart that shows the support for each class in the fitted classification model displayed as a bar plot. It is initialized with a fitted model and generates a class balance chart on draw. Parameters ---------- ax: axes the axis to plot the figure on. model: estimator Scikit-Learn estimator object. Should be an instance of a classifier, else ``__init__()`` will raise an exception. classes: list A list of class names for the legend. If classes is None and a y value is passed to fit then the classes are selected from the target vector. kwargs: dict Keyword arguments passed to the super class. Here, used to colorize the bars in the histogram. Notes ----- These parameters can be influenced later on in the visualization process, but can and should be set as early as possible. """
[docs] def score(self, X, y=None, **kwargs): """ Generates the Scikit-Learn precision_recall_fscore_support Parameters ---------- X : ndarray or DataFrame of shape n x m A matrix of n instances with m features y : ndarray or Series of length n An array or series of target or class values Returns ------- ax : the axis with the plotted figure """ y_pred = self.predict(X) self.scores = precision_recall_fscore_support(y, y_pred) self.support = dict(zip(self.classes_, self.scores[-1])) return self.draw()
[docs] def draw(self): """ Renders the class balance chart across the axis. Returns ------- ax : the axis with the plotted figure """ #TODO: Would rather not have to set the colors with this method. # Refactor to make better use of yb_palettes module? colors = self.colors[0:len(self.classes_)] self.ax.bar( np.arange(len(self.support)), self.support.values(), color=colors, align='center', width=0.5 ) return self.ax
[docs] def finalize(self, **kwargs): """ Finalize executes any subclass-specific axes finalization steps. The user calls poof and poof calls finalize. Parameters ---------- kwargs: generic keyword arguments. """ # Set the title self.set_title('Class Balance for {}'.format(self.name)) # Set the x ticks with the class names self.ax.set_xticks(np.arange(len(self.support))) self.ax.set_xticklabels(self.support.keys()) # Compute the ceiling for the y limit cmax, cmin = max(self.support.values()), min(self.support.values()) self.ax.set_ylim(0, cmax + cmax* 0.1)
def class_balance(model, X, y=None, ax=None, classes=None, **kwargs): """Quick method: Displays the support for each class in the fitted classification model displayed as a bar plot. This helper function is a quick wrapper to utilize the ClassBalance ScoreVisualizer for one-off analysis. Parameters ---------- X : ndarray or DataFrame of shape n x m A matrix of n instances with m features. y : ndarray or Series of length n An array or series of target or class values. ax : matplotlib axes The axes to plot the figure on. model : the Scikit-Learn estimator (should be a classifier) classes : list of strings The names of the classes in the target Returns ------- ax : matplotlib axes Returns the axes that the class balance plot was drawn on. """ # Instantiate the visualizer visualizer = ClassBalance(model, ax, classes, **kwargs) # Create the train and test splits X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) # Fit and transform the visualizer (calls draw) visualizer.fit(X_train, y_train, **kwargs) visualizer.score(X_test, y_test) # Return the axes object on the visualizer return visualizer.ax