#!/usr/bin/env Rscript my_kmeans <- function(data, k, max_iters = 1000) { data <- as.matrix(na.omit(data)) n <- nrow(data) p <- ncol(data) centroids <- data[sample(n, k), , drop = FALSE] rownames(centroids) <- 1:k clusters <- integer(n) prev_centroids <- matrix(0, nrow = k, ncol = p) iter <- 0 converged <- FALSE while (iter < max_iters && !converged) { distances <- matrix(0, nrow = n, ncol = k) # find distance of cluster centers to a each point for (i in 1:k) { # sweep by cols diff <- sweep(data, 2, centroids[i, ], "-") distances[, i] <- sqrt(rowSums(diff^2)) } # assign to closest centroid by rows clusters <- apply(distances, 1, which.min) prev_centroids <- centroids for (i in 1:k) { if (sum(clusters == i) > 0) { centroids[i, ] <- colMeans(data[clusters == i, , drop = FALSE]) } } converged <- sqrt(sum((centroids - prev_centroids)^2)) < 1e-4 iter <- iter + 1 } wss <- 0 for (i in 1:k) { cluster_points <- data[clusters == i, , drop = FALSE] if (nrow(cluster_points) > 0) { cluster_center <- centroids[i, ] differences <- sweep(cluster_points, 2, cluster_center, "-") wss <- wss + sum(differences^2) } } return(list( clusters = clusters, centers = centroids, withinss = wss, iter = iter, converged = converged )) } load("income_elec_state.rdata") head(income_elec_state) income_elec_state <- log10(income_elec_state) income_elec_state <- income_elec_state[income_elec_state$elec > 2.83, ] k <- 3 km <- my_kmeans(income_elec_state, k) km_centers <- data.frame(km$centers) head(km_centers) if (!require(ggplot2)) install.packages("ggplot2", repos = "https://cran.r-project.org/") library(ggplot2) ggplot( data = income_elec_state, mapping = aes(x = income, y = elec, color = factor(km$cluster)), ) + labs(x = "income", y = "electricity usage") + geom_point(shape = 1) + geom_point( data = km_centers, mapping = aes( x = income, y = elec, color = factor(rownames(km_centers)), label = NULL ), shape = 13, size = 4 ) wss <- NULL range <- 1:10 for (i in range) { res <- my_kmeans(income_elec_state, i) wss <- c(wss, res$withinss) } wss_df <- data.frame(wss) ggplot(wss_df, aes(x = range, y = wss)) + geom_path() + geom_point() + scale_x_continuous(breaks = range)