Solutions

1. Theoretical

See Module 5 Recommended exercise solutsion for the bootstrap result.

2. Understanding

See solutions to exercise Q1 and Q4 here: (https://rstudio-pubs-static.s3.amazonaws.com/65564_925dfde884e14ef9b5735eddd16c263e.html)

3. Implementation:

  1. We have 4601 observations and 58 variables, 57 of them will be used as covariates.
library(kernlab)
?spam
  1. We use approximately \(2/3\) of the observations as the train set and \(1/3\) as the test set.
library(tree)

set.seed(1)

data(spam)
N=dim(spam)[1]

train=sample(1:N,3000)
test=(1:N)[-train]
  1. We fit a classification tree to the training data.
tree.spam=tree(type~.,spam,subset=train)

plot(tree.spam)

text(tree.spam,pretty=1)

For this training set we have 11 terminal nodes.

summary(tree.spam)
## 
## Classification tree:
## tree(formula = type ~ ., data = spam, subset = train)
## Variables actually used in tree construction:
## [1] "charDollar"      "charExclamation" "remove"          "hp"             
## [5] "capitalLong"     "our"             "capitalAve"      "free"           
## [9] "edu"            
## Number of terminal nodes:  13 
## Residual mean deviance:  0.4983 = 1488 / 2987 
## Misclassification error rate: 0.08033 = 241 / 3000
  1. We predict the response for the test data.
yhat=predict(tree.spam,spam[test,],type="class")
response.test=spam$type[test]

and make a confusion table:

misclass=table(yhat,response.test)
print(misclass)
##          response.test
## yhat      nonspam spam
##   nonspam     904   75
##   spam         72  550

The misclassification rate is given by:

1-sum(diag(misclass))/sum(misclass)
## [1] 0.09181761
  1. We use cv.tree() to find the optimal tree size.
set.seed(1)

cv.spam=cv.tree(tree.spam,FUN=prune.misclass)

plot(cv.spam$size,cv.spam$dev,type="b")

According to the plot the optimal number of terminal nodes is 7 (or larger). We choose 7 as this gives the simplest tree, and prune the tree according to this value.

prune.spam=prune.misclass(tree.spam,best=7)

plot(prune.spam)
text(prune.spam,pretty=1)

We predict the response for the test data:

yhat.prune=predict(prune.spam,spam[test,],type="class")

misclass.prune=table(yhat.prune,response.test)
print(misclass.prune)
##           response.test
## yhat.prune nonspam spam
##    nonspam     892   94
##    spam         84  531

The misclassification rate is

1-sum(diag(misclass.prune))/sum(misclass.prune)
## [1] 0.1111805
  1. We create a decision tree by bagging.
library(randomForest)
bag.spam=randomForest(type~.,data=spam,subset=train,mtry=dim(spam)[2]-1,ntree=500,importance=TRUE)

We predict the response for the test data as before:

yhat.bag=predict(bag.spam,newdata=spam[test,])

misclass.bag=table(yhat.bag,response.test)
print(misclass.bag)
##          response.test
## yhat.bag  nonspam spam
##   nonspam     931   50
##   spam         45  575

The misclassification rate is

1-sum(diag(misclass.bag))/sum(misclass.bag)
## [1] 0.05933791
  1. We now use the random forest-algorithm and consider only \(\sqrt{57}\approx 8\) of the predictors at each split. This is specified in mtry.
set.seed(1)

rf.spam=randomForest(type~.,data=spam,subset=train,mtry=round(sqrt(dim(spam)[2]-1)),ntree=500,importance=TRUE)

We study the importance of each variable

importance(rf.spam)
##                      nonspam       spam MeanDecreaseAccuracy
## make               3.3286508  7.0789361             7.881511
## address            6.9456739  6.6017228             9.230802
## all                3.7838031 11.6089412            11.045453
## num3d              5.6317271 -0.4005859             4.816208
## our               20.3481526 18.2994589            25.161252
## over              11.9763732  7.8880771            13.449061
## remove            37.6992051 29.1526914            40.821426
## internet          17.4841880 11.1798743            18.176840
## order              9.5504663  6.8347954            11.096554
## mail               7.0328910  8.6655513            10.720964
## receive           12.4903939  6.1395860            13.454074
## will               5.6038105 17.0474590            17.999635
## people             4.1327512  5.4704301             6.483632
## report             4.0542313  5.1114214             6.380721
## addresses          7.1395641  3.8447086             7.990526
## free              32.3624493 26.8187760            35.900421
## business          19.3066415 11.4915974            19.938649
## email             12.6029393  7.8927205            13.299838
## you               13.8265495 17.8253373            20.725064
## credit            13.5546885  3.2207635            13.689773
## your              19.9696217 24.7322894            29.052763
## font              10.7897727  3.8486375            11.158576
## num000            17.0004151 10.6880946            18.253620
## money             16.1264992 14.2336887            18.288839
## hp                25.0943963 33.7878174            36.456178
## hpl               11.7578982 21.6303665            23.082185
## george            13.6058637 23.9623175            25.709246
## num650            10.3679627 13.7472018            16.323370
## lab                0.3133154  9.2430027             9.434675
## labs               5.4408129 10.7223801            11.860677
## telnet             4.5877395  8.3926111             8.772876
## num857             1.9379366  6.1490354             6.407622
## data               4.2606581  6.8010137             7.770119
## num415             2.2429043  5.6267855             6.076930
## num85              5.1583360 11.1537680            11.604092
## technology         8.8812437  8.2667293            11.891463
## num1999           14.4469226 23.2347517            25.032492
## parts              0.1094174  5.1704751             3.977237
## pm                 5.5609781 12.0312351            12.058829
## direct             5.7137641  1.3441715             5.968042
## cs                 2.4533282  5.2032679             5.511520
## meeting            7.6123864 16.3370269            17.346401
## original           3.0671712 11.7132619            12.185920
## project            3.4386891  9.7637893            10.076058
## re                12.0271312 18.0264226            20.720029
## edu               17.9664055 28.5134606            29.953656
## table              0.4942779  3.2756976             3.040355
## conference         4.4105194  6.2259611             7.127297
## charSemicolon      9.4429702  8.5365298            13.201454
## charRoundbracket   6.2702812 17.5509548            16.203931
## charSquarebracket  7.2737650  7.8437246             9.511108
## charExclamation   36.7847670 39.1124985            48.165503
## charDollar        33.2478237 27.0283095            38.118219
## charHash           6.9571543  6.5520469             9.864698
## capitalAve        28.2298909 27.0213048            37.143622
## capitalLong       26.6282237 23.6391011            34.570490
## capitalTotal      19.4301840 19.3615757            26.484016
##                   MeanDecreaseGini
## make                     5.6981330
## address                  6.4923495
## all                     14.9033216
## num3d                    1.3940859
## our                     36.8185976
## over                     9.0506097
## remove                 106.2083191
## internet                20.5055623
## order                    6.0201409
## mail                    10.4454996
## receive                 12.3149092
## will                    15.9134800
## people                   4.1708053
## report                   2.4653695
## addresses                1.9129947
## free                   103.6290501
## business                19.8860882
## email                   12.6402552
## you                     41.2049962
## credit                   7.8855323
## your                   100.8884216
## font                     2.8163249
## num000                  31.3586097
## money                   45.5466170
## hp                      60.6352550
## hpl                     25.4301457
## george                  21.4660651
## num650                   8.8804572
## lab                      3.0106289
## labs                     5.8715549
## telnet                   2.5021686
## num857                   0.9480281
## data                     2.7804938
## num415                   0.9364028
## num85                    4.3617622
## technology               4.4563586
## num1999                 20.8874537
## parts                    0.8235699
## pm                       4.1717657
## direct                   1.6672541
## cs                       1.1962003
## meeting                  5.6242406
## original                 2.6279996
## project                  2.4955861
## re                      11.3850449
## edu                     23.3806509
## table                    0.3051245
## conference               1.6687055
## charSemicolon            7.0022222
## charRoundbracket        15.5874489
## charSquarebracket        3.8276785
## charExclamation        176.9073084
## charDollar             142.6341536
## charHash                 4.4513333
## capitalAve              88.1114809
## capitalLong             79.5597096
## capitalTotal            59.7588218

If MeanDecreaseAccuracy and MeanDecreaseGini are large, the corresponding covariate is important.

varImpPlot(rf.spam)

In this plot we see that charExclamation is the most important covariate, followed by remove and charDollar. This is as expected as these variables are used in the top splits in the classification trees we have seen so far.

We now predict the response for the test data.

yhat.rf=predict(rf.spam,newdata=spam[test,])

misclass.rf=table(yhat.rf,response.test)
1-sum(diag(misclass.rf))/sum(misclass.rf)
## [1] 0.05121799

The misclassification rate is given by

print(misclass.rf)
##          response.test
## yhat.rf   nonspam spam
##   nonspam     943   49
##   spam         33  576
  1. Finally, we create a tree by using the boosting algorithm. The gbm() function does not allow factors in the response, so we have to use “1” and “0” instead of “spam” and “nonspam”:
library(gbm)
set.seed(1)

spamboost=spam
spamboost$type=c()
spamboost$type[spam$type=="spam"]=1
spamboost$type[spam$type=="nonspam"]=0

boost.spam=gbm(type~.,data=spamboost[train,],distribution="bernoulli",n.trees=5000,interaction.depth=4,shrinkage=0.001)

We predict the response for the test data:

yhat.boost=predict(boost.spam,newdata=spamboost[-train,],n.trees=5000,distribution="bernoulli",type="response")

yhat.boost=ifelse(yhat.boost>0.5,1,0) #Transform to 0 and 1 (nonspam and spam).

misclass.boost=table(yhat.boost,spamboost$type[test])

print(misclass.boost)
##           
## yhat.boost   0   1
##          0 941  58
##          1  35 567

The misclassification rate is

1-sum(diag(misclass.boost))/sum(misclass.boost)
## [1] 0.05808869
  1. We get lower misclassification rates for bagging, boosting and random forest as expected.