In this lab we will go through the model building, validation, and interpretation of tree models. The focus will be on rpart
package. Recall that when the response variable \(Y\) is continuous, we fit regression tree; when the reponse variable \(Y\) is categorical, we fit classification tree. We build tree models for our familiar datasets, Boston Housing data and Credit Card Default data, for regression and classification tree respectively.
Load the data, and randomly split to training and testing sample.
library(tidyverse)
## -- Attaching packages ------------------------------------------------------------------------------------------- tidyverse 1.2.1 --
## v ggplot2 3.1.1 v purrr 0.3.2
## v tibble 2.1.1 v dplyr 0.8.0.1
## v tidyr 0.8.3 v stringr 1.4.0
## v readr 1.3.1 v forcats 0.4.0
## -- Conflicts ---------------------------------------------------------------------------------------------- tidyverse_conflicts() --
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
library(MASS)
##
## Attaching package: 'MASS'
## The following object is masked from 'package:dplyr':
##
## select
data(Boston)
index <- sample(nrow(Boston),nrow(Boston)*0.90)
boston.train <- Boston[index,]
boston.test <- Boston[-index,]
We will use the ‘rpart’ library for model building and ‘rpart.plot’ for plotting.
install.packages('rpart')
install.packages('rpart.plot')
library(rpart)
library(rpart.plot)
The simple form of the rpart function is similar to lm and glm. It takes a formula argument in which you specify the response and predictor variables, and a data argument in which you specify the data frame.
boston.rpart <- rpart(formula = medv ~ ., data = boston.train)
boston.rpart
## n= 455
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 455 38536.4300 22.62022
## 2) rm< 6.797 370 13542.1300 19.59865
## 4) lstat>=14.4 155 2908.5710 14.93097
## 8) crim>=6.99237 63 761.1032 11.77937 *
## 9) crim< 6.99237 92 1093.2090 17.08913 *
## 5) lstat< 14.4 215 4821.9370 22.96372
## 10) lstat>=4.52 207 3001.1550 22.50483
## 20) lstat>=9.715 100 635.1251 20.74300 *
## 21) lstat< 9.715 107 1765.5270 24.15140 *
## 11) lstat< 4.52 8 649.2987 34.83750 *
## 3) rm>=6.797 85 6911.7480 35.77294
## 6) rm< 7.437 57 2291.8200 31.24386
## 12) lstat>=11.315 8 420.9288 20.91250 *
## 13) lstat< 11.315 49 877.5841 32.93061 *
## 7) rm>=7.437 28 1070.5190 44.99286 *
prp(boston.rpart, digits = 4, extra = 1)
Make sure you know how to interpret this tree model!
Exercise: What is the predicted median housing price (in thousand) given following information:
crim | zn | indus | chas | nox | rm | age | dis | rad | tax | ptratio | black | lstat |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0.05 | 0 | 3.41 | 0 | 0.49 | 7.08 | 63.1 | 3.41 | 2 | 270 | 17.8 | 396.06 | 5.7 |
The in-sample and out-of-sample prediction for regression tree is also similar to lm()
and glm()
models.
In-sample prediction
boston.train.pred.tree = predict(boston.rpart)
Out-of-sample prediction
boston.test.pred.tree = predict(boston.rpart, newdata=boston.test)
Exercise: Calculate the mean squared error (MSE) and mean squared prediction error (MSPE) for this tree model
MSE.tree<-
MSPE.tree<-
Calculate the mean squared error (MSE) and mean squared prediction error (MSPE) for linear regression model using all variables. Then compare the results. What is your conclusion? Further, try to compare the regression tree with the best linear regression model using some variable selection procedures.
boston.lm<-
boston.train.pred.lm<-
boston.test.pred.lm<-
MSE.lm<-
MSPE.lm<-
Load the data, rename response variable (because it is too long), convert categorical variable to factor, and randomly split to training and testing sample.
credit.data <- read.csv("http://homepages.uc.edu/~lis6/DataMining/Data/credit_default.csv", header=T)
# rename
library(dplyr)
credit.data<- rename(credit.data, default=default.payment.next.month)
# convert categorical data to factor
credit.data$SEX<- as.factor(credit.data$SEX)
credit.data$EDUCATION<- as.factor(credit.data$EDUCATION)
credit.data$MARRIAGE<- as.factor(credit.data$MARRIAGE)
# random splitting
index <- sample(nrow(credit.data),nrow(credit.data)*0.80)
credit.train = credit.data[index,]
credit.test = credit.data[-index,]
You need to tell R you want a classification tree. We have to specify method="class"
, since the default is to fit regression tree.
credit.rpart0 <- rpart(formula = default ~ ., data = credit.train, method = "class")
However, this tree minimizes the symmetric cost, which is misclassification rate. We can take a look at the confusion matrix.
pred0<- predict(credit.rpart0, type="class")
table(credit.train$default, pred0, dnn = c("True", "Pred"))
## Pred
## True 0 1
## 0 7209 308
## 1 1369 714
Note that in the predict()
function, we need type="class"
in order to get binary prediction.
Look at the confusion matrix, is it what we expected? Think about why the confusion matrix is like this?
Therefore, for most applications (very unbalanced data), we often have asymmetric cost. Recall the example in logistic regression. In the credit scoring case it means that false negatives (predicting 0 when truth is 1, or giving out loans that end up in default) will cost more than false positives (predicting 1 when truth is 0, rejecting loans that you should not reject).
Here we make the assumption that false negative cost 5 times of false positive. In real life the cost structure should be carefully researched.
credit.rpart <- rpart(formula = default ~ . , data = credit.train, method = "class", parms = list(loss=matrix(c(0,5,1,0), nrow = 2)))
For more advanced controls, you should carefully read the help document for the rpart function.
credit.rpart
## n= 9600
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 9600 7517 1 (0.78302083 0.21697917)
## 2) PAY_0< 0.5 7468 4890 0 (0.86904124 0.13095876)
## 4) PAY_AMT2>=1640.5 4724 2415 0 (0.89775614 0.10224386)
## 8) PAY_4< 1 4484 2105 0 (0.90611062 0.09388938) *
## 9) PAY_4>=1 240 178 1 (0.74166667 0.25833333) *
## 5) PAY_AMT2< 1640.5 2744 2249 1 (0.81960641 0.18039359)
## 10) LIMIT_BAL>=65000 1745 1330 0 (0.84756447 0.15243553)
## 20) BILL_AMT1>=8583 617 300 0 (0.90275527 0.09724473) *
## 21) BILL_AMT1< 8583 1128 922 1 (0.81737589 0.18262411)
## 42) PAY_AMT4>=962 349 190 0 (0.89111748 0.10888252) *
## 43) PAY_AMT4< 962 779 611 1 (0.78433890 0.21566110) *
## 11) LIMIT_BAL< 65000 999 770 1 (0.77077077 0.22922923) *
## 3) PAY_0>=0.5 2132 1027 1 (0.48170732 0.51829268) *
prp(credit.rpart, extra = 1)
For a binary classification problem, there are 2 types of predictions. One is the predicted class of response (0 or 1), and the second type is the probability of response being 1. We use an additional argument type=“class” or type=“prob” to get these:
In-sample prediction
credit.train.pred.tree1<- predict(credit.rpart, credit.train, type="class")
table(credit.train$default, credit.train.pred.tree1, dnn=c("Truth","Predicted"))
## Predicted
## Truth 0 1
## 0 4931 2586
## 1 519 1564
Exercise: Out-of-sample prediction
#Predicted Class
credit.test.pred.tree1<-
table()
Exercise: Try type="prob"
in prediction, what can you say about these predicted probabilities?
We can get the expected loss for this tree model by defining a cost function that has the correct weights:
cost <- function(r, pi){
weight1 = 5
weight0 = 1
c1 = (r==1)&(pi==0) #logical vector - true if actual 1 but predict 0
c0 = (r==0)&(pi==1) #logical vector - true if actual 0 but predict 1
return(mean(weight1*c1+weight0*c0))
}
Calculate the cost for training sample using above cost function
cost(credit.train$default,credit.train.pred.tree1)
## [1] 0.5396875
Exercise: Calculate the cost for testing sample.
We can compare this model’s out-of-sample performance with the logistic regression model with all variables in it. Recall that when we search for the optimal cut-off using the same cost function we get optimal cut-off at about 0.21.
#Fit logistic regression model
credit.glm<- glm(default~., data = credit.train, family=binomial)
#Get binary prediction
credit.test.pred.glm<- as.numeric(predict(credit.glm, credit.test, type="response")>0.21)
#Calculate cost using test set
cost(credit.test$default,credit.test.pred.glm)
## [1] 0.6854167
#Confusion matrix
table(credit.test$default, credit.test.pred.glm, dnn=c("Truth","Predicted"))
## Predicted
## Truth 0 1
## 0 1256 595
## 1 210 339
Exercise: Comparison for in-sample performance.
Which model do you think is better?
To get ROC curve, we get the predicted probability of Y being 1 from the fitted tree.
credit.test.prob.rpart<- predict(credit.rpart,credit.test, type="prob")
credit.test.prob.rpart has 2 columns, the first one is prob(Y) = 0 and the second prob(Y) = 1. We need the second column.
To get ROC curve we use
library(ROCR)
pred = prediction(credit.test.prob.rpart[,2], credit.test$default)
perf = performance(pred, "tpr", "fpr")
plot(perf)
Area under the curve is given by (do not worry about the syntax here):
slot(performance(pred, "auc"), "y.values")[[1]]
## [1] 0.7290762
Exercise: Draw the ROC curve for training sample.
In rpart(), the cp(complexity parameter) argument is one of the parameters that are used to control the compexity of the tree. The help document for rpart tells you “Any split that does not decrease the overall lack of fit by a factor of cp is not attempted”. For a regression tree, the overall Rsquare must increase by cp at each step. Basically, the smaller the cp value, the larger (complex) tree rpart will attempt to fit. The default value for cp is 0.01.
What happens when you have a large tree? The following tree has 27 splits.
boston.largetree <- rpart(formula = medv ~ ., data = boston.train, cp = 0.001)
Try plot it yourself to see its structure.
prp(boston.largetree)
The plotcp() function gives the relationship between 10-fold cross-validation error in the training set and size of tree.
plotcp(boston.largetree)
You can observe from the above graph that the cross-validation error (x-val) does not always go down when the tree becomes more complex. The analogy is when you add more variables in a regression model, its ability to predict future observations not necessarily increases. A good choice of cp for pruning is often the leftmost value for which the mean lies below the horizontal line. In the Boston housing example, you may conclude that having a tree mode with more than 10 splits is not helptul.
To look at the error vs size of tree more carefully, you can look at the following table:
printcp(boston.largetree)
##
## Regression tree:
## rpart(formula = medv ~ ., data = boston.train, cp = 0.001)
##
## Variables actually used in tree construction:
## [1] age black crim dis indus lstat nox ptratio
## [9] rm tax
##
## Root node error: 38536/455 = 84.695
##
## n= 455
##
## CP nsplit rel error xerror xstd
## 1 0.4692327 0 1.00000 1.00519 0.086612
## 2 0.1508085 1 0.53077 0.59011 0.058555
## 3 0.0921053 2 0.37996 0.42350 0.049502
## 4 0.0303994 3 0.28785 0.33628 0.044758
## 5 0.0273575 4 0.25745 0.31817 0.043050
## 6 0.0257758 5 0.23010 0.31559 0.043090
## 7 0.0155827 6 0.20432 0.30128 0.045387
## 8 0.0090332 7 0.18874 0.26985 0.042530
## 9 0.0062656 8 0.17970 0.26077 0.040694
## 10 0.0060826 10 0.16717 0.25459 0.038892
## 11 0.0051515 11 0.16109 0.24923 0.038662
## 12 0.0049469 13 0.15079 0.24594 0.038563
## 13 0.0043488 14 0.14584 0.24347 0.038575
## 14 0.0041788 15 0.14149 0.24246 0.038560
## 15 0.0036638 16 0.13731 0.23620 0.037596
## 16 0.0021362 17 0.13365 0.23170 0.037720
## 17 0.0020514 18 0.13151 0.22986 0.037657
## 18 0.0016736 19 0.12946 0.23243 0.037473
## 19 0.0016079 20 0.12779 0.23244 0.037477
## 20 0.0015478 22 0.12457 0.23335 0.037504
## 21 0.0015370 23 0.12302 0.23360 0.037505
## 22 0.0014781 25 0.11995 0.23342 0.037506
## 23 0.0013558 26 0.11847 0.23244 0.037510
## 24 0.0012113 27 0.11712 0.23378 0.037531
## 25 0.0011550 28 0.11591 0.23493 0.037557
## 26 0.0010831 29 0.11475 0.23416 0.037213
## 27 0.0010000 30 0.11367 0.23514 0.037246
Root node error is the error when you do not do anything too smart in prediction, in regression case, it is the mean squared error(MSE) if you use the average of medv as the prediction. Note it is the same as
sum((boston.train$medv - mean(boston.train$medv))^2)/nrow(boston.train)
## [1] 84.69546
The first 2 columns CP and nsplit tells you how large the tree is. rel.error \(\times\) root node error gives you the in sample error. For example, the last row “(rel error)*(root node error)“, which is the same as the in-sample MSE if you calculate using predict:
mean((predict(boston.largetree) - boston.train$medv)^2)
## [1] 9.627137
xerror gives you the cross-validation (default is 10-fold) error. You can see that the rel error (in-sample error) is always decreasing as model is more complex, while the cross-validation error (measure of performance on future observations) is not. That is why we prune the tree to avoid overfitting the training data.
The way rpart() does it is that it uses some default control parameters to avoid fitting a large tree. The main reason for this approach is to save computation time. For example by default rpart set a cp = 0.1 and the minimum number of observations that must exist in a node to be 20. Use ?rpart.control to view these parameters. Sometimes we wish to change these paramters to see how more complex trees will perform, as we did above. If we have a larger than necessary tree, we can use prune() function and specify a new cp:
tree.prune<- prune(boston.largetree, cp = 0.008)
tree.prune$cptable
## CP nsplit rel error xerror xstd
## 1 0.469232750 0 1.0000000 1.0051894 0.08661198
## 2 0.150808478 1 0.5307673 0.5901101 0.05855456
## 3 0.092105275 2 0.3799588 0.4234953 0.04950210
## 4 0.030399365 3 0.2878535 0.3362820 0.04475830
## 5 0.027357463 4 0.2574541 0.3181666 0.04305025
## 6 0.025775803 5 0.2300967 0.3155939 0.04308958
## 7 0.015582728 6 0.2043209 0.3012764 0.04538707
## 8 0.009033196 7 0.1887381 0.2698540 0.04252968
## 9 0.008000000 8 0.1797049 0.2607654 0.04069443
Exercise: Prune a classification tree. Start with “cp=0.001”, and find a reasonable cp value, then obtain the pruned tree.
digit<- data.matrix(read_csv("https://www.dropbox.com/s/ulujvi2a4ykfzju/train.csv?dl=1"))
dim(digit)
## [1] 42000 785
## visualize the data
plotTrain <- function(data, index){
op <- par(no.readonly=TRUE)
x <- ceiling(sqrt(length(index)))
par(mfrow=c(x, x), mar=c(.1, .1, .1, .1))
for (i in index){ #reverse and transpose each matrix to rotate images
m <- matrix(data[i,-1], nrow=28, byrow=TRUE)
m <- apply(m, 2, rev)
image(t(m), col=grey.colors(255), axes=FALSE)
text(0.05, 0.2, col="white", cex=1.2, data[i, 1])
}
par(op) #reset the original graphics parameters
}
plotTrain(data=digit, index=1:100)
index<- sample(1:nrow(digit), 0.6*nrow(digit))
train<- digit[index,]
test<- digit[-index,]
Currently, each cell uses 0-255 to represent the grey color scale. We recale it to 0-1.
## standardize X
train.x <- train[,-1] #remove 'label' column
test.x<- test[,-1]
train.y <- train[,1] #label column
test.y<- test[,1]
train.x <- train.x/255
test.x <- test.x/255
Here we use classification tree to train a classifier, and then compare with the multinomial logit model (in last lab).
Due to the large size, we only use first 3000 observations as training sample.
fit.tree <- rpart(y ~., method = "class", data = data.frame(y=train.y[1:3000], x=train.x[1:3000,]), cp=0.00001)
plotcp(fit.tree)
printcp(fit.tree)
##
## Classification tree:
## rpart(formula = y ~ ., data = data.frame(y = train.y[1:3000],
## x = train.x[1:3000, ]), method = "class", cp = 1e-05)
##
## Variables actually used in tree construction:
## [1] x.pixel122 x.pixel126 x.pixel128 x.pixel153 x.pixel154 x.pixel155
## [7] x.pixel158 x.pixel177 x.pixel210 x.pixel211 x.pixel213 x.pixel235
## [13] x.pixel240 x.pixel241 x.pixel263 x.pixel265 x.pixel267 x.pixel269
## [19] x.pixel270 x.pixel271 x.pixel272 x.pixel290 x.pixel297 x.pixel301
## [25] x.pixel315 x.pixel317 x.pixel318 x.pixel324 x.pixel326 x.pixel344
## [31] x.pixel345 x.pixel346 x.pixel347 x.pixel350 x.pixel351 x.pixel352
## [37] x.pixel353 x.pixel354 x.pixel355 x.pixel358 x.pixel372 x.pixel374
## [43] x.pixel375 x.pixel378 x.pixel380 x.pixel382 x.pixel386 x.pixel404
## [49] x.pixel405 x.pixel426 x.pixel430 x.pixel437 x.pixel439 x.pixel455
## [55] x.pixel456 x.pixel457 x.pixel458 x.pixel460 x.pixel462 x.pixel464
## [61] x.pixel483 x.pixel486 x.pixel487 x.pixel488 x.pixel489 x.pixel490
## [67] x.pixel513 x.pixel514 x.pixel515 x.pixel517 x.pixel519 x.pixel544
## [73] x.pixel550 x.pixel569 x.pixel573 x.pixel575 x.pixel599 x.pixel606
## [79] x.pixel626 x.pixel651 x.pixel654 x.pixel655 x.pixel657 x.pixel658
## [85] x.pixel659 x.pixel660 x.pixel679
##
## Root node error: 2663/3000 = 0.88767
##
## n= 3000
##
## CP nsplit rel error xerror xstd
## 1 0.08974840 0 1.00000 1.00488 0.0063839
## 2 0.07998498 1 0.91025 0.92790 0.0078385
## 3 0.07810740 2 0.83027 0.84566 0.0088982
## 4 0.07209914 3 0.75216 0.75629 0.0096613
## 5 0.07172362 4 0.68006 0.72512 0.0098503
## 6 0.05520090 5 0.60834 0.62223 0.0102275
## 7 0.03567405 6 0.55314 0.55389 0.0102825
## 8 0.03229440 7 0.51746 0.52798 0.0102637
## 9 0.02065340 8 0.48517 0.50169 0.0102223
## 10 0.01764927 9 0.46451 0.48104 0.0101737
## 11 0.01426962 10 0.44686 0.46864 0.0101378
## 12 0.01145325 11 0.43259 0.45137 0.0100789
## 13 0.00901239 13 0.40969 0.43485 0.0100131
## 14 0.00826136 14 0.40068 0.42471 0.0099679
## 15 0.00788584 16 0.38415 0.42133 0.0099520
## 16 0.00675929 18 0.36838 0.40931 0.0098923
## 17 0.00638378 19 0.36162 0.40481 0.0098686
## 18 0.00600826 20 0.35524 0.39992 0.0098420
## 19 0.00544499 21 0.34923 0.39054 0.0097884
## 20 0.00450620 24 0.33271 0.38002 0.0097245
## 21 0.00413068 28 0.31468 0.36801 0.0096462
## 22 0.00375516 36 0.28014 0.35261 0.0095376
## 23 0.00337965 39 0.26887 0.34848 0.0095069
## 24 0.00300413 43 0.25535 0.33383 0.0093921
## 25 0.00262861 45 0.24934 0.32557 0.0093234
## 26 0.00225310 50 0.23620 0.31769 0.0092550
## 27 0.00187758 56 0.22268 0.30755 0.0091630
## 28 0.00150207 63 0.20954 0.30567 0.0091455
## 29 0.00131431 69 0.20053 0.30379 0.0091277
## 30 0.00112655 71 0.19790 0.30492 0.0091384
## 31 0.00075103 79 0.18888 0.30417 0.0091313
## 32 0.00037552 86 0.18363 0.30792 0.0091665
## 33 0.00025034 93 0.18100 0.31055 0.0091907
## 34 0.00018776 96 0.18025 0.31093 0.0091942
## 35 0.00001000 98 0.17987 0.31055 0.0091907
# make prediction
pred.y.tree<- predict(fit.tree, data.frame(y=test.y, x=test.x), type = "class")
# accuracy rate
mean(test.y==pred.y.tree)
## [1] 0.7339286
We got pretty good accuracy. As we learn more advanced ML algorithm, you will see that the accuracy rate could hit to 99%.
plotResults <- function(testdata, index, preds){
op <- par(no.readonly=TRUE)
x <- ceiling(sqrt(length(index)))
par(mfrow=c(x,x), mar=c(.1,.1,.1,.1))
for (i in index){
m <- matrix(testdata[i,], nrow=28, byrow=TRUE)
m <- apply(m, 2, rev)
image(t(m), col=grey.colors(255), axes=FALSE)
text(0.05,0.1,col="green", cex=1.2, preds[i])
}
par(op)
}
Here are the first 100 images in the test set and their predicted values:
plotResults(testdata=test.x, index=1:100, preds=pred.y.tree)
We see it did a very good job.
Use rpart()
to fit regression and classification tree.
Know how to interpret a tree.
Use predict()
for prediction, and how to assess the performance.
Know how to use Cp plot/table to prune a large tree.