This repository contains a pure Python implementation of a mixed effects random forest (MERF) algorithm. It can be used, out of the box, to fit a MERF model and predict with it.
The MERF model is:
y_i = f(X_i) + Z_i * b_i + e_i
b_i ~ N(0, D)
e_i ~ N(0, R_i)
for each cluster i out of n total clusters.
In the above:
The learned parameters in MERF are:
Note that one key assumption of the MERF model is that the random effect is linear. Though, this is limiting in some regards, it is still broadly useful for many problems. It is better than not modelling the random effect at all.
The algorithms implemented in this repo were developed by Ahlem Hajjem, Francois Bellavance, and Denis Larocque and published in a paper here. Many thanks to Ahlem and Denis for providing an R reference and aiding in the debugging of this code. Quick note, the published paper has a small typo in the update equation for sigma^2 which is corrected in the source code here.
The MERF code is modelled after scikit-learn estimators. To use, you instantiate a MERF object. As of 1.0, you can pass any non-linear estimator for the fixed effect. By default this is a scikit-learn random forest, but you can pass any model you wish that conforms to the scikit-learn estimator API, e.g. LightGBM, XGBoost, a properly wrapped PyTorch neural net,
Then you fit the model using training data. As of 1.0, you can also pass a validation set to see the validation performance on it. This is meant to feel similar to PyTorch where you can view the validation loss after each epoch of training. After fitting you can predict responses from data, either from known (cluster in training set) or new (cluster not in training set) clusters.
from merf import MERF merf = MERF() merf.fit(X_train, Z_train, clusters_train, y_train) y_hat = merf.predict(X_test, Z_test, clusters_test)
> from lightgbm import LGBMRegressor > lgbm = LGBMRegressor() > mrf_lgbm = MERF(lgbm, max_iterations=15) > mrf_lgbm.fit(X_train, Z_train, clusters_train, y_train, X_val, Z_val, clusters_val, y_val) > y_hat = merf.predict(X_test, Z_test, clusters_test)
Note that training is slow because the underlying expectation-maximization (EM) algorithm requires many calls to the non-linear fixed effects model, e.g. random forest. That being said, this implemtataion has early stopping which aborts the EM algorithm if the generalized log-likelihood (GLL) stops significantly improving.
\merf directory contains all the source code:
merf.pyis the key module that contains the MERF class. It is imported at the package level.
merf_test.pycontain some simple unit tests.
utils.pycontains a class for generating synthetic data that can be used to test the accuracy of MERF. The process implemented is the same as that in this paper.
viz.pycontains a plotting function that takes in a trained MERF object and plots various metrics of interest.
\notebooks directory contains some useful notebooks that show you how to use the code and evaluate MERF performance. Most of the techniques implemented are the same as those in this paper.