aboutsummaryrefslogtreecommitdiff
path: root/RModel/main.r
diff options
context:
space:
mode:
Diffstat (limited to 'RModel/main.r')
-rwxr-xr-xRModel/main.r107
1 files changed, 107 insertions, 0 deletions
diff --git a/RModel/main.r b/RModel/main.r
new file mode 100755
index 0000000..4279252
--- /dev/null
+++ b/RModel/main.r
@@ -0,0 +1,107 @@
+#!/usr/bin/env Rscript
+
+my_kmeans <- function(data, k, max_iters = 1000) {
+ data <- as.matrix(na.omit(data))
+ n <- nrow(data)
+ p <- ncol(data)
+
+ centroids <- data[sample(n, k), , drop = FALSE]
+ rownames(centroids) <- 1:k
+
+ clusters <- integer(n)
+ prev_centroids <- matrix(0, nrow = k, ncol = p)
+ iter <- 0
+ converged <- FALSE
+
+ while (iter < max_iters && !converged) {
+ distances <- matrix(0, nrow = n, ncol = k)
+
+ # find distance of cluster centers to a each point
+ for (i in 1:k) {
+ # sweep by cols
+ diff <- sweep(data, 2, centroids[i, ], "-")
+ distances[, i] <- sqrt(rowSums(diff^2))
+ }
+
+ # assign to closest centroid by rows
+ clusters <- apply(distances, 1, which.min)
+
+ prev_centroids <- centroids
+
+ for (i in 1:k) {
+ if (sum(clusters == i) > 0) {
+ centroids[i, ] <- colMeans(data[clusters == i, , drop = FALSE])
+ }
+ }
+
+ converged <- sqrt(sum((centroids - prev_centroids)^2)) < 1e-4
+
+ iter <- iter + 1
+ }
+
+ wss <- 0
+ for (i in 1:k) {
+ cluster_points <- data[clusters == i, , drop = FALSE]
+ if (nrow(cluster_points) > 0) {
+ cluster_center <- centroids[i, ]
+ differences <- sweep(cluster_points, 2, cluster_center, "-")
+ wss <- wss + sum(differences^2)
+ }
+ }
+
+ return(list(
+ clusters = clusters,
+ centers = centroids,
+ withinss = wss,
+ iter = iter,
+ converged = converged
+ ))
+}
+
+
+load("income_elec_state.rdata")
+
+head(income_elec_state)
+
+income_elec_state <- log10(income_elec_state)
+income_elec_state <- income_elec_state[income_elec_state$elec > 2.83, ]
+
+k <- 3
+km <- my_kmeans(income_elec_state, k)
+
+km_centers <- data.frame(km$centers)
+head(km_centers)
+
+if (!require(ggplot2)) install.packages("ggplot2", repos = "https://cran.r-project.org/")
+library(ggplot2)
+
+ggplot(
+ data = income_elec_state,
+ mapping = aes(x = income, y = elec, color = factor(km$cluster)),
+) +
+ labs(x = "income", y = "electricity usage") +
+ geom_point(shape = 1) +
+ geom_point(
+ data = km_centers,
+ mapping = aes(
+ x = income,
+ y = elec,
+ color = factor(rownames(km_centers)),
+ label = NULL
+ ),
+ shape = 13,
+ size = 4
+ )
+
+
+wss <- NULL
+range <- 1:10
+for (i in range) {
+ res <- my_kmeans(income_elec_state, i)
+ wss <- c(wss, res$withinss)
+}
+wss_df <- data.frame(wss)
+ggplot(wss_df, aes(x = range, y = wss)) +
+ geom_path() +
+ geom_point() +
+ scale_x_continuous(breaks = range) \ No newline at end of file