Letting accuracy and interpretability go hand in hand
Munir Hiabu
Motivation \(\cdots\) Random Planted forest \(\cdots\) Interpretability
General Conception
Conjecture: Interpretability can boost accuracy
1 The curse of dimensionality
We first generate 100 observations of the \(x_1\) variable which is uniformly distributed on [0,10].
A total of 42 observations are within a distance of 2 units from the middle point 5
We generate \(x_2\) independent of \(x_1\) and uniformly distributed on \([0,10]\).
We find 19 observations within a distance of 2 units from the middle point (5,5).
The \(x_3\) variable is independent of \((x_1,x_2)\) and uniformly distributed on \([0,10]\).
Only a total of 9 observation(s) fall within a distance of 2 units from the middle point (5,5,5).
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
There are two ways to tackle the curse of dimensionality
Assume that the intrinsic dimension is lower
Interactions are limited and structure can be exploited
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
There are two ways to tackle the curse of dimensionality
Assume that the intrinsic dimension is lower
Interactions are limited and structure can be exploited
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
2 Random Planted Forest
This is joint work with
Preprint is available on https://arxiv.org/abs/2012.14563
Assume a data set with \(d\) features. Also assume that we can approximate the regression function \(m\) by a (q-th) order functional decomposition:
\[m(x) \approx m_0+\sum_{k=1}^d m_k(x_{k}) + \sum_{k_1<k_2} m_{k_1k_2}(x_{k_1},x_{k_2}) + \cdots +\sum_{k_1<\cdots <k_q} m_{k_1,\dots,k_q} (x_{k_1},\dots,x_{k_q}).\]
Model general | general \(d\) | \(d = 6\) | Comparable sample sizes for \(d=6\) |
---|---|---|---|
Full model | \(O_p(n^{-2/(d+4)})\) | \(O_p(n^{-1/5})\) | 1 000 000 |
Interaction (q) | \(O_p(n^{-2/(q+4)})\) | \(O_p(n^{-2/(q+4)})\) | 1 000 – 1 000 000 |
Interaction (2) | \(O_p(n^{-1/3})\) | \(O_p(n^{-1/3})\) | 4 000 |
Additive | \(O_p(n^{-2/5})\) | \(O_p(n^{-2/5})\) | 1 000 |
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
\[m(x)=m_0+\sum_{k=1}^d m_k(x_{k}) + \sum_{k<l} m_{kl}(x_{k},x_{l}) + \sum_{j<k<l} m_{jkl}(x_{j},x_{k},x_{l})+ \cdots.\]
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
Split 1:
3 possible combinations:
\(t_{try}:\) \(0.6\times3=1.8 \rightarrow\) 2 viable combinations randomly picked, say:
\(split_{try}:\)For each viable split option we consider 5 randomly picked split points \(\rightarrow\) \(2\times5=10\) split options.
Compare the 10 split options: \(\sum_i (\widehat m(X_i) -Y_i)^2\)
\((m_2tree\rightarrow x_2,c_1)\) produces minimal least squares loss.
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
Split 1:
3 possible combinations:
\(t_{try}:\) \(0.6\times3=1.8 \rightarrow\) 2 viable combinations randomly picked, say:
\(split_{try}:\)For each viable split option we consider 5 randomly picked split points \(\rightarrow\) \(2\times5=10\) split options.
Compare the 10 split options: \(\sum_i (\widehat m(X_i) -Y_i)^2\)
\((m_2tree\rightarrow x_2,c_1)\) produces minimal least squares loss.
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
Split 2:
Now: 5 possible split combinations:
Viable split combinations: \(5\times0.6=3\) –> 3 options randomly picked, say
Hence \(5\)\(\times(1+2+1)=20\) split options.
\((m_0 \rightarrow m_3: x_3,c_2)\) produces minimal least squares loss.
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
Split 2:
Now: 5 possible split combinations:
Viable split combinations: \(5\times0.6=3\) –> 3 options randomly picked, say
Hence \(5\)\(\times(1+2+1)=20\) split options.
\((m_0 \rightarrow m_3: x_3,c_2)\) produces minimal least squares loss.
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
Split 3:
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
Split 3:
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
Split 4:
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
True function: black solid line
Grey lines: 40 Monte Carlo simulations
xgboost (additive=depth=1)
planted forest (additive= max interaction=1)
Simulations are run with optimal parameters
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
Method | Assumption | dim=4 | dim=10 | dim=30 |
---|---|---|---|---|
xgboost | additive | 0.119 (0.021) | 0.142 (0.021) | 0.176 (0.027) |
xgboost | - | 0.141 (0.024) | 0.166 (0.028) | 0.193 (0.033) |
xgboost-CV | - | 0.139 (0.028) | 0.152 (0.029) | 0.194 (0.035) |
rpf | additive | 0.087 (0.018) | 0.086 (0.017) | 0.097 (0.019) |
rpf | interaction(2) | 0.107 (0.015) | 0.121 (0.025) | 0.142 (0.026) |
rpf | - | 0.112 (0.017) | 0.134 (0.026) | 0.162 (0.028) |
rpf-CV | - | 0.103 (0.02) | 0.102 (0.035) | 0.105 (0.022) |
rf | - | 0.209 (0.021) | 0.252 (0.027) | 0.3 (0.029) |
sbf | additive+smooth | 0.071 (0.026) | 0.134 (0.013) | 0.388 (0.073) |
gam | additive+smooth | 0.033 (0.012) | 0.035 (0.013) | 0.058 (0.021) |
BART | - | 0.085 (0.019) | 0.076 (0.017) | 0.091 (0.023) |
BART-CV | - | 0.09 (0.019) | 0.081 (0.014) | 0.09 (0.02) |
MARS | smooth | 0.054 (0.014) | 0.061 (0.025) | 0.076 (0.031) |
1-NN | no noise | 1.509 (0.1) | 3.228 (0.182) | 5.534 (0.313) |
average | no covariates | 3.811 (0.217) | 3.689 (0.183) | 3.748 (0.202) |
Method | Assumption | dim=4 | dim=10 | dim=30 |
---|---|---|---|---|
xgboost | additive | 0.19 (0.029) | 0.282 (0.044) | 0.401 (0.045) |
xgboost | - | 0.198 (0.031) | 0.265 (0.053) | 0.286 (0.034) |
xgboost-CV | - | 0.209 (0.028) | 0.281 (0.052) | 0.313 (0.058) |
rpf | additive | 0.159 (0.033) | 0.198 (0.075) | 0.179 (0.041) |
rpf | interaction(2) | 0.185 (0.028) | 0.24 (0.066) | 0.259 (0.043) |
rpf | - | 0.192 (0.026) | 0.251 (0.065) | 0.282 (0.043) |
rpf-CV | - | 0.169 (0.033) | 0.207 (0.072) | 0.183 (0.042) |
rf | - | 0.274 (0.035) | 0.322 (0.05) | 0.375 (0.037) |
sbf | additive+smooth | 0.342 (0.049) | 0.603 (0.053) | 1.112 (0.138) |
gam | additive+smooth | 0.41 (0.047) | 0.406 (0.027) | 0.431 (0.06) |
BART | - | 0.177 (0.047) | 0.162 (0.038) | 0.157 (0.034) |
BART-CV | - | 0.179 (0.051) | 0.163 (0.041) | 0.159 (0.036) |
MARS | smooth | 0.751 (0.136) | 0.74 (0.104) | 0.687 (0.123) |
1-NN | no noise | 2.393 (0.229) | 3.029 (0.308) | 3.512 (0.333) |
average | no covariates | 1.276 (0.075) | 1.25 (0.063) | 1.213 (0.054) |
Method | Assumption | dim=4 | dim=10 | dim=30 |
---|---|---|---|---|
xgboost | - | 0.374 (0.035) | 0.481 (0.064) | 0.557 (0.089) |
xgboost-CV | - | 0.393 (0.051) | 0.499 (0.058) | 0.563 (0.089) |
rpf | interaction(2) | 0.248 (0.038) | 0.327 (0.045) | 0.408 (0.07) |
rpf | - | 0.263 (0.034) | 0.357 (0.044) | 0.452 (0.076) |
rpf-CV | - | 0.277 (0.039) | 0.366 (0.051) | 0.463 (0.083) |
rf | - | 0.432 (0.039) | 0.575 (0.061) | 0.671 (0.08) |
BART | - | 0.214 (0.03) | 0.223 (0.04) | 0.252 (0.037) |
BART-CV | - | 0.242 (0.043) | 0.276 (0.053) | 0.315 (0.047) |
MARS | smooth | 0.355 (0.089) | 0.282 (0.038) | 0.414 (0.126) |
1-NN | no noise | 2.068 (0.156) | 5.988 (0.624) | 11.059 (0.676) |
average | no covariates | 8.366 (0.43) | 8.086 (0.246) | 8.207 (0.496) |
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
Method | Assumption | dim=4 | dim=10 | dim=30 |
---|---|---|---|---|
xgboost | - | 0.417 (0.082) | 0.797 (0.16) | 1.381 (0.234) |
xgboost-CV | - | 0.443 (0.078) | 0.872 (0.136) | 1.497 (0.326) |
rpf | interaction(2) | 0.416 (0.082) | 1.289 (0.224) | 1.822 (0.208) |
rpf | - | 0.219 (0.035) | 0.556 (0.143) | 1.186 (0.236) |
rpf-CV | - | 0.233 (0.033) | 0.603 (0.163) | 1.313 (0.253) |
rf | - | 0.304 (0.047) | 0.744 (0.305) | 1.295 (0.317) |
BART | - | 0.168 (0.022) | 0.172 (0.032) | 0.202 (0.021) |
BART-CV | - | 0.192 (0.03) | 0.199 (0.039) | 0.223 (0.025) |
MARS | smooth | 0.245 (0.088) | 0.831 (0.728) | 0.429 (0.403) |
1-NN | no noise | 1.323 (0.117) | 2.642 (0.317) | 4.173 (0.413) |
average | no covariates | 2.187 (0.125) | 2.226 (0.174) | 2.177 (0.146) |
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
This is joint work with
Code is available on https://github.com/PlantedML/randomPlantedForest
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
3 Interpretability
This is joint work with
Preprint is available on https://arxiv.org/abs/2208.06151
Code is available on https://github.com/PlantedML/glex
A Random Planted Forest with max-interaction=2 is interpretable in the sense the we can plot the one and two dimensional components.
Random Planted Forest is not (yet) interpretable in a strong sense: What is the meaning of the components.
First: The components \(m_0, m_k(x_{k}), m_{k_1k_2}(x_{k_1},x_{k_2}), \dots\) are not yet identified …
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
Two possibilities are
Common choices for \(w\) are
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
Two possibilities are
Common choices for \(w\) are
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
For the rest of this talk we will consider the marginal identification
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
The are three reasons why the marginal identification is particularly interesting
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
A partial dependence plot, \(\xi_S\), is defined as \[\xi_S(x_S)= \int \hat m(x) p_{-S}(x_{-S})\mathrm dx_{-S}.\]
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
Assume \(U\) is a set of protected features.
Let \(U \cup V=\{1,\dots, d\}, U\cap V=\emptyset\).
The do-operator, \(do(X_V=x_V)\), removes all edges going into \(X_V\), ensuring counterfactual fairness (Kusner et al. 2017), see also Lindholm et al. (2022)
\(E[m(X) |\ do(X_V=x_V))\) does not use information contained in \(X_U\); neither directly nor indirectly.
Under the assumed causal structure we have \[E[m(X) |\ do(X_V=x_V)]= \int m(x) p_U(x_U) dx_U.\]
Under marginal identifiaction: \[\int \hat m^\ast(x) \hat p_U(x_U) dx_U= \sum_{S: S\cap U =\emptyset}\int \hat m^\ast_S(x_S) \hat p_U(x_U) dx_U + \sum_{S: S\cap U \neq \emptyset}\int \hat m^\ast_S(x_S) \hat p_U(x_U) dx_U= \sum_{S \subseteq V} \hat m^\ast_S(x_S),\] i.e., a “fair” estimator can be extracted from \(\hat m\) by dropping all components that include features in \(U\).
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
We have made available a fast implementation for extracting the marginal identification from xgboost on github: https://github.com/PlantedML/glex.
A version for Random Planted forest will be available soon.
We simulate 10,000 noisy observations from
\[m(x_1,x_2)=x_1+x_2+x_1x_2\] If \(X_1, X_2\) have each mean zero and variance one, then
The curse of dimensionality \(\cdots\) Random Planted forest \(\cdots\) Interpretability
Thank You!