diff options
Diffstat (limited to 'R_Tree/main.r')
| -rwxr-xr-x | R_Tree/main.r | 110 |
1 files changed, 110 insertions, 0 deletions
diff --git a/R_Tree/main.r b/R_Tree/main.r new file mode 100755 index 0000000..ddac2fb --- /dev/null +++ b/R_Tree/main.r @@ -0,0 +1,110 @@ +#!/usr/bin/env Rscript + +data <- read.csv("survey.csv") +train <- head(data, 600) +test <- tail(data, 150) + +if (!require(rpart)) install.packages("rpart", repos = "https://cran.r-project.org/", Ncpus = 16) # nolint +library(rpart) +library(rpart.plot) + + +dt <- rpart(MYDEPV ~ Price + Income + Age, + data = train, + method = "class", + parms = list(split = "information"), # information gain splitting index + control = rpart.control(xval = 3) # three-fold cross-validation +) + +printcp(dt) +rpart.plot(dt, extra = 106) +# summary(dt) + + +tree_stats <- function(tree) { + frm <- tree$frame + internal_nodes <- sum(as.character(frm$var) != "<leaf>") + + node_indexes <- as.integer(row.names(frm)) + depth <- floor(log2(node_indexes)) + 1L + + list(internal_nodes = as.integer(internal_nodes), height = as.integer(max(depth))) +} + +stats <- tree_stats(dt) +stats$internal_nodes +stats$height + + +pred <- predict(dt, train, type = "class") +conf_matrix <- table(Predicted = pred, Actual = train$MYDEPV) +conf_matrix + +misclass <- function(tt) { + # total_wrong / total_records + overall_misclass <- (sum(tt) - sum(diag(tt))) / sum(tt) + cat("Overall misclassification rate:", round(overall_misclass, 4), "\n") + + classes <- rownames(tt) + misclass_per_class <- numeric(length(classes)) + names(misclass_per_class) <- classes + + for (cls in classes) { + correct <- tt[cls, cls] + total_in_class <- sum(tt[, cls]) + misclass_per_class[cls] <- (total_in_class - correct) / total_in_class + } + + cat("Misclassification rate per income class:\n") + print(round(misclass_per_class, 4)) +} + +misclass(conf_matrix) + +if (!require(ROCR)) install.packages("ROCR", repos = "https://cran.r-project.org/", Ncpus = 16) # nolint +library(ROCR) + +rocr_pred <- prediction(predict(dt, type = "prob")[, 2], train$MYDEPV) +roc <- performance(rocr_pred, "tpr", "fpr") +auc <- performance(rocr_pred, "auc") + +plot(roc, col = "blue", main = "ROC") +auc@y.values + + +print("score with test data") +pred <- predict(dt, test, type = "class") +conf_matrix <- table(Predicted = pred, Actual = test$MYDEPV) +conf_matrix +misclass(conf_matrix) + + + +dt_gini <- rpart(MYDEPV ~ Price + Income + Age, + data = train, + method = "class", + parms = list(split = "gini"), # information gain splitting index + control = rpart.control(xval = 3) # three-fold cross-validation +) + +printcp(dt_gini) +rpart.plot(dt_gini, extra = 106) +# summary(dt_gini) + +cp_table <- dt_gini$cptable +optimal_cp <- cp_table[which.min(cp_table[, "xerror"]), "CP"] +cat("Optimal CP value for pruning:", optimal_cp, "\n") +dt_gini_pruned <- prune(dt_gini, cp = optimal_cp) + +printcp(dt_gini_pruned) +rpart.plot(dt_gini_pruned, extra = 106) + +pred_gini <- predict(dt_gini_pruned, train, type = "class") +conf_matrix_gini <- table(Predicted = pred_gini, Actual = train$MYDEPV) +conf_matrix_gini + +misclass(conf_matrix_gini) + +stats <- tree_stats(dt_gini_pruned) +stats$internal_nodes +stats$height |
