aboutsummaryrefslogtreecommitdiff
path: root/R_NB/main.r
blob: 75979ad7753034bf2a0e0c546dd361afd040c019 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#!/usr/bin/env Rscript

if (!require(e1071)) install.packages("e1071", repos = "https://cran.r-project.org/", Ncpus = 16) # nolint
library(e1071)

if (!require(dplyr)) install.packages("dplyr", repos = "https://cran.r-project.org/", Ncpus = 16) # nolint
library(dplyr)

# P(y/x) = P(y) * P(x/y) / P(x)

data <- read.csv("./nbtrain.csv")
train <- head(data, 9010)
test <- tail(data, 1000)

# a
model <- naiveBayes(income ~ age + sex + educ, data = train)
model


pred_income <- predict(model, newdata = test)

tt <- table(Predicted = pred_income, Actual = test$income)
print(tt)

misclass <- function(tt) {
  # total_wrong / total_records
  overall_misclass <- (sum(tt) - sum(diag(tt))) / sum(tt)
  cat("Overall misclassification rate:", round(overall_misclass, 4), "\n")

  classes <- rownames(tt)
  misclass_per_class <- numeric(length(classes))
  names(misclass_per_class) <- classes

  for (cls in classes) {
    correct <- tt[cls, cls]
    total_in_class <- sum(tt[, cls])
    misclass_per_class[cls] <- (total_in_class - correct) / total_in_class
  }

  cat("Misclassification rate per income class:\n")
  print(round(misclass_per_class, 4))
}

misclass(tt)


model_sex <- naiveBayes(sex ~ age + educ + income, data = train)
model_sex
pred_sex <- predict(model_sex, newdata = test)
tt <- table(Predicted = pred_sex, Actual = test$sex)
print(tt)

misclass(tt)


test_random <- function() {
  data_female <- subset(train, sex == "F")
  data_male <- subset(train, sex == "M")

  data_female <- sample_n(data_female, 3500)
  data_male <- sample_n(data_male, 3500)

  random_sample <- rbind(data_male, data_female)
  model_random <- naiveBayes(sex ~ age + income + educ, data = random_sample)
  print(model_random)
  pred_random <- predict(model_random, test, type = "class")
  tt <- table(Predicted = pred_random, Actual = test$sex)
  print(tt)

  misclass(tt)
}


test_random()
# test_random()
# test_random()
# test_random()