%load_ext autoreload
%autoreload 2
from vflow import Vset, build_vset, init_args, dict_to_df, perturbation_stats
from vflow.pipeline import build_graph
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.utils import resample
from sklearn.impute import KNNImputer, SimpleImputer
import numpy as np
import pandas as pd
import ray
pd.options.display.max_rows = 8
np.random.seed(31415)
The autoreload extension is already loaded. To reload it, use: %reload_ext autoreload
Feature Importance Stability¶
In this example, we will probe the the stability of the permutation feature importance metric for random forest relative to data resampling, data preprocessing, and model hyperparameter perturbations.
vflow
supports automatic parallelization using ray
. We can use ray
to compute downstream results by setting is_async=True
when constructing a Vset
:
ray.init(num_cpus=4)
2022-11-15 12:59:28,038 INFO worker.py:1528 -- Started a local Ray instance.
Ray
Python version: | 3.10.6 |
Ray version: | 2.1.0 |
Define stability target¶
Below, we create a Vset
which applies three custom data preprocessing functions and another that calculates the permutation importance metric via the function sklearn.inspection.permutation_importance
.
from vflow import Vset
from sklearn.impute import KNNImputer, SimpleImputer
from sklearn.inspection import permutation_importance
preproc_list = [SimpleImputer(strategy='mean'),
SimpleImputer(strategy='median'),
KNNImputer()]
# create a Vset which varies over the list of preprocessing methods
preproc_set = Vset("preproc", preproc_list, ['mean', 'med', 'knn'],
output_matching=True, lazy=True)
# create the feature importance Vset
feat_imp_set = build_vset('feat_imp', permutation_importance,
n_repeats=3, is_async=True)
Define model hyperparameter perturbations¶
We can also specify modeling perturbations, both within a single class of models (hyperparameter perturbations) and across different classes. Here we'll use the helper build_vset
to create hyperparameter perturbations for random forest.
from sklearn.ensemble import RandomForestRegressor as RF
from vflow import build_vset
# hyperparameters to try
RF_params = {
'n_estimators': [100, 300],
'min_samples_split': [2, 10]
}
# we could instead pass a list of distinct models and corresponding param dicts
RF_set = build_vset('RF', RF, RF_params, is_async=True)
Define data perturbations¶
For stability analysis, it is often useful to add data perturbations such as the bootstrap in order to assess stability over resampling variability in the data. We can lazily compute the bootstrap, such that data will not be resampled until needed, by setting lazy=True
when constructing a Vset
:
from sklearn.utils import resample
# create a Vset for bootstrapping from data 10 times
# we use lazy=True so that the data will not be resampled until needed
boot_set = build_vset('boot', resample, reps=10, lazy=True)
Fit all models for all combinations of resampling and preprocessing¶
Now we can load in our data and fit each of the four random forest models to the 300 combinations of resampled training data and preprocessing functions.
X, y = make_regression(n_samples=1000, n_features=3, n_informative=1)
# 20% of X entries missing
i = np.random.randint(X.shape[0], size=round(X.shape[0]*X.shape[1] * 0.2))
j = np.random.randint(X.shape[1], size=i.size)
X[i, j] = np.nan
X_trainval, X_test, y_trainval, y_test = train_test_split(X, y)
X_train, X_val, y_train, y_val = train_test_split(X_trainval, y_trainval)
X_train, y_train = init_args([X_train, y_train], names=['X_train', 'y_train'])
X_val, y_val = init_args([X_val, y_val], names=['X_val', 'y_val'])
# bootstrap from training data by calling boot_fun
X_trains, y_trains = boot_set(X_train, y_train)
# apply three preprocessing methods to each bootstrap sample
X_trains = preproc_set.fit_transform(X_trains)
%%time
# this results in fitting all 4 RF models to each of the 30 boot/preproc combos
RF_set.fit(X_trains, y_trains)
CPU times: user 1.19 s, sys: 230 ms, total: 1.42 s Wall time: 8.57 s
<vflow.vset.Vset at 0x7f2cd8267040>
We can examine the pipeline graph to see what happened so far using the utility function build_graph
:
from vflow import build_graph
# examine the pipeline graph
build_graph(RF_set)
<networkx.classes.digraph.DiGraph at 0x7f2d223866b0>
Calculate feature importances and perturbation statistics¶
Finally, we calculate the importance metric and examine its mean and standard deviation across bootstrap perturbations for each combination of data preprocessing and modeling hyperparameters. This allows us to assess the stability of the feature importances conditioned on different pipeline paths:
%%time
from vflow import dict_to_df, perturbation_stats
# calculate importances
importances = feat_imp_set.evaluate(RF_set.fitted_vfuncs, preproc_set.fit_transform(X_val), y_val)
CPU times: user 1.58 s, sys: 0 ns, total: 1.58 s Wall time: 4.52 s
Multiple outputs can be split with dict_to_df
using param_key='out'
. We use it below to split feature importances into mean and std:
# the helper dict_to_df converts the output to a pandas.DataFrame
# using param_key='out' separates the importance dict into multiple cols
importances_df = dict_to_df(importances, param_key='out')
importances_df
/home/james/repos/Yu-Group/veridical-flow/vflow/utils.py:225: FutureWarning: DataFrame.set_axis 'inplace' keyword is deprecated and will be removed in a future version. Use `obj = obj.set_axis(..., copy=False)` instead df.set_axis(cols, axis=1, inplace=True)
init-boot | boot | init-preproc | preproc | init-RF | RF | init-feat_imp | init-feat_imp | init-feat_imp | feat_imp | out | out-importances_mean | out-importances_std | out-importances | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | X_train | (rep=0, func=resample) | X_train | mean | y_train | (func=RandomForestRegressor, n_estimators=100,... | X_val | X_val | y_val | feat_imp_0 | {'importances_mean': [-0.009932776917497721, -... | [-0.009932776917497721, -0.04070469572987978, ... | [0.02807632315272223, 0.0207214174976573, 0.17... | [[-0.043482796313648375, -0.011548445063171542... |
1 | X_train | (rep=1, func=resample) | X_train | mean | y_train | (func=RandomForestRegressor, n_estimators=100,... | X_val | X_val | y_val | feat_imp_0 | {'importances_mean': [0.00160611171388166, -0.... | [0.00160611171388166, -0.008456642182130492, 1... | [0.006726329191110503, 0.0022083045870299653, ... | [[0.000407542364921909, -0.0059669857746053045... |
2 | X_train | (rep=2, func=resample) | X_train | mean | y_train | (func=RandomForestRegressor, n_estimators=100,... | X_val | X_val | y_val | feat_imp_0 | {'importances_mean': [0.055381280590333426, 0.... | [0.055381280590333426, 0.02134881478117941, 1.... | [0.040643441849098585, 0.012224976089989572, 0... | [[0.04922666330000902, 0.10795025078083897, 0.... |
3 | X_train | (rep=3, func=resample) | X_train | mean | y_train | (func=RandomForestRegressor, n_estimators=100,... | X_val | X_val | y_val | feat_imp_0 | {'importances_mean': [0.011021345773666757, 0.... | [0.011021345773666757, 0.00853453502359014, 1.... | [0.016820928151262623, 0.021516585395633524, 0... | [[-0.012280288405732698, 0.02681834414642581, ... |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
116 | X_train | (rep=6, func=resample) | X_train | knn | y_train | (func=RandomForestRegressor, n_estimators=300,... | X_val | X_val | y_val | feat_imp_0 | {'importances_mean': [0.007412483189494686, -0... | [0.007412483189494686, -0.007532937920816736, ... | [0.01412085232161868, 0.023437001839893418, 0.... | [[-0.002858726987412852, 0.027379621798113818,... |
117 | X_train | (rep=7, func=resample) | X_train | knn | y_train | (func=RandomForestRegressor, n_estimators=300,... | X_val | X_val | y_val | feat_imp_0 | {'importances_mean': [-0.011173182297858797, 0... | [-0.011173182297858797, 0.0028564120945974527,... | [0.02007232767747927, 0.0018915591706521987, 0... | [[-0.014608678746148351, 0.014947343036760374,... |
118 | X_train | (rep=8, func=resample) | X_train | knn | y_train | (func=RandomForestRegressor, n_estimators=300,... | X_val | X_val | y_val | feat_imp_0 | {'importances_mean': [0.004123171923337147, -0... | [0.004123171923337147, -0.03592749164760323, 1... | [0.008854347365617389, 0.0032036541951960813, ... | [[0.015223634066613423, 0.0035913077235798507,... |
119 | X_train | (rep=9, func=resample) | X_train | knn | y_train | (func=RandomForestRegressor, n_estimators=300,... | X_val | X_val | y_val | feat_imp_0 | {'importances_mean': [0.0185297391983702, -0.0... | [0.0185297391983702, -0.015749514121496316, 1.... | [0.012357649757708106, 0.010695031046462512, 0... | [[0.0265756397455037, 0.027942371330881066, 0.... |
120 rows × 14 columns
We can compute statistics on a single iterable item of the output by passing wrt=out-col
and split=True
:
# get count, mean, and std of importances
perturbation_stats(importances_df, 'preproc', 'RF', wrt='out-importances_mean', prefix='X', split=True)
preproc | RF | X-count | X0-mean | X0-std | X1-mean | X1-std | X2-mean | X2-std | |
---|---|---|---|---|---|---|---|---|---|
0 | knn | (func=RandomForestRegressor, n_estimators=100,... | 10 | 0.004863 | 0.010259 | -0.018839 | 0.026793 | 1.354068 | 0.080347 |
1 | knn | (func=RandomForestRegressor, n_estimators=100,... | 10 | 0.007366 | 0.012449 | -0.022356 | 0.031232 | 1.382703 | 0.125563 |
2 | knn | (func=RandomForestRegressor, n_estimators=300,... | 10 | 0.002487 | 0.013445 | -0.017375 | 0.021144 | 1.366897 | 0.054302 |
3 | knn | (func=RandomForestRegressor, n_estimators=300,... | 10 | 0.002941 | 0.009406 | -0.023427 | 0.026766 | 1.408808 | 0.098899 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
8 | med | (func=RandomForestRegressor, n_estimators=100,... | 10 | -0.000721 | 0.018412 | -0.008172 | 0.025700 | 1.483120 | 0.074220 |
9 | med | (func=RandomForestRegressor, n_estimators=100,... | 10 | 0.001824 | 0.017672 | -0.008696 | 0.028163 | 1.496330 | 0.081837 |
10 | med | (func=RandomForestRegressor, n_estimators=300,... | 10 | -0.000831 | 0.018208 | -0.006548 | 0.018171 | 1.493522 | 0.100362 |
11 | med | (func=RandomForestRegressor, n_estimators=300,... | 10 | 0.005192 | 0.011888 | -0.008944 | 0.017066 | 1.492519 | 0.075081 |
12 rows × 9 columns
From here, we can (optionally) filter over the data preprocessing and modeling perturbations via the helper filter_vset_by_metric
to select the top combinations in terms of stability (or another metric of interest) and continue our analysis on a held-out test set.
ray.shutdown()