aboutsummaryrefslogtreecommitdiff
path: root/R_Tree/main.r
blob: ddac2fba2a8d7ed15ae8820b92d4b3f9606967c2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#!/usr/bin/env Rscript

data <- read.csv("survey.csv")
train <- head(data, 600)
test <- tail(data, 150)

if (!require(rpart)) install.packages("rpart", repos = "https://cran.r-project.org/", Ncpus = 16) # nolint
library(rpart)
library(rpart.plot)


dt <- rpart(MYDEPV ~ Price + Income + Age,
  data = train,
  method = "class",
  parms = list(split = "information"), # information gain splitting index
  control = rpart.control(xval = 3) # three-fold cross-validation
)

printcp(dt)
rpart.plot(dt, extra = 106)
# summary(dt)


tree_stats <- function(tree) {
  frm <- tree$frame
  internal_nodes <- sum(as.character(frm$var) != "<leaf>")

  node_indexes <- as.integer(row.names(frm))
  depth <- floor(log2(node_indexes)) + 1L

  list(internal_nodes = as.integer(internal_nodes), height = as.integer(max(depth)))
}

stats <- tree_stats(dt)
stats$internal_nodes
stats$height


pred <- predict(dt, train, type = "class")
conf_matrix <- table(Predicted = pred, Actual = train$MYDEPV)
conf_matrix

misclass <- function(tt) {
  # total_wrong / total_records
  overall_misclass <- (sum(tt) - sum(diag(tt))) / sum(tt)
  cat("Overall misclassification rate:", round(overall_misclass, 4), "\n")

  classes <- rownames(tt)
  misclass_per_class <- numeric(length(classes))
  names(misclass_per_class) <- classes

  for (cls in classes) {
    correct <- tt[cls, cls]
    total_in_class <- sum(tt[, cls])
    misclass_per_class[cls] <- (total_in_class - correct) / total_in_class
  }

  cat("Misclassification rate per income class:\n")
  print(round(misclass_per_class, 4))
}

misclass(conf_matrix)

if (!require(ROCR)) install.packages("ROCR", repos = "https://cran.r-project.org/", Ncpus = 16) # nolint
library(ROCR)

rocr_pred <- prediction(predict(dt, type = "prob")[, 2], train$MYDEPV)
roc <- performance(rocr_pred, "tpr", "fpr")
auc <- performance(rocr_pred, "auc")

plot(roc, col = "blue", main = "ROC")
auc@y.values


print("score with test data")
pred <- predict(dt, test, type = "class")
conf_matrix <- table(Predicted = pred, Actual = test$MYDEPV)
conf_matrix
misclass(conf_matrix)



dt_gini <- rpart(MYDEPV ~ Price + Income + Age,
  data = train,
  method = "class",
  parms = list(split = "gini"), # information gain splitting index
  control = rpart.control(xval = 3) # three-fold cross-validation
)

printcp(dt_gini)
rpart.plot(dt_gini, extra = 106)
# summary(dt_gini)

cp_table <- dt_gini$cptable
optimal_cp <- cp_table[which.min(cp_table[, "xerror"]), "CP"]
cat("Optimal CP value for pruning:", optimal_cp, "\n")
dt_gini_pruned <- prune(dt_gini, cp = optimal_cp)

printcp(dt_gini_pruned)
rpart.plot(dt_gini_pruned, extra = 106)

pred_gini <- predict(dt_gini_pruned, train, type = "class")
conf_matrix_gini <- table(Predicted = pred_gini, Actual = train$MYDEPV)
conf_matrix_gini

misclass(conf_matrix_gini)

stats <- tree_stats(dt_gini_pruned)
stats$internal_nodes
stats$height