aboutsummaryrefslogtreecommitdiff
path: root/R_Tree/main.r
diff options
context:
space:
mode:
authorleshe4ka46 <alex9102naid1@ya.ru>2025-12-09 11:42:25 +0300
committerleshe4ka46 <alex9102naid1@ya.ru>2025-12-09 11:42:25 +0300
commit72b4edeadeafc9c54b3db9b0961a45da3d07b77c (patch)
treefd472d1be885fc9856a3426e8aa794d9aee968c5 /R_Tree/main.r
parentf60863aecfdfb2a7a35c9d2d4233142ca17c9152 (diff)
ar, tree
Diffstat (limited to 'R_Tree/main.r')
-rwxr-xr-xR_Tree/main.r110
1 files changed, 110 insertions, 0 deletions
diff --git a/R_Tree/main.r b/R_Tree/main.r
new file mode 100755
index 0000000..ddac2fb
--- /dev/null
+++ b/R_Tree/main.r
@@ -0,0 +1,110 @@
+#!/usr/bin/env Rscript
+
+data <- read.csv("survey.csv")
+train <- head(data, 600)
+test <- tail(data, 150)
+
+if (!require(rpart)) install.packages("rpart", repos = "https://cran.r-project.org/", Ncpus = 16) # nolint
+library(rpart)
+library(rpart.plot)
+
+
+dt <- rpart(MYDEPV ~ Price + Income + Age,
+ data = train,
+ method = "class",
+ parms = list(split = "information"), # information gain splitting index
+ control = rpart.control(xval = 3) # three-fold cross-validation
+)
+
+printcp(dt)
+rpart.plot(dt, extra = 106)
+# summary(dt)
+
+
+tree_stats <- function(tree) {
+ frm <- tree$frame
+ internal_nodes <- sum(as.character(frm$var) != "<leaf>")
+
+ node_indexes <- as.integer(row.names(frm))
+ depth <- floor(log2(node_indexes)) + 1L
+
+ list(internal_nodes = as.integer(internal_nodes), height = as.integer(max(depth)))
+}
+
+stats <- tree_stats(dt)
+stats$internal_nodes
+stats$height
+
+
+pred <- predict(dt, train, type = "class")
+conf_matrix <- table(Predicted = pred, Actual = train$MYDEPV)
+conf_matrix
+
+misclass <- function(tt) {
+ # total_wrong / total_records
+ overall_misclass <- (sum(tt) - sum(diag(tt))) / sum(tt)
+ cat("Overall misclassification rate:", round(overall_misclass, 4), "\n")
+
+ classes <- rownames(tt)
+ misclass_per_class <- numeric(length(classes))
+ names(misclass_per_class) <- classes
+
+ for (cls in classes) {
+ correct <- tt[cls, cls]
+ total_in_class <- sum(tt[, cls])
+ misclass_per_class[cls] <- (total_in_class - correct) / total_in_class
+ }
+
+ cat("Misclassification rate per income class:\n")
+ print(round(misclass_per_class, 4))
+}
+
+misclass(conf_matrix)
+
+if (!require(ROCR)) install.packages("ROCR", repos = "https://cran.r-project.org/", Ncpus = 16) # nolint
+library(ROCR)
+
+rocr_pred <- prediction(predict(dt, type = "prob")[, 2], train$MYDEPV)
+roc <- performance(rocr_pred, "tpr", "fpr")
+auc <- performance(rocr_pred, "auc")
+
+plot(roc, col = "blue", main = "ROC")
+auc@y.values
+
+
+print("score with test data")
+pred <- predict(dt, test, type = "class")
+conf_matrix <- table(Predicted = pred, Actual = test$MYDEPV)
+conf_matrix
+misclass(conf_matrix)
+
+
+
+dt_gini <- rpart(MYDEPV ~ Price + Income + Age,
+ data = train,
+ method = "class",
+ parms = list(split = "gini"), # information gain splitting index
+ control = rpart.control(xval = 3) # three-fold cross-validation
+)
+
+printcp(dt_gini)
+rpart.plot(dt_gini, extra = 106)
+# summary(dt_gini)
+
+cp_table <- dt_gini$cptable
+optimal_cp <- cp_table[which.min(cp_table[, "xerror"]), "CP"]
+cat("Optimal CP value for pruning:", optimal_cp, "\n")
+dt_gini_pruned <- prune(dt_gini, cp = optimal_cp)
+
+printcp(dt_gini_pruned)
+rpart.plot(dt_gini_pruned, extra = 106)
+
+pred_gini <- predict(dt_gini_pruned, train, type = "class")
+conf_matrix_gini <- table(Predicted = pred_gini, Actual = train$MYDEPV)
+conf_matrix_gini
+
+misclass(conf_matrix_gini)
+
+stats <- tree_stats(dt_gini_pruned)
+stats$internal_nodes
+stats$height