# =====================================================================
# Toy gene-expression clustering, end to end (no file I/O):
#   PART 1 - simulate 100 genes x 3 time points from 3 cluster-specific
#            trivariate normal distributions (sizes 50 / 30 / 20).
#   PART 2 - fit a finite Gaussian mixture model (mclust) and check
#            recovery against the true cluster labels.
# The simulated data are passed directly in memory (no CSV write/read).
# =====================================================================

library(MASS)    # mvrnorm()
library(mclust)  # Mclust GMM + adjustedRandIndex()
set.seed(123)

# =====================================================================
# PART 1.  Simulate the toy dataset
# =====================================================================
time_points <- c("t1", "t2", "t3")   # 3 expression measures over time

# --- A DIFFERENT trivariate normal for each cluster ---
# Mean vectors (temporal expression profiles)
mu1 <- c(2, 5, 8)   # cluster 1: rising over time
mu2 <- c(8, 5, 2)   # cluster 2: falling over time
mu3 <- c(3, 9, 3)   # cluster 3: transient peak at t2

# Covariance matrices (different magnitude / correlation structure)
Sigma1 <- matrix(c(1.0, 0.6, 0.3,
                   0.6, 1.0, 0.6,
                   0.3, 0.6, 1.0), nrow = 3, byrow = TRUE)   # mild AR-like
Sigma2 <- matrix(c(1.5, -0.4,  0.1,
                  -0.4,  1.5, -0.4,
                   0.1, -0.4,  1.5), nrow = 3, byrow = TRUE) # negative lag-1 corr
Sigma3 <- matrix(c(0.8, 0.2, 0.5,
                   0.2, 2.0, 0.2,
                   0.5, 0.2, 0.8), nrow = 3, byrow = TRUE)   # high variance at t2

n1 <- 50; n2 <- 30; n3 <- 20

expr <- rbind(mvrnorm(n1, mu = mu1, Sigma = Sigma1),
              mvrnorm(n2, mu = mu2, Sigma = Sigma2),
              mvrnorm(n3, mu = mu3, Sigma = Sigma3))
colnames(expr) <- time_points

dat <- data.frame(
  gene_id = sprintf("gene%03d", 1:100),
  cluster = factor(rep(c(1, 2, 3), times = c(n1, n2, n3))),
  expr,
  row.names = NULL
)

cat("===== Simulated dataset =====\n")
cat("Dimensions:", nrow(dat), "genes x", length(time_points), "time points\n")
cat("Cluster sizes:\n"); print(table(dat$cluster))
cat("\nObserved per-cluster mean profiles (should approximate mu1/mu2/mu3):\n")
cl_means <- aggregate(cbind(t1, t2, t3) ~ cluster, data = dat, FUN = mean)
cl_means[, -1] <- round(cl_means[, -1], 2)
print(cl_means)

# =====================================================================
# PART 2.  Fit a Gaussian mixture model and check recovery
# =====================================================================
X     <- as.matrix(dat[, time_points])   # 100 x 3 feature matrix
truth <- dat$cluster                      # true cluster labels (1/2/3)

# Mclust searches over the number of components G and over covariance
# parameterizations, selecting the best model by BIC. We allow G = 1..6
# so model selection -- not us -- decides the number of clusters.
fit <- Mclust(X, G = 1:6)

cat("\n===== Model selected by BIC =====\n")
cat("Number of components (G):", fit$G, "\n")
cat("Covariance model        :", fit$modelName, "\n")
cat("BIC                     :", round(fit$bic, 2), "\n\n")

cat("Estimated mixing proportions:\n")
print(round(fit$parameters$pro, 3))

cat("\nEstimated component means (rows = t1,t2,t3 ; cols = components):\n")
print(round(fit$parameters$mean, 2))

# --- Compare recovered clusters with the TRUE labels ---
pred <- fit$classification

cat("\n===== Cluster recovery =====\n")
cat("Cross-tabulation (rows = true cluster, cols = GMM cluster):\n")
print(table(true = truth, predicted = pred))

# Adjusted Rand Index: 1 = perfect agreement, 0 = random labeling.
ari <- adjustedRandIndex(truth, pred)
cat(sprintf("\nAdjusted Rand Index (true vs GMM): %.3f\n", ari))

# --- Per-gene assignments + classification uncertainty (kept in memory) ---
out <- data.frame(
  gene_id      = dat$gene_id,
  true_cluster = truth,
  gmm_cluster  = pred,
  uncertainty  = round(fit$uncertainty, 4)   # 1 - max posterior prob
)
cat("\nFirst rows of per-gene assignments:\n")
print(head(out))

# Optional diagnostic plot (BIC across G / covariance models)
plot(fit, what = "BIC")
