In [10]:
%load_ext autoreload
%autoreload 2

from vflow import Vset, build_vset, init_args, dict_to_df, perturbation_stats
from vflow.pipeline import build_graph

from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.utils import resample

from sklearn.impute import KNNImputer, SimpleImputer

import numpy as np
import pandas as pd

import ray

pd.options.display.max_rows = 8
np.random.seed(31415)
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload

Feature Importance Stability¶

In this example, we will probe the the stability of the permutation feature importance metric for random forest relative to data resampling, data preprocessing, and model hyperparameter perturbations.

vflow supports automatic parallelization using ray. We can use ray to compute downstream results by setting is_async=True when constructing a Vset:

In [11]:
ray.init(num_cpus=4)
2022-11-15 12:59:28,038	INFO worker.py:1528 -- Started a local Ray instance.
Out[11]:

Ray

Python version: 3.10.6
Ray version: 2.1.0

Define stability target¶

Below, we create a Vset which applies three custom data preprocessing functions and another that calculates the permutation importance metric via the function sklearn.inspection.permutation_importance.

In [12]:
from vflow import Vset
from sklearn.impute import KNNImputer, SimpleImputer
from sklearn.inspection import permutation_importance

preproc_list = [SimpleImputer(strategy='mean'),
                SimpleImputer(strategy='median'),
                KNNImputer()]

# create a Vset which varies over the list of preprocessing methods
preproc_set = Vset("preproc", preproc_list, ['mean', 'med', 'knn'], 
                   output_matching=True, lazy=True)

# create the feature importance Vset
feat_imp_set = build_vset('feat_imp', permutation_importance, 
                          n_repeats=3, is_async=True)

Define model hyperparameter perturbations¶

We can also specify modeling perturbations, both within a single class of models (hyperparameter perturbations) and across different classes. Here we'll use the helper build_vset to create hyperparameter perturbations for random forest.

In [13]:
from sklearn.ensemble import RandomForestRegressor as RF
from vflow import build_vset

# hyperparameters to try
RF_params = {
    'n_estimators': [100, 300],
    'min_samples_split': [2, 10]
}

# we could instead pass a list of distinct models and corresponding param dicts
RF_set = build_vset('RF', RF, RF_params, is_async=True)

Define data perturbations¶

For stability analysis, it is often useful to add data perturbations such as the bootstrap in order to assess stability over resampling variability in the data. We can lazily compute the bootstrap, such that data will not be resampled until needed, by setting lazy=True when constructing a Vset:

In [14]:
from sklearn.utils import resample

# create a Vset for bootstrapping from data 10 times
# we use lazy=True so that the data will not be resampled until needed
boot_set = build_vset('boot', resample, reps=10, lazy=True)

Fit all models for all combinations of resampling and preprocessing¶

Now we can load in our data and fit each of the four random forest models to the 300 combinations of resampled training data and preprocessing functions.

In [15]:
X, y = make_regression(n_samples=1000, n_features=3, n_informative=1)

# 20% of X entries missing
i = np.random.randint(X.shape[0], size=round(X.shape[0]*X.shape[1] * 0.2))
j = np.random.randint(X.shape[1], size=i.size)
X[i, j] = np.nan

X_trainval, X_test, y_trainval, y_test = train_test_split(X, y)
X_train, X_val, y_train, y_val = train_test_split(X_trainval, y_trainval)

X_train, y_train = init_args([X_train, y_train], names=['X_train', 'y_train'])
X_val, y_val = init_args([X_val, y_val], names=['X_val', 'y_val'])
In [16]:
# bootstrap from training data by calling boot_fun
X_trains, y_trains = boot_set(X_train, y_train)

# apply three preprocessing methods to each bootstrap sample
X_trains = preproc_set.fit_transform(X_trains)
In [17]:
%%time

# this results in fitting all 4 RF models to each of the 30 boot/preproc combos
RF_set.fit(X_trains, y_trains)
CPU times: user 1.19 s, sys: 230 ms, total: 1.42 s
Wall time: 8.57 s
Out[17]:
<vflow.vset.Vset at 0x7f2cd8267040>

We can examine the pipeline graph to see what happened so far using the utility function build_graph:

In [18]:
from vflow import build_graph

# examine the pipeline graph
build_graph(RF_set)
Out[18]:
<networkx.classes.digraph.DiGraph at 0x7f2d223866b0>
No description has been provided for this image

Calculate feature importances and perturbation statistics¶

Finally, we calculate the importance metric and examine its mean and standard deviation across bootstrap perturbations for each combination of data preprocessing and modeling hyperparameters. This allows us to assess the stability of the feature importances conditioned on different pipeline paths:

In [19]:
%%time

from vflow import dict_to_df, perturbation_stats

# calculate importances
importances = feat_imp_set.evaluate(RF_set.fitted_vfuncs, preproc_set.fit_transform(X_val), y_val)
CPU times: user 1.58 s, sys: 0 ns, total: 1.58 s
Wall time: 4.52 s

Multiple outputs can be split with dict_to_df using param_key='out'. We use it below to split feature importances into mean and std:

In [20]:
# the helper dict_to_df converts the output to a pandas.DataFrame
# using param_key='out' separates the importance dict into multiple cols
importances_df = dict_to_df(importances, param_key='out')
importances_df
/home/james/repos/Yu-Group/veridical-flow/vflow/utils.py:225: FutureWarning: DataFrame.set_axis 'inplace' keyword is deprecated and will be removed in a future version. Use `obj = obj.set_axis(..., copy=False)` instead
  df.set_axis(cols, axis=1, inplace=True)
Out[20]:
init-boot boot init-preproc preproc init-RF RF init-feat_imp init-feat_imp init-feat_imp feat_imp out out-importances_mean out-importances_std out-importances
0 X_train (rep=0, func=resample) X_train mean y_train (func=RandomForestRegressor, n_estimators=100,... X_val X_val y_val feat_imp_0 {'importances_mean': [-0.009932776917497721, -... [-0.009932776917497721, -0.04070469572987978, ... [0.02807632315272223, 0.0207214174976573, 0.17... [[-0.043482796313648375, -0.011548445063171542...
1 X_train (rep=1, func=resample) X_train mean y_train (func=RandomForestRegressor, n_estimators=100,... X_val X_val y_val feat_imp_0 {'importances_mean': [0.00160611171388166, -0.... [0.00160611171388166, -0.008456642182130492, 1... [0.006726329191110503, 0.0022083045870299653, ... [[0.000407542364921909, -0.0059669857746053045...
2 X_train (rep=2, func=resample) X_train mean y_train (func=RandomForestRegressor, n_estimators=100,... X_val X_val y_val feat_imp_0 {'importances_mean': [0.055381280590333426, 0.... [0.055381280590333426, 0.02134881478117941, 1.... [0.040643441849098585, 0.012224976089989572, 0... [[0.04922666330000902, 0.10795025078083897, 0....
3 X_train (rep=3, func=resample) X_train mean y_train (func=RandomForestRegressor, n_estimators=100,... X_val X_val y_val feat_imp_0 {'importances_mean': [0.011021345773666757, 0.... [0.011021345773666757, 0.00853453502359014, 1.... [0.016820928151262623, 0.021516585395633524, 0... [[-0.012280288405732698, 0.02681834414642581, ...
... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
116 X_train (rep=6, func=resample) X_train knn y_train (func=RandomForestRegressor, n_estimators=300,... X_val X_val y_val feat_imp_0 {'importances_mean': [0.007412483189494686, -0... [0.007412483189494686, -0.007532937920816736, ... [0.01412085232161868, 0.023437001839893418, 0.... [[-0.002858726987412852, 0.027379621798113818,...
117 X_train (rep=7, func=resample) X_train knn y_train (func=RandomForestRegressor, n_estimators=300,... X_val X_val y_val feat_imp_0 {'importances_mean': [-0.011173182297858797, 0... [-0.011173182297858797, 0.0028564120945974527,... [0.02007232767747927, 0.0018915591706521987, 0... [[-0.014608678746148351, 0.014947343036760374,...
118 X_train (rep=8, func=resample) X_train knn y_train (func=RandomForestRegressor, n_estimators=300,... X_val X_val y_val feat_imp_0 {'importances_mean': [0.004123171923337147, -0... [0.004123171923337147, -0.03592749164760323, 1... [0.008854347365617389, 0.0032036541951960813, ... [[0.015223634066613423, 0.0035913077235798507,...
119 X_train (rep=9, func=resample) X_train knn y_train (func=RandomForestRegressor, n_estimators=300,... X_val X_val y_val feat_imp_0 {'importances_mean': [0.0185297391983702, -0.0... [0.0185297391983702, -0.015749514121496316, 1.... [0.012357649757708106, 0.010695031046462512, 0... [[0.0265756397455037, 0.027942371330881066, 0....

120 rows × 14 columns

We can compute statistics on a single iterable item of the output by passing wrt=out-col and split=True:

In [21]:
# get count, mean, and std of importances
perturbation_stats(importances_df, 'preproc', 'RF', wrt='out-importances_mean', prefix='X', split=True)
Out[21]:
preproc RF X-count X0-mean X0-std X1-mean X1-std X2-mean X2-std
0 knn (func=RandomForestRegressor, n_estimators=100,... 10 0.004863 0.010259 -0.018839 0.026793 1.354068 0.080347
1 knn (func=RandomForestRegressor, n_estimators=100,... 10 0.007366 0.012449 -0.022356 0.031232 1.382703 0.125563
2 knn (func=RandomForestRegressor, n_estimators=300,... 10 0.002487 0.013445 -0.017375 0.021144 1.366897 0.054302
3 knn (func=RandomForestRegressor, n_estimators=300,... 10 0.002941 0.009406 -0.023427 0.026766 1.408808 0.098899
... ... ... ... ... ... ... ... ... ...
8 med (func=RandomForestRegressor, n_estimators=100,... 10 -0.000721 0.018412 -0.008172 0.025700 1.483120 0.074220
9 med (func=RandomForestRegressor, n_estimators=100,... 10 0.001824 0.017672 -0.008696 0.028163 1.496330 0.081837
10 med (func=RandomForestRegressor, n_estimators=300,... 10 -0.000831 0.018208 -0.006548 0.018171 1.493522 0.100362
11 med (func=RandomForestRegressor, n_estimators=300,... 10 0.005192 0.011888 -0.008944 0.017066 1.492519 0.075081

12 rows × 9 columns

From here, we can (optionally) filter over the data preprocessing and modeling perturbations via the helper filter_vset_by_metric to select the top combinations in terms of stability (or another metric of interest) and continue our analysis on a held-out test set.

In [22]:
ray.shutdown()
In [ ]: