See Module 5 Recommended exercise solution for the bootstrap result.
See solutions to exercise Q1 and Q4 here: https://rstudio-pubs-static.s3.amazonaws.com/65564_925dfde884e14ef9b5735eddd16c263e.html
See solutions to 2c-f: https://www.math.ntnu.no/emner/TMA4268/2019v/8Trees/TMA4268M8RecEx2ctof.pdf
We have 4601 observations and 58 variables, 57 of them will be used as covariates.
library(kernlab)
`?`(spam)
We use approximately \(2/3\) of the observations as the train set and \(1/3\) as the test set.
library(tree)
set.seed(1)
data(spam)
N = dim(spam)[1]
train = sample(1:N, 3000)
test = (1:N)[-train]
We fit a classification tree to the training data.
tree.spam = tree(type ~ ., spam, subset = train)
plot(tree.spam)
text(tree.spam, pretty = 1)
For this training set we have 11 terminal nodes.
summary(tree.spam)
##
## Classification tree:
## tree(formula = type ~ ., data = spam, subset = train)
## Variables actually used in tree construction:
## [1] "charDollar" "charExclamation" "remove" "hp"
## [5] "capitalLong" "our" "capitalAve" "free"
## [9] "edu"
## Number of terminal nodes: 13
## Residual mean deviance: 0.4983 = 1488 / 2987
## Misclassification error rate: 0.08033 = 241 / 3000
We predict the response for the test data.
yhat = predict(tree.spam, spam[test, ], type = "class")
response.test = spam$type[test]
and make a confusion table:
misclass = table(yhat, response.test)
print(misclass)
## response.test
## yhat nonspam spam
## nonspam 904 75
## spam 72 550
The misclassification rate is given by:
1 - sum(diag(misclass))/sum(misclass)
## [1] 0.09181761
We use cv.tree() to find the optimal tree size.
set.seed(1)
cv.spam = cv.tree(tree.spam, FUN = prune.misclass)
plot(cv.spam$size, cv.spam$dev, type = "b")
According to the plot the optimal number of terminal nodes is 7 (or larger). We choose 7 as this gives the simplest tree, and prune the tree according to this value.
prune.spam = prune.misclass(tree.spam, best = 7)
plot(prune.spam)
text(prune.spam, pretty = 1)
We predict the response for the test data:
yhat.prune = predict(prune.spam, spam[test, ], type = "class")
misclass.prune = table(yhat.prune, response.test)
print(misclass.prune)
## response.test
## yhat.prune nonspam spam
## nonspam 892 94
## spam 84 531
The misclassification rate is
1 - sum(diag(misclass.prune))/sum(misclass.prune)
## [1] 0.1111805
We create a decision tree by bagging.
library(randomForest)
bag.spam = randomForest(type ~ ., data = spam, subset = train, mtry = dim(spam)[2] -
1, ntree = 500, importance = TRUE)
We predict the response for the test data as before:
yhat.bag = predict(bag.spam, newdata = spam[test, ])
misclass.bag = table(yhat.bag, response.test)
print(misclass.bag)
## response.test
## yhat.bag nonspam spam
## nonspam 932 49
## spam 44 576
The misclassification rate is
1 - sum(diag(misclass.bag))/sum(misclass.bag)
## [1] 0.05808869
We now use the random forest-algorithm and consider only \(\sqrt{57}\approx 8\) of the predictors at each split. This is specified in mtry.
set.seed(1)
rf.spam = randomForest(type ~ ., data = spam, subset = train, mtry = round(sqrt(dim(spam)[2] -
1)), ntree = 500, importance = TRUE)
We study the importance of each variable
importance(rf.spam)
## nonspam spam MeanDecreaseAccuracy
## make 3.7106926 6.420583 6.875656
## address 6.9560490 6.007242 8.491042
## all 5.4977666 12.325039 12.232882
## num3d 5.4895675 4.538283 6.569200
## our 19.0702594 20.728272 24.117637
## over 12.1413796 9.795084 14.258429
## remove 35.0075329 29.197688 37.905961
## internet 17.4097498 10.702028 18.181981
## order 9.5480959 5.650150 10.472488
## mail 8.9561248 9.207912 11.760278
## receive 12.3069450 5.844350 13.394928
## will 7.7310375 16.248204 17.546926
## people 3.1117965 6.800867 6.922568
## report 5.7533164 6.769473 7.833354
## addresses 6.6664005 4.573558 7.733526
## free 31.6622555 27.112404 35.791600
## business 19.0766919 11.764762 20.243486
## email 10.3535985 7.826456 12.028503
## you 14.0188662 18.642762 21.698724
## credit 12.2629571 4.306761 12.725296
## your 19.9909130 25.601900 29.211840
## font 8.7669734 4.720813 9.363262
## num000 18.0528648 12.205178 19.363161
## money 15.2171668 14.735097 17.635124
## hp 24.5580004 34.288678 36.872253
## hpl 14.8464704 21.665639 24.122479
## george 14.0824656 23.516176 25.324531
## num650 11.8830753 14.350035 18.109587
## lab 0.6018690 8.988654 9.391188
## labs 6.0469190 10.178154 11.735195
## telnet 4.4734144 7.112928 7.479557
## num857 0.6437629 5.593084 5.692570
## data 2.2307640 7.009368 7.212093
## num415 2.9177565 5.627777 5.886071
## num85 5.2151469 10.287728 10.095952
## technology 8.6331089 7.479078 11.168371
## num1999 15.7322669 22.746573 24.837069
## parts -0.4804300 5.911246 4.371374
## pm 3.7499730 12.265635 11.606110
## direct 4.9604423 2.434232 5.890448
## cs 0.9163541 4.891726 5.135119
## meeting 7.8844445 17.857885 18.958137
## original 1.2158141 10.391645 10.552116
## project 3.9333618 8.562863 9.463547
## re 13.2754715 16.646305 20.791126
## edu 20.3180372 26.993562 30.045121
## table -0.3981665 1.966841 1.619565
## conference 3.2026967 6.642276 7.377644
## charSemicolon 7.9446983 10.917926 13.436729
## charRoundbracket 6.3424632 17.052765 16.814249
## charSquarebracket 7.1312429 5.718408 8.784499
## charExclamation 37.7930688 39.126286 49.908915
## charDollar 33.6329177 28.010149 38.166858
## charHash 7.5307501 5.446310 9.192637
## capitalAve 30.9637266 29.090111 41.028351
## capitalLong 25.7204430 21.918536 31.896693
## capitalTotal 21.5167798 20.363919 28.448450
## MeanDecreaseGini
## make 5.8945417
## address 6.0484787
## all 17.6430495
## num3d 1.4442140
## our 39.3109685
## over 10.5074779
## remove 96.1191026
## internet 22.2350905
## order 6.7413230
## mail 12.1389219
## receive 12.8310793
## will 15.9908299
## people 4.1434561
## report 2.7434463
## addresses 2.0555616
## free 98.7746981
## business 20.4326715
## email 12.2363666
## you 40.9670624
## credit 7.1021820
## your 89.6925790
## font 2.4547703
## num000 31.8994178
## money 49.0904913
## hp 64.5469175
## hpl 25.4851785
## george 20.1254989
## num650 9.6499747
## lab 2.4523419
## labs 6.4688614
## telnet 2.2025785
## num857 0.9449562
## data 2.7562055
## num415 0.9492738
## num85 3.7265835
## technology 4.7351351
## num1999 20.5135878
## parts 0.8035020
## pm 4.3135475
## direct 1.5860245
## cs 1.0452749
## meeting 6.2419602
## original 2.2716718
## project 2.2491866
## re 11.8499184
## edu 21.7211105
## table 0.4296191
## conference 1.5614413
## charSemicolon 7.1880288
## charRoundbracket 15.8344882
## charSquarebracket 3.3167999
## charExclamation 180.0080435
## charDollar 149.2319049
## charHash 4.4687720
## capitalAve 91.2654492
## capitalLong 77.3266819
## capitalTotal 60.1267373
If MeanDecreaseAccuracy and MeanDecreaseGini are large, the corresponding covariate is important.
varImpPlot(rf.spam)
In this plot we see that charExclamation is the most important covariate, followed by remove and charDollar. This is as expected as these variables are used in the top splits in the classification trees we have seen so far.
We now predict the response for the test data.
yhat.rf = predict(rf.spam, newdata = spam[test, ])
misclass.rf = table(yhat.rf, response.test)
1 - sum(diag(misclass.rf))/sum(misclass.rf)
## [1] 0.05121799
The misclassification rate is given by
print(misclass.rf)
## response.test
## yhat.rf nonspam spam
## nonspam 943 49
## spam 33 576
Finally, we create a tree by using the boosting algorithm. The gbm() function does not allow factors in the response, so we have to use “1” and “0” instead of “spam” and “nonspam”:
library(gbm)
set.seed(1)
spamboost = spam
spamboost$type = c()
spamboost$type[spam$type == "spam"] = 1
spamboost$type[spam$type == "nonspam"] = 0
boost.spam = gbm(type ~ ., data = spamboost[train, ], distribution = "bernoulli",
n.trees = 5000, interaction.depth = 4, shrinkage = 0.001)
We predict the response for the test data:
yhat.boost = predict(boost.spam, newdata = spamboost[-train, ], n.trees = 5000,
distribution = "bernoulli", type = "response")
yhat.boost = ifelse(yhat.boost > 0.5, 1, 0) #Transform to 0 and 1 (nonspam and spam).
misclass.boost = table(yhat.boost, spamboost$type[test])
print(misclass.boost)
##
## yhat.boost 0 1
## 0 941 58
## 1 35 567
The misclassification rate is
1 - sum(diag(misclass.boost))/sum(misclass.boost)
## [1] 0.05808869
We get lower misclassification rates for bagging, boosting and random forest as expected.
Problem 1 - Classification with trees: https://www.math.ntnu.no/emner/TMA4268/2018v/CompEx/Compulsory3solutions.html
Classification of diabetes cases c), with Q20, Q21, Q22.