1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
|
#!/usr/bin/env Rscript
if (!require(e1071)) install.packages("e1071", repos = "https://cran.r-project.org/", Ncpus = 16) # nolint
library(e1071)
if (!require(dplyr)) install.packages("dplyr", repos = "https://cran.r-project.org/", Ncpus = 16) # nolint
library(dplyr)
# P(y/x) = P(y) * P(x/y) / P(x)
data <- read.csv("./nbtrain.csv")
train <- head(data, 9010)
test <- tail(data, 1000)
# a
model <- naiveBayes(income ~ age + sex + educ, data = train)
model
pred_income <- predict(model, newdata = test)
tt <- table(Predicted = pred_income, Actual = test$income)
print(tt)
misclass <- function(tt) {
# total_wrong / total_records
overall_misclass <- (sum(tt) - sum(diag(tt))) / sum(tt)
cat("Overall misclassification rate:", round(overall_misclass, 4), "\n")
classes <- rownames(tt)
misclass_per_class <- numeric(length(classes))
names(misclass_per_class) <- classes
for (cls in classes) {
correct <- tt[cls, cls]
total_in_class <- sum(tt[, cls])
misclass_per_class[cls] <- (total_in_class - correct) / total_in_class
}
cat("Misclassification rate per income class:\n")
print(round(misclass_per_class, 4))
}
misclass(tt)
model_sex <- naiveBayes(sex ~ age + educ + income, data = train)
model_sex
pred_sex <- predict(model_sex, newdata = test)
tt <- table(Predicted = pred_sex, Actual = test$sex)
print(tt)
misclass(tt)
test_random <- function() {
data_female <- subset(train, sex == "F")
data_male <- subset(train, sex == "M")
data_female <- sample_n(data_female, 3500)
data_male <- sample_n(data_male, 3500)
random_sample <- rbind(data_male, data_female)
model_random <- naiveBayes(sex ~ age + income + educ, data = random_sample)
print(model_random)
pred_random <- predict(model_random, test, type = "class")
tt <- table(Predicted = pred_random, Actual = test$sex)
print(tt)
misclass(tt)
}
test_random()
# test_random()
# test_random()
# test_random()
|