Time-to-event prediction with PyTorch
Get Started • Methods • Evaluation Criteria • Datasets • Installation • References
pycox is a python package for survival analysis and time-to-event prediction with PyTorch, built on the torchtuples package for training PyTorch models. An R version of this package is available at survivalmodels.
The package contains implementations of various survival models, some useful evaluation metrics, and a collection of event-time datasets.
In addition, some useful preprocessing tools are available in the pycox.preprocessing
module.
To get started you first need to install PyTorch. You can then install pycox via pip:
pip install pycox
OR, via conda:
conda install -c conda-forge pycox
We recommend to start with 01_introduction.ipynb, which explains the general usage of the package in terms of preprocessing, creation of neural networks, model training, and evaluation procedure.
The notebook use the LogisticHazard
method for illustration, but most of the principles generalize to the other methods.
Alternatively, there are many examples listed in the examples folder, or you can follow the tutorial based on the LogisticHazard
:
01_introduction.ipynb: General usage of the package in terms of preprocessing, creation of neural networks, model training, and evaluation procedure.
02_introduction.ipynb: Quantile based discretization scheme, nested tuples with tt.tuplefy
, entity embedding of categorical variables, and cyclical learning rates.
03_network_architectures.ipynb:
Extending the framework with custom networks and custom loss functions. The example combines an autoencoder with a survival network, and considers a loss that combines the autoencoder loss with the loss of the LogisticHazard
.
04_mnist_dataloaders_cnn.ipynb: Using dataloaders and convolutional networks for the MNIST data set. We repeat the simulations of [8] where each digit defines the scale parameter of an exponential distribution.
The following methods are available in the pycox.methods
module.
Method | Description | Example |
---|---|---|
CoxTime | Cox-Time is a relative risk model that extends Cox regression beyond the proportional hazards [1]. | notebook |
CoxCC | Cox-CC is a proportional version of the Cox-Time model [1]. | notebook |
CoxPH (DeepSurv) | CoxPH is a Cox proportional hazards model also referred to as DeepSurv [2]. | notebook |
PCHazard | The Piecewise Constant Hazard (PC-Hazard) model [12] assumes that the continuous-time hazard function is constant in predefined intervals. It is similar to the Piecewise Exponential Models [11] and PEANN [14], but with a softplus activation instead of the exponential function. | notebook |
Method | Description | Example |
---|---|---|
LogisticHazard (Nnet-survival) | The Logistic-Hazard method parametrize the discrete hazards and optimize the survival likelihood [12] [7]. It is also called Partial Logistic Regression [13] and Nnet-survival [8]. | notebook |
PMF | The PMF method parametrize the probability mass function (PMF) and optimize the survival likelihood [12]. It is the foundation of methods such as DeepHit and MTLR. | notebook |
DeepHit, DeepHitSingle | DeepHit is a PMF method with a loss for improved ranking that can handle competing risks [3]. | single competing |
MTLR (N-MTLR) | The (Neural) Multi-Task Logistic Regression is a PMF methods proposed by [9] and [10]. | notebook |
BCESurv | A method representing a set of binary classifiers that remove individuals as they are censored [15]. The loss is the binary cross entropy of the survival estimates at a set of discrete times, with targets that are indicators of surviving each time. | bs_example |
The following evaluation metrics are available with pycox.evalutation.EvalSurv
.
Metric | Description |
---|---|
concordance_td | The time-dependent concordance index evaluated at the event times [4]. |
brier_score | The IPCW Brier score (inverse probability of censoring weighted Brier score) [5][6][15]. See Section 3.1.2 of [15] for details. |
nbll | The IPCW (negative) binomial log-likelihood [5][1]. I.e., this is minus the binomial log-likelihood and should not be confused with the negative binomial distribution. The weighting is performed as in Section 3.1.2 of [15] for details. |
integrated_brier_score | The integrated IPCW Brier score. Numerical integration of the `brier_score` [5][6]. |
integrated_nbll | The integrated IPCW (negative) binomial log-likelihood. Numerical integration of the `nbll` [5][1]. |
brier_score_admin integrated_brier_score_admin | The administrative Brier score [15]. Works well for data with administrative censoring, meaning all censoring times are observed. See this example notebook. |
nbll_admin integrated_nbll_admin | The administrative (negative) binomial log-likelihood [15]. Works well for data with administrative censoring, meaning all censoring times are observed. See this example notebook. |
A collection of datasets are available through the pycox.datasets
module.
For example, the following code will download the metabric
dataset and load it in the form of a pandas dataframe
from pycox import datasets
df = datasets.metabric.read_df()
The datasets
module will store datasets under the installation directory by default. You can specify a different directory by setting the PYCOX_DATA_DIR
environment variable.
Dataset | Size | Dataset | Data source |
---|---|---|---|
flchain | 6,524 | The Assay of Serum Free Light Chain (FLCHAIN) dataset. See [1] for preprocessing. | source |
gbsg | 2,232 | The Rotterdam & German Breast Cancer Study Group. See [2] for details. | source |
kkbox | 2,814,735 | A survival dataset created from the WSDM - KKBox's Churn Prediction Challenge 2017 with administrative censoring. See [1] and [15] for details. Compared to kkbox_v1, this data set has more covariates and censoring times. Note: You need Kaggle credentials to access the dataset. | source |
kkbox_v1 | 2,646,746 | A survival dataset created from the WSDM - KKBox's Churn Prediction Challenge 2017. See [1] for details. This is not the preferred version of this data set. Use kkbox instead. Note: You need Kaggle credentials to access the dataset. | source |
metabric | 1,904 | The Molecular Taxonomy of Breast Cancer International Consortium (METABRIC). See [2] for details. | source |
nwtco | 4,028 | Data from the National Wilm's Tumor (NWTCO). | source |
support | 8,873 | Study to Understand Prognoses Preferences Outcomes and Risks of Treatment (SUPPORT). See [2] for details. | source |
Dataset | Size | Dataset | Data source |
---|---|---|---|
rr_nl_nph | 25,000 | Dataset from simulation study in [1]. This is a continuous-time simulation study with event times drawn from a relative risk non-linear non-proportional hazards model (RRNLNPH). | SimStudyNonLinearNonPH |
sac3 | 100,000 | Dataset from simulation study in [12]. This is a discrete time dataset with 1000 possible event-times. | SimStudySACCensorConst |
sac_admin5 | 50,000 | Dataset from simulation study in [15]. This is a discrete time dataset with 1000 possible event-times. Very similar to `sac3`, but with fewer survival covariates and administrative censoring determined by 5 covariates. | SimStudySACAdmin |
Note: This package is still in its early stages of development, so please don't hesitate to report any problems you may experience.
The package only works for python 3.6+.
Before installing pycox, please install PyTorch (version >= 1.1). You can then install the package with
pip install pycox
For the bleeding edge version, you can instead install directly from github (consider adding --force-reinstall
):
pip install git+git://github.com/havakv/pycox.git
Installation from source depends on PyTorch, so make sure a it is installed. Next, clone and install with
git clone https://github.com/havakv/pycox.git
cd pycox
pip install .
[1] Håvard Kvamme, Ørnulf Borgan, and Ida Scheel. Time-to-event prediction with neural networks and Cox regression. Journal of Machine Learning Research, 20(129):1–30, 2019. [paper]
[2] Jared L. Katzman, Uri Shaham, Alexander Cloninger, Jonathan Bates, Tingting Jiang, and Yuval Kluger. Deepsurv: personalized treatment recommender system using a Cox proportional hazards deep neural network. BMC Medical Research Methodology, 18(1), 2018. [paper]
[3] Changhee Lee, William R Zame, Jinsung Yoon, and Mihaela van der Schaar. Deephit: A deep learning approach to survival analysis with competing risks. In Thirty-Second AAAI Conference on Artificial Intelligence, 2018. [paper]
[4] Laura Antolini, Patrizia Boracchi, and Elia Biganzoli. A time-dependent discrimination index for survival data. Statistics in Medicine, 24(24):3927–3944, 2005. [paper]
[5] Erika Graf, Claudia Schmoor, Willi Sauerbrei, and Martin Schumacher. Assessment and comparison of prognostic classification schemes for survival data. Statistics in Medicine, 18(17-18):2529–2545, 1999. [paper]
[6] Thomas A. Gerds and Martin Schumacher. Consistent estimation of the expected brier score in general survival models with right-censored event times. Biometrical Journal, 48 (6):1029–1040, 2006. [paper]
[7] Charles C. Brown. On the use of indicator variables for studying the time-dependence of parameters in a response-time model. Biometrics, 31(4):863–872, 1975. [paper]
[8] Michael F. Gensheimer and Balasubramanian Narasimhan. A scalable discrete-time survival model for neural networks. PeerJ, 7:e6257, 2019. [paper]
[9] Chun-Nam Yu, Russell Greiner, Hsiu-Chin Lin, and Vickie Baracos. Learning patient- specific cancer survival distributions as a sequence of dependent regressors. In Advances in Neural Information Processing Systems 24, pages 1845–1853. Curran Associates, Inc., 2011. [paper]
[10] Stephane Fotso. Deep neural networks for survival analysis based on a multi-task framework. arXiv preprint arXiv:1801.05512, 2018. [paper]
[11] Michael Friedman. Piecewise exponential models for survival data with covariates. The Annals of Statistics, 10(1):101–113, 1982. [paper]
[12] Håvard Kvamme and Ørnulf Borgan. Continuous and discrete-time survival prediction with neural networks. arXiv preprint arXiv:1910.06724, 2019. [paper]
[13] Elia Biganzoli, Patrizia Boracchi, Luigi Mariani, and Ettore Marubini. Feed forward neural networks for the analysis of censored survival data: a partial logistic regression approach. Statistics in Medicine, 17(10):1169–1186, 1998. [paper]
[14] Marco Fornili, Federico Ambrogi, Patrizia Boracchi, and Elia Biganzoli. Piecewise exponential artificial neural networks (PEANN) for modeling hazard function with right censored data. Computational Intelligence Methods for Bioinformatics and Biostatistics, pages 125–136, 2014. [paper]
[15] Håvard Kvamme and Ørnulf Borgan. The Brier Score under Administrative Censoring: Problems and Solutions. arXiv preprint arXiv:1912.08581, 2019. [paper]