Here we give a brief introduction to using Multi-task Logistic Regression (MTLR) for survival prediction. Note that MTLR was specifically designed to give survival probabilities across a range of times for individual observations. This differs from models which produce risk scores (such as those given by Cox proportional hazards), single time probability models (such as the Gail model), and population wide models (e.g. Kaplan-Meier curves). Producing survival probabilities over a range of times gives a more holistic view of survival to patients and physicians which may be critical in making healthcare decisions.
MTLR was introduced first in 2011 at NIPS under the name, “Learning Patient-Specific Cancer Survival Distributions as a Sequence of Dependent Regressors”. Since then much work has been done including a website which can be used to build MTLR models on uploaded data. While this is an extremely beneficial resource we have extended MTLR to be included in the R environment to make comparisons to other survival methods and use tools included in other R packages, such as survival
and randomForestSRC
.
MTLR can be used for survival data containing right, left, interval, or no censoring. In addition, these types of censoring can be mixed in the same dataset. Documentation on utilizing these different types of censoring can be found using help(mtlr)
. In this vignette we will consider an example which includes right censoring only. Namely, we will be using the lung
dataset from the survival
package.
One can access the lung
dataset by loading the survival
package.
library(survival)
#Looking at the top 6 rows...
head(lung)
#> inst time status age sex ph.ecog ph.karno pat.karno meal.cal wt.loss
#> 1 3 306 2 74 1 1 90 100 1175 NA
#> 2 3 455 2 68 1 0 90 90 1225 15
#> 3 3 1010 1 56 1 0 90 90 NA 15
#> 4 5 210 2 57 1 1 90 60 1150 11
#> 5 1 883 2 60 1 0 100 90 NA 0
#> 6 12 1022 1 74 1 1 50 80 513 0
#help(lung) #See the basic information of lung.
If you look at the help file for lung
you will see the following feature definitions:
Most importantly you will notice the two features needed for every survival dataset for use of MTLR – an event time (here time
), and the indicator identifying if an observation is uncensored/censored (here status
). For this example we have status == 1
indicating a right censored individual and status == 2
indicating an uncensored individual. Later on we will be using the Surv
function to structure our survival data for MTLR – there are other acceptable formats for the indicator feature (status
) – see help(Surv)
for more information.
We will remove inst
for this example since this is a categorical feature with 19 unique values and we would like to keep the number of features relatively small.
lung <- lung[,-1]
Before progressing any further we will split our data into a training and testing set. Note that we could stratify our training/testing set by the censor status but for simplicity we skip that for now.
numberTrain <- floor(nrow(lung)*0.8)
set.seed(42)
trInd <- sample(1:nrow(lung), numberTrain)
training <- lung[trInd,]
testing <- lung[-trInd,]
You may also notice that there are some missing values in the data, namely in meal.cal
and wt.loss
(although ph.ecog
, ph.karno
, and pat.karno
also have missing values). The MTLR package does not handle missing values for users so this must be pre-processed ahead of time. If one passes in data which contains missing values anyway, all rows with missing values will be removed before model training/predictions. To remedy this problem we perform a very basic mean imputation on the dataset. Note that we use the means from the training set to impute the test set.
#Perform imputation
trMeans <- colMeans(training,na.rm=T)
for(i in 1:ncol(training)){
training[is.na(training[,i]), i] <- trMeans[i]
testing[is.na(testing[,i]), i] <- trMeans[i]
}
Once the dataset has been prepared we can begin to play around with some of the functions found in the MTLR package. Most importantly we will be utilizing the mtlr
function to train our model. There are a number of arguments that can be used by mtlr
, though only a select few are discussed here. There are only two arguments required to train an mtlr
model, formula and data. For formula we must structure our event time feature and censor indicator feature using the Surv
function. Since we have time
and status
as these two features we can create our formula object:
formula <- Surv(time,status)~.
The above says we will be training a model on the survival object created from time
and status
and using all the other features in our dataset as predictors. If we wanted to select a few features we could do this as well, for example, with age
and sex
.
formulaSmall <- Surv(time, status)~age+sex
Next, we just need the data argument which in our case is training
. We can finally make our first model!
library(MTLR)
fullMod <- mtlr(formula = formula, data = training)
smallMod <- mtlr(formula = formulaSmall, data = training)
#We will print the small model so the output is more compact.
smallMod
#>
#> Call: mtlr(formula = formulaSmall, data = training)
#>
#> Time points:
#> [1] 60.6 101.2 155.4 177.0 192.7 210.6 235.4 269.5 291.8 310.6 353.0
#> [12] 386.2 455.1 553.6 688.4
#>
#>
#> Weights:
#> Bias age sex
#> 60.62 0.0714 0.05518 -0.0148
#> 101.25 0.0933 0.03607 -0.0275
#> 155.44 0.0863 0.02890 -0.0272
#> 177 -0.0761 0.03339 -0.0465
#> 192.69 0.1011 0.01274 -0.0566
#> 210.62 0.4183 0.01650 -0.0318
#> 235.38 -0.0482 0.00554 -0.0410
#> 269.5 -0.3877 -0.01542 -0.0284
#> 291.81 0.0233 0.01393 -0.0471
#> 310.62 -0.2513 0.02312 -0.0414
#> 353 0.0146 0.01510 -0.0233
#> 386.25 -0.4602 -0.00296 -0.0199
#> 455.12 0.4586 0.01033 -0.0248
#> 553.62 -0.5599 0.01773 -0.0263
#> 688.38 -0.2246 0.02091 -0.0308
There is a lot to take in at first from the output of the mtlr
model. The first item is simply the call that was used to build the model. Next is the time points that mtlr
used to train the model. If these time points are not specified when constructing the model then mtlr
will choose time points based on the quantiles of the event time feature. Additionally, the number of time points is chosen to be the sqrt(N) where N is the number of observations. Since we had 205 training instances and the sqrt(205 = 14.317) mtlr
rounded up to 15 time points.
Last, mtlr
outputs the weight matrix for the model – these are the weights corresponding to each feature at each time point (additionally notice that we include the bias weights). The row names correspond to the time point for which the feature weight belongs. If you would like to access these weights, they are saved in the model object as weight_matrix
so you can access them using smallMod$weight_matrix
.
We can also plot the weights for a mtlr
model. Before we printed the small model but here we will look at the weights for the complete model.
plot(fullMod)
By default, plot
will only look at the 5 features which had the largest sum of absolute values across time (the most influence). You can alter these specifications by playing with the arguments in plot
.
Now that we have trained a MTLR model we should make some predictions! This is where our testing
set and the predict
function will come into play. Note that there are a number of predictions we may be interested in acquiring. First, we may want to view the survival curves of our test observations.
survCurves <- predict(fullMod, testing, type = "survivalcurve")
#survCurves is pretty large so we will look at the first 5 rows/columns.
survCurves[1:5,1:5]
#> time 1 2 3 4
#> 1 0.0000 1.0000000 1.0000000 1.0000000 1.0000000
#> 2 60.6250 0.9197721 0.9515804 0.8685207 0.9201473
#> 3 101.2500 0.8427539 0.9092633 0.7588228 0.8429653
#> 4 155.4375 0.7720628 0.8716582 0.6670640 0.7726787
#> 5 177.0000 0.7080682 0.8373624 0.5913061 0.7086728
When we use the predict
function for survival curves we will be returned a matrix where the first column (time) is the list of time points that the model evaluated the survival probability for each observation (these will be the time points used by mtlr
and an additional 0 point). Every following column will correspond to the row number of the data passed in, e.g. column 2 (named 1) corresponds to row 1 of testing
. Each row of this matrix gives the probabilities of survival at the corresponding time point (given by the time column). For example, testing observation 1 has a survival probability of 0.919 at time 60.625.
Since these curves may be hard to digest by observing a matrix of survival probabilities we can also choose to plot them.
plotcurves(survCurves, 1:10)
Here we have specified that we want to observe the survival curves for the first 10 observations (corresponding to the first 10 rows of testing
). You will notice that these curves have been smoothed whereas before we only had probabilities for certain time points. We have performed a monotonic spline fit to those survival probabilities to produce the curves you see here.
Additionally, you may have specific plot specifications you want to make. plotcurves
is simply returning a ggplot2
object so specifications can be made like you would make to any other ggplot2
graphic. For example, plotcurves(survCurves, 1:10) + ggplot2::xlab("Days")
would change the x-axis label to “Days” instead of “Time”.
In addition to the entire survival curve one may also be interested in the average survival time. This is again available from the predict
function.
#Mean
meanSurv <- predict(fullMod, testing, type = "mean_time")
head(meanSurv)
#> [1] 319.0312 407.2113 262.0245 317.3220 327.1613 336.3179
#Median
medianSurv <- predict(fullMod, testing, type = "median_time")
head(medianSurv)
#> [1] 276.5360 378.9466 197.3210 274.5491 279.7247 296.5696
Here the mean survival time corresponds to the area under the survival curve of each observation. One subtlety is that many survival curves never touch zero probability making this area not well-defined. When this occurs, a linear fit is drawn from the time = 0, survival probability = 1 point to the last time point and extended to the 0 probability time. For example, below we have drawn a linear extension on the curves below to calculate the mean survival time.
This is also performed when calculating the median survival time if the last survival probability is above 0.5.
The last prediction type supported is acquiring the observations survival probability at the respective event time. However, in order to use this prediction, the event time (whether censored or uncensored) must be included in the features passed into the predict
function.
survivalProbs <- predict(fullMod, testing, type = "prob_event")
head(survivalProbs)
#> [1] 0.58742054 0.15956481 0.06579507 0.56822951 0.19219566 0.80750612
#To see what times these probabilities correspond to:
head(testing$time)
#> [1] 210 883 1022 218 567 144
You will notice that some of these survival probabilities correspond to 0 (usually those with very large event times). We again have drawn the linear extension for the survival time if the event time could not be mapped onto the survival curve.
mtlr_cv
Previously we just used the default settings of mtlr
. However, a number of things can be adjusted included the number of time points, the exact time points used, the initialization of the feature weights, and the regularization parameter (C1) which corresponds to the C1 given in the NIPS paper. The mtlr_cv
function helps to select a value of C1. Given a vector of values to test for C1, mtlr_cv
will do internal cross validation to select the optimal C1 for some criteria. Currently the only optimization is referred to as the log-likelihood loss (see the “Details” section of help(mtlr_cv)
). For example, we use this command with 5 values of C1 (although there is a default of (0.001,0.01,0.1,1,10,100,1000)).
mtlr_cv(formula,training, C1_vec = c(0.01,0.1,1,10,100))
#> $best_C1
#> [1] 1
#>
#> $avg_loss
#> 0.01 0.1 1 10 100
#> 2.334196 2.173089 2.127982 2.152218 2.183193
The output gives us the best value of C1 and the losses for the values tested. Once we have the best value of C1 we can then use the mtlr
function with the chosen value of C1.
create_folds
As we mentioned, mtlr_cv
uses an internal k-fold cross validation to evaluate the loss. We also export the function (create_folds
) used to create these cross-validation folds as it is creating folds in a unique way.
These folds can be deterministic, semi-deterministic, or totally random. The deterministic folds arise by stratifying folds by censor status and attempting to create equal ranges in the event times within each fold. This is done by first stratifying the survival dataset into a censored and uncensored portion and then sorting each portion by the event time. These portions are then numbered off into k different folds (see figure below). This option corresponds to “fullstrat”.