aboutsummaryrefslogtreecommitdiff
path: root/RModel/main.r
blob: 42792520f7eee237ea55aa67f9eeef30d1a1b08e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#!/usr/bin/env Rscript

my_kmeans <- function(data, k, max_iters = 1000) {
  data <- as.matrix(na.omit(data))
  n <- nrow(data)
  p <- ncol(data)

  centroids <- data[sample(n, k), , drop = FALSE]
  rownames(centroids) <- 1:k

  clusters <- integer(n)
  prev_centroids <- matrix(0, nrow = k, ncol = p)
  iter <- 0
  converged <- FALSE

  while (iter < max_iters && !converged) {
    distances <- matrix(0, nrow = n, ncol = k)

    # find distance of cluster centers to a each point
    for (i in 1:k) {
      # sweep by cols
      diff <- sweep(data, 2, centroids[i, ], "-")
      distances[, i] <- sqrt(rowSums(diff^2))
    }

    # assign to closest centroid by rows
    clusters <- apply(distances, 1, which.min)

    prev_centroids <- centroids

    for (i in 1:k) {
      if (sum(clusters == i) > 0) {
        centroids[i, ] <- colMeans(data[clusters == i, , drop = FALSE])
      }
    }

    converged <- sqrt(sum((centroids - prev_centroids)^2)) < 1e-4

    iter <- iter + 1
  }

  wss <- 0
  for (i in 1:k) {
    cluster_points <- data[clusters == i, , drop = FALSE]
    if (nrow(cluster_points) > 0) {
      cluster_center <- centroids[i, ]
      differences <- sweep(cluster_points, 2, cluster_center, "-")
      wss <- wss + sum(differences^2)
    }
  }

  return(list(
    clusters = clusters,
    centers = centroids,
    withinss = wss,
    iter = iter,
    converged = converged
  ))
}


load("income_elec_state.rdata")

head(income_elec_state)

income_elec_state <- log10(income_elec_state)
income_elec_state <- income_elec_state[income_elec_state$elec > 2.83, ]

k <- 3
km <- my_kmeans(income_elec_state, k)

km_centers <- data.frame(km$centers)
head(km_centers)

if (!require(ggplot2)) install.packages("ggplot2", repos = "https://cran.r-project.org/")
library(ggplot2)

ggplot(
  data = income_elec_state,
  mapping = aes(x = income, y = elec, color = factor(km$cluster)),
) +
  labs(x = "income", y = "electricity usage") +
  geom_point(shape = 1) +
  geom_point(
    data = km_centers,
    mapping = aes(
      x = income,
      y = elec,
      color = factor(rownames(km_centers)),
      label = NULL
    ),
    shape = 13,
    size = 4
  )


wss <- NULL
range <- 1:10
for (i in range) {
  res <- my_kmeans(income_elec_state, i)
  wss <- c(wss, res$withinss)
}
wss_df <- data.frame(wss)
ggplot(wss_df, aes(x = range, y = wss)) +
  geom_path() +
  geom_point() +
  scale_x_continuous(breaks = range)