aboutsummaryrefslogtreecommitdiff
path: root/R_NB/main.r
diff options
context:
space:
mode:
Diffstat (limited to 'R_NB/main.r')
-rwxr-xr-xR_NB/main.r77
1 files changed, 77 insertions, 0 deletions
diff --git a/R_NB/main.r b/R_NB/main.r
new file mode 100755
index 0000000..75979ad
--- /dev/null
+++ b/R_NB/main.r
@@ -0,0 +1,77 @@
+#!/usr/bin/env Rscript
+
+if (!require(e1071)) install.packages("e1071", repos = "https://cran.r-project.org/", Ncpus = 16) # nolint
+library(e1071)
+
+if (!require(dplyr)) install.packages("dplyr", repos = "https://cran.r-project.org/", Ncpus = 16) # nolint
+library(dplyr)
+
+# P(y/x) = P(y) * P(x/y) / P(x)
+
+data <- read.csv("./nbtrain.csv")
+train <- head(data, 9010)
+test <- tail(data, 1000)
+
+# a
+model <- naiveBayes(income ~ age + sex + educ, data = train)
+model
+
+
+pred_income <- predict(model, newdata = test)
+
+tt <- table(Predicted = pred_income, Actual = test$income)
+print(tt)
+
+misclass <- function(tt) {
+ # total_wrong / total_records
+ overall_misclass <- (sum(tt) - sum(diag(tt))) / sum(tt)
+ cat("Overall misclassification rate:", round(overall_misclass, 4), "\n")
+
+ classes <- rownames(tt)
+ misclass_per_class <- numeric(length(classes))
+ names(misclass_per_class) <- classes
+
+ for (cls in classes) {
+ correct <- tt[cls, cls]
+ total_in_class <- sum(tt[, cls])
+ misclass_per_class[cls] <- (total_in_class - correct) / total_in_class
+ }
+
+ cat("Misclassification rate per income class:\n")
+ print(round(misclass_per_class, 4))
+}
+
+misclass(tt)
+
+
+model_sex <- naiveBayes(sex ~ age + educ + income, data = train)
+model_sex
+pred_sex <- predict(model_sex, newdata = test)
+tt <- table(Predicted = pred_sex, Actual = test$sex)
+print(tt)
+
+misclass(tt)
+
+
+test_random <- function() {
+ data_female <- subset(train, sex == "F")
+ data_male <- subset(train, sex == "M")
+
+ data_female <- sample_n(data_female, 3500)
+ data_male <- sample_n(data_male, 3500)
+
+ random_sample <- rbind(data_male, data_female)
+ model_random <- naiveBayes(sex ~ age + income + educ, data = random_sample)
+ print(model_random)
+ pred_random <- predict(model_random, test, type = "class")
+ tt <- table(Predicted = pred_random, Actual = test$sex)
+ print(tt)
+
+ misclass(tt)
+}
+
+
+test_random()
+# test_random()
+# test_random()
+# test_random()