See Module 5 Recommended exercise solutsion for the bootstrap result.
See solutions to exercise Q1 and Q4 here: (https://rstudio-pubs-static.s3.amazonaws.com/65564_925dfde884e14ef9b5735eddd16c263e.html)
library(kernlab)
?spam
library(tree)
set.seed(1)
data(spam)
N=dim(spam)[1]
train=sample(1:N,3000)
test=(1:N)[-train]
tree.spam=tree(type~.,spam,subset=train)
plot(tree.spam)
text(tree.spam,pretty=1)
For this training set we have 11 terminal nodes.
summary(tree.spam)
##
## Classification tree:
## tree(formula = type ~ ., data = spam, subset = train)
## Variables actually used in tree construction:
## [1] "charDollar" "charExclamation" "remove" "hp"
## [5] "capitalLong" "our" "capitalAve" "free"
## [9] "edu"
## Number of terminal nodes: 13
## Residual mean deviance: 0.4983 = 1488 / 2987
## Misclassification error rate: 0.08033 = 241 / 3000
yhat=predict(tree.spam,spam[test,],type="class")
response.test=spam$type[test]
and make a confusion table:
misclass=table(yhat,response.test)
print(misclass)
## response.test
## yhat nonspam spam
## nonspam 904 75
## spam 72 550
The misclassification rate is given by:
1-sum(diag(misclass))/sum(misclass)
## [1] 0.09181761
set.seed(1)
cv.spam=cv.tree(tree.spam,FUN=prune.misclass)
plot(cv.spam$size,cv.spam$dev,type="b")
According to the plot the optimal number of terminal nodes is 7 (or larger). We choose 7 as this gives the simplest tree, and prune the tree according to this value.
prune.spam=prune.misclass(tree.spam,best=7)
plot(prune.spam)
text(prune.spam,pretty=1)
We predict the response for the test data:
yhat.prune=predict(prune.spam,spam[test,],type="class")
misclass.prune=table(yhat.prune,response.test)
print(misclass.prune)
## response.test
## yhat.prune nonspam spam
## nonspam 892 94
## spam 84 531
The misclassification rate is
1-sum(diag(misclass.prune))/sum(misclass.prune)
## [1] 0.1111805
library(randomForest)
bag.spam=randomForest(type~.,data=spam,subset=train,mtry=dim(spam)[2]-1,ntree=500,importance=TRUE)
We predict the response for the test data as before:
yhat.bag=predict(bag.spam,newdata=spam[test,])
misclass.bag=table(yhat.bag,response.test)
print(misclass.bag)
## response.test
## yhat.bag nonspam spam
## nonspam 931 50
## spam 45 575
The misclassification rate is
1-sum(diag(misclass.bag))/sum(misclass.bag)
## [1] 0.05933791
set.seed(1)
rf.spam=randomForest(type~.,data=spam,subset=train,mtry=round(sqrt(dim(spam)[2]-1)),ntree=500,importance=TRUE)
We study the importance of each variable
importance(rf.spam)
## nonspam spam MeanDecreaseAccuracy
## make 3.3286508 7.0789361 7.881511
## address 6.9456739 6.6017228 9.230802
## all 3.7838031 11.6089412 11.045453
## num3d 5.6317271 -0.4005859 4.816208
## our 20.3481526 18.2994589 25.161252
## over 11.9763732 7.8880771 13.449061
## remove 37.6992051 29.1526914 40.821426
## internet 17.4841880 11.1798743 18.176840
## order 9.5504663 6.8347954 11.096554
## mail 7.0328910 8.6655513 10.720964
## receive 12.4903939 6.1395860 13.454074
## will 5.6038105 17.0474590 17.999635
## people 4.1327512 5.4704301 6.483632
## report 4.0542313 5.1114214 6.380721
## addresses 7.1395641 3.8447086 7.990526
## free 32.3624493 26.8187760 35.900421
## business 19.3066415 11.4915974 19.938649
## email 12.6029393 7.8927205 13.299838
## you 13.8265495 17.8253373 20.725064
## credit 13.5546885 3.2207635 13.689773
## your 19.9696217 24.7322894 29.052763
## font 10.7897727 3.8486375 11.158576
## num000 17.0004151 10.6880946 18.253620
## money 16.1264992 14.2336887 18.288839
## hp 25.0943963 33.7878174 36.456178
## hpl 11.7578982 21.6303665 23.082185
## george 13.6058637 23.9623175 25.709246
## num650 10.3679627 13.7472018 16.323370
## lab 0.3133154 9.2430027 9.434675
## labs 5.4408129 10.7223801 11.860677
## telnet 4.5877395 8.3926111 8.772876
## num857 1.9379366 6.1490354 6.407622
## data 4.2606581 6.8010137 7.770119
## num415 2.2429043 5.6267855 6.076930
## num85 5.1583360 11.1537680 11.604092
## technology 8.8812437 8.2667293 11.891463
## num1999 14.4469226 23.2347517 25.032492
## parts 0.1094174 5.1704751 3.977237
## pm 5.5609781 12.0312351 12.058829
## direct 5.7137641 1.3441715 5.968042
## cs 2.4533282 5.2032679 5.511520
## meeting 7.6123864 16.3370269 17.346401
## original 3.0671712 11.7132619 12.185920
## project 3.4386891 9.7637893 10.076058
## re 12.0271312 18.0264226 20.720029
## edu 17.9664055 28.5134606 29.953656
## table 0.4942779 3.2756976 3.040355
## conference 4.4105194 6.2259611 7.127297
## charSemicolon 9.4429702 8.5365298 13.201454
## charRoundbracket 6.2702812 17.5509548 16.203931
## charSquarebracket 7.2737650 7.8437246 9.511108
## charExclamation 36.7847670 39.1124985 48.165503
## charDollar 33.2478237 27.0283095 38.118219
## charHash 6.9571543 6.5520469 9.864698
## capitalAve 28.2298909 27.0213048 37.143622
## capitalLong 26.6282237 23.6391011 34.570490
## capitalTotal 19.4301840 19.3615757 26.484016
## MeanDecreaseGini
## make 5.6981330
## address 6.4923495
## all 14.9033216
## num3d 1.3940859
## our 36.8185976
## over 9.0506097
## remove 106.2083191
## internet 20.5055623
## order 6.0201409
## mail 10.4454996
## receive 12.3149092
## will 15.9134800
## people 4.1708053
## report 2.4653695
## addresses 1.9129947
## free 103.6290501
## business 19.8860882
## email 12.6402552
## you 41.2049962
## credit 7.8855323
## your 100.8884216
## font 2.8163249
## num000 31.3586097
## money 45.5466170
## hp 60.6352550
## hpl 25.4301457
## george 21.4660651
## num650 8.8804572
## lab 3.0106289
## labs 5.8715549
## telnet 2.5021686
## num857 0.9480281
## data 2.7804938
## num415 0.9364028
## num85 4.3617622
## technology 4.4563586
## num1999 20.8874537
## parts 0.8235699
## pm 4.1717657
## direct 1.6672541
## cs 1.1962003
## meeting 5.6242406
## original 2.6279996
## project 2.4955861
## re 11.3850449
## edu 23.3806509
## table 0.3051245
## conference 1.6687055
## charSemicolon 7.0022222
## charRoundbracket 15.5874489
## charSquarebracket 3.8276785
## charExclamation 176.9073084
## charDollar 142.6341536
## charHash 4.4513333
## capitalAve 88.1114809
## capitalLong 79.5597096
## capitalTotal 59.7588218
If MeanDecreaseAccuracy and MeanDecreaseGini are large, the corresponding covariate is important.
varImpPlot(rf.spam)
In this plot we see that charExclamation is the most important covariate, followed by remove and charDollar. This is as expected as these variables are used in the top splits in the classification trees we have seen so far.
We now predict the response for the test data.
yhat.rf=predict(rf.spam,newdata=spam[test,])
misclass.rf=table(yhat.rf,response.test)
1-sum(diag(misclass.rf))/sum(misclass.rf)
## [1] 0.05121799
The misclassification rate is given by
print(misclass.rf)
## response.test
## yhat.rf nonspam spam
## nonspam 943 49
## spam 33 576
library(gbm)
set.seed(1)
spamboost=spam
spamboost$type=c()
spamboost$type[spam$type=="spam"]=1
spamboost$type[spam$type=="nonspam"]=0
boost.spam=gbm(type~.,data=spamboost[train,],distribution="bernoulli",n.trees=5000,interaction.depth=4,shrinkage=0.001)
We predict the response for the test data:
yhat.boost=predict(boost.spam,newdata=spamboost[-train,],n.trees=5000,distribution="bernoulli",type="response")
yhat.boost=ifelse(yhat.boost>0.5,1,0) #Transform to 0 and 1 (nonspam and spam).
misclass.boost=table(yhat.boost,spamboost$type[test])
print(misclass.boost)
##
## yhat.boost 0 1
## 0 941 58
## 1 35 567
The misclassification rate is
1-sum(diag(misclass.boost))/sum(misclass.boost)
## [1] 0.05808869