setwd(main_path)

### Import relevant libraries
library(tidyverse)
library(haven)
library(stargazer)
library(xtable)
library(data.table)
library(foreach)
##

#Store Environment
to_keep <- ls()
start_time <- Sys.time()

#Set seed for replication
set.seed(123)

## Defininf all functions needed
elem <- function(m){
  fun <- function(x){
    return(nth(x, n = m))
  }
  return(fun)
}

reducesplit <- function(n, dta){
  dta1 = dta[[n]]
  dta2 = reduce(dta[-n], full_join, by = "rahhidpn")
  return(list(dta1, dta2))
}
#Converts marital status to a binary
trans <- function(num){
  if(!is.na(num) & num<4){return(1)}
  else if(!is.na(num)){return(0)}
  else{return(NA)}
}

#Finds the nearest 
dmin <- function(vec, mat){
  dists <- (t(mat)-as.vector(t(as.matrix(vec))))^2
  return(which.min(colSums(dists)))
}

ctrls <- function(bool,age){return(ifelse(bool, "+frailty_bln52+shltn52", ""))}
cname <- function(age, var, k){
  return(ifelse(var == "C", "", paste0("+k",k,"_cluster_",var,"_",age,age+10)))
}

prepare <- function(dta, x){
  dta <- dta %>% group_by(rahhidpn) %>%
    filter(!(Dead==1 & row_number()==1), !is.na(frailty_bl)) %>% filter(n() >= 5) %>%
    mutate(across(c("shlt", "frailty_bl"), 
                  list(n1 = elem(1),
                       n2 = elem(2),
                       n3 = elem(3),
                       n4 = elem(4),
                       n5 = elem(5)),
                  .names = "{col}{fn}")) %>% 
    filter(row_number()==1) %>% ungroup()
  return(dta)
}

#Prepares a dataframe for regressions
prep <- function(df){
  df <- df %>% group_by(rahhidpn) %>% mutate(hcpl = ifelse(is.na(hcpl), trans(first(remstat)), hcpl), nrshom = ifelse(Dead==1& lag(Dead)==0,renrshom, nrshom)) %>%
    ungroup() %>% filter(agey_e >= Initial_Age)
  df <- df %>% group_by(rahhidpn) %>% mutate(across(c("frailty_bl", "shlt"), list(n52 = elem(1)), .names = "{col}{fn}")) %>% ungroup()
  df <- mutate(df, hhearn = ifelse(is.na(iearnspouse), iearn, iearn+iearnspouse))
  df <- df %>% mutate(surv = 1-Dead, Age1 = agey_e, Age2 = agey_e^2,
                      Age3 = agey_e^3, cohort.f = factor(cohort), 
                      raedegrm.f = factor(raedegrm), race = factor(raracem))
  return(df %>% select("rahhidpn", "nrshom", "frailty_bl", "surv", "Age1", "Age2", "Age3", "cohort.f", "hcpl", "raedegrm.f", "race", "ragender", "hhearn", "frailty_bln52", "shltn52", starts_with("age"),"Initial_Age","Final_Age"))
}

#Clusters two datasets, the training and test sets, at the same time
cluster <- function(var, x, k, age){
  df_reg <- inner_join(data_clustering, x[[2]])
  df_fit <- inner_join(data_clustering, x[[1]])
  subdfr <- df_reg %>% select("rahhidpn")
  subdff <- df_fit %>% select("rahhidpn")
  filteredr <- df_reg %>% select(starts_with(var)) %>% select(!contains("thru"))
  filteredf <- df_fit %>% select(starts_with(var)) %>% select(!contains("thru"))
  cresultr <- kmeans(filteredr, k, iter.max = 50, nstart = 1000)
  vecs <- split(filteredf, f = 1:nrow(filteredf))
  centr <- cresultr$centers
  clusters <- as.data.frame(unlist(lapply(vecs, function(x) dmin(x, centr))))
  colnames(clusters) <- c("col")
  name <- paste0("k",k,"_cluster_", var, "_", age, age+10)
  subdfr <- subdfr %>% mutate(temp = cresultr$cluster) %>% 
    mutate(temp = factor(temp)) %>% rename_at("temp", ~name)
  subdff <- subdff %>% mutate(temp = clusters$col) %>%
    mutate(temp = factor(temp)) %>% rename_at("temp", ~name)
  return(list(subdfr, subdff))
}


fold <- function(x,k, controls, var, regand){ 
  df_c <- cluster(var, x, k, 52)
  df_reg <- as.data.frame(df_c[[1]])
  df_fit <- as.data.frame(df_c[[2]])
  df_reg <- inner_join(df_reg, df_prep1_select, by = "rahhidpn")
  df_fit <- inner_join(df_fit, df_prep1_select, by = "rahhidpn")
  preds <- getpreds(df_reg, df_fit, 52, controls, var, var, k, regand)
  preds2 <-getpreds(df_reg, df_fit, 52, controls, "C", var, k, regand)
  rat<-as.numeric(preds)/as.numeric(preds2)
  return(c(rat))
}

#Clusters data, predicts values of a variable based on regressions, and compares predictions to actual values for each of 10 splits of the dataset, then averages results across the splits.
runitall <- function(k, controls, var, regand){
  
  result<-foreach(n = 1:10, .combine = "c") %do% {
    x <- splits[[n]]
    fold(x,k, controls, var, regand)
  }
  means <- mean(result)
  cat("\014")
  print(k)
  gc()
  return(c(means))
}

#Predicts variables out of sample based on regressions, then compares the predictions to actual values for a given split of the sample and a fixed number of clusters.
getpreds <- function(df, df2, age, controls, var2, var, k, regand){
  dfa<- df %>% filter(Age1<=Final_Age)
  dfa2<- df2 %>% filter(Age1<=Final_Age)
  f <- as.formula(paste0(regand, "~Age1+Age2+Age3+", 
                         ifelse(age == 52 & regand != "nrshom", "cohort.f+", ""),
                         "hcpl+raedegrm.f+race+ragender",
                         ctrls(controls, age), cname(age, var2, k)))
  
  if(regand=="frailty_bl"){
    reg <- lm(f, data = dfa)
  } else {
    reg <- glm(f, data = dfa, family = binomial)
  }
  
  df_part <- dfa2 %>% mutate(test = abs(get(regand)-predict(reg, newdata = (.), type = "response")))
  return(mean(df_part$test, na.rm = TRUE))
}



###################        Keep with sample that used in clustering ####################
#Get Original database used for clustering

cross_val<-1
if (cross_val==1){
  
df <- read_dta("dtafiles/P52_5_Clusters.dta")
data_clustering <- read_dta("dtafiles/data_clustering.dta")

#Get ids
Ind_to_select <- df %>% select(rahhidpn) %>%unique()

#Split individuals from main exercise--> 10 fold
grouping <- Ind_to_select  %>% split(sample(nrow(.), 10, replace = FALSE))
splits <- lapply(1:10, function(x) reducesplit(x, grouping))


#Prepare Data set for predictions
df_prep1 <- prep(df)
#Select those we use for clustering
df_prep1_select <- df_prep1 %>% filter(rahhidpn %in% Ind_to_select$rahhidpn)

## Get clustering structure [rahhidpn,frailty,F_frailty,F2_frailty,F3_frailty,F4_frailty]
## Rename variables
data_clustering <- data_clustering %>%
  rename(frailty_bl = frailty,
         frailty_bl_F = F_frailty,
         frailty_bl_F2 = F2_frailty,
         frailty_bl_F3 = F3_frailty,
         frailty_bl_F4 = F4_frailty)


finalgraph <- function(controls, var, regand){
  # Set up a parallel cluster
  datamat <- lapply(2:15, function(x) runitall(x, controls, var, regand))
  datamat <- as.data.frame(unlist(datamat)) %>% mutate(numclust = rep(2:15))
  name<-paste0("output/Part2_output/Part2_c_number_cluster/datamat_15_",var,"_",regand,ifelse(controls,1,0),".csv")
  write.csv(datamat, file = name, row.names = FALSE)
}

setwd(main_path)

## Run Cross-validation exercise: It takes a few minutes
finalgraph(TRUE, "frailty_bl", "frailty_bl")
finalgraph(FALSE, "frailty_bl", "frailty_bl")
finalgraph(TRUE, "frailty_bl", "surv")
finalgraph(FALSE, "frailty_bl", "surv")
}

################################################################################
## Gather Results and Export Figures
################################################################################

################################################################################
## Appendix C: Figure 7
################################################################################
setwd(main_path)
Da<-read.csv("output/Part2_output/Part2_c_number_cluster/datamat_15_frailty_bl_frailty_bl0.csv")
Da$inv <- 1-Da$unlist.datamat.
ggplot(aes(x = numclust, y = get("inv")), data = Da) + 
  geom_line(size = 1.15) +  # Increase the line thickness
  ylab("Predictive Power of Health Types") + 
  xlab("Number of Clusters") +
  labs(col = "Age") + 
  scale_y_continuous(breaks = c(0, 0.2, 0.4, 0.6, 0.8, 1.0), limits = c(0, 1)) +
  scale_x_continuous(breaks = unique(Da$numclust)) +
  theme_bw() +
  theme(axis.title = element_text(size = 16),  # Increase axis label font size
        axis.text = element_text(size = 14),  # Increase axis tick label font size
        panel.grid.major.x = element_line(color = "gray", size = 0.2),  # Remove major x-axis gridlines
        panel.grid.minor.x = element_blank(),
        panel.grid.major.y = element_line(color = "gray", size = 0.1),  # Remove major x-axis gridlines
        panel.grid.minor.y = element_blank())+
  # Add transparent blue bar zone between x=4 and x=6
  geom_rect(xmin = 4, xmax = 6, ymin = 0, ymax = 1, fill = "gray", alpha = 0.03) +
  
  # Add transparent blue bar zone between x=6 and x=10
  geom_rect(xmin = 6, xmax = 10, ymin = 0, ymax = 1, fill = "gray", alpha = 0.01)
# Customize minor x-axis gridlines
ggsave(filename = "fig7-1.pdf", path = "output/Part2_output/Part2_c_number_cluster/", device = "pdf", width = 6, height = 6)

#Conditional EPS export
if (eps == 1) {
  ggsave(filename = "fig7-1.eps", path = "output/Part2_output/Part2_c_number_cluster/", device = cairo_ps, width = 6, height = 6, dpi = 1000)
  
}



Da<-read.csv("output/Part2_output/Part2_c_number_cluster/datamat_15_frailty_bl_surv0.csv")
Da$inv <- 1-Da$unlist.datamat.
ggplot(aes(x = numclust, y = get("inv")), data = Da) + 
  geom_line(size = 1.15) +  # Increase the line thickness
  ylab("Predictive Power of Health Types") + 
  xlab("Number of Clusters") +
  labs(col = "Age") + 
  scale_y_continuous(breaks = c(0, 0.2, 0.4, 0.6, 0.8, 1.0), limits = c(0, 1)) +
  scale_x_continuous(breaks = unique(Da$numclust)) +
  theme_bw() +
  theme(axis.title = element_text(size = 16),  # Increase axis label font size
        axis.text = element_text(size = 14),  # Increase axis tick label font size
        panel.grid.major.x = element_line(color = "gray", size = 0.2),  # Remove major x-axis gridlines
        panel.grid.minor.x = element_blank(),
        panel.grid.major.y = element_line(color = "gray", size = 0.1),  # Remove major x-axis gridlines
        panel.grid.minor.y = element_blank()) +
  # Add transparent blue bar zone between x=4 and x=6
  geom_rect(xmin = 4, xmax = 6, ymin = 0, ymax = 1, fill = "gray", alpha = 0.03) +
  
  # Add transparent blue bar zone between x=6 and x=10
  geom_rect(xmin = 6, xmax = 10, ymin = 0, ymax = 1, fill = "gray", alpha = 0.01)

# Customize minor x-axis gridlines
ggsave(filename = "fig7-2.pdf", path = "output/Part2_output/Part2_c_number_cluster/", device = "pdf", width = 6, height = 6)
#Conditional EPS export
if (eps == 1) {
  ggsave(filename = "fig7-2.eps", path = "output/Part2_output/Part2_c_number_cluster/", device = cairo_ps, width = 6, height = 6, dpi = 1000)
  
}

################################################################################
## Appendix C:  Figure 8
################################################################################

Da<-read.csv("output/Part2_output/Part2_c_number_cluster/datamat_15_frailty_bl_frailty_bl1.csv")
Da$inv <- 1-Da$unlist.datamat.
ggplot(aes(x = numclust, y = get("inv")), data = Da) + 
  geom_line(size = 1.15) +  # Increase the line thickness
  ylab("Predictive Power of Health Types") + 
  xlab("Number of Clusters") +
  labs(col = "Age") + 
  scale_y_continuous(breaks = c(0, 0.2, 0.4, 0.6, 0.8, 1.0), limits = c(0, 1)) +
  scale_x_continuous(breaks = unique(Da$numclust)) +
  theme_bw() +
  theme(axis.title = element_text(size = 16),  # Increase axis label font size
        axis.text = element_text(size = 14),  # Increase axis tick label font size
        panel.grid.major.x = element_line(color = "gray", size = 0.2),  # Remove major x-axis gridlines
        panel.grid.minor.x = element_blank(),
        panel.grid.major.y = element_line(color = "gray", size = 0.1),  # Remove major x-axis gridlines
        panel.grid.minor.y = element_blank()) + 
  # Add transparent blue bar zone between x=4 and x=6
  geom_rect(xmin = 4, xmax = 6, ymin = 0, ymax = 1, fill = "gray", alpha = 0.03) +
  
  # Add transparent blue bar zone between x=6 and x=10
  geom_rect(xmin = 6, xmax = 10, ymin = 0, ymax = 1, fill = "gray", alpha = 0.01)
# Customize minor x-axis gridlines
ggsave(filename = "fig8-1.pdf", path = "output/Part2_output/Part2_c_number_cluster/", device = "pdf", width = 6, height = 6)
#Conditional EPS export
if (eps == 1) {
  ggsave(filename = "fig8-1.eps", path = "output/Part2_output/Part2_c_number_cluster/", device = cairo_ps, width = 6, height = 6, dpi = 1000)
  
}

Da<-read.csv("output/Part2_output/Part2_c_number_cluster/datamat_15_frailty_bl_surv1.csv")
Da$inv <- 1-Da$unlist.datamat.
ggplot(aes(x = numclust, y = get("inv")), data = Da) + 
  geom_line(size = 1.15) +  # Increase the line thickness
  ylab("Predictive Power of Health Types") + 
  xlab("Number of Clusters") +
  labs(col = "Age") + 
  scale_y_continuous(breaks = c(0, 0.2, 0.4, 0.6, 0.8, 1.0), limits = c(0, 1)) +
  scale_x_continuous(breaks = unique(Da$numclust)) +
  theme_bw() +
  theme(axis.title = element_text(size = 16),  # Increase axis label font size
        axis.text = element_text(size = 14),  # Increase axis tick label font size
        panel.grid.major.x = element_line(color = "gray", size = 0.2),  # Remove major x-axis gridlines
        panel.grid.minor.x = element_blank(),
        panel.grid.major.y = element_line(color = "gray", size = 0.1),  # Remove major x-axis gridlines
        panel.grid.minor.y = element_blank()) + 
  # Add transparent blue bar zone between x=4 and x=6
  geom_rect(xmin = 4, xmax = 6, ymin = 0, ymax = 1, fill = "gray", alpha = 0.03) +
  
  # Add transparent blue bar zone between x=6 and x=10
  geom_rect(xmin = 6, xmax = 10, ymin = 0, ymax = 1, fill = "gray", alpha = 0.01)
# Customize minor x-axis gridlines
ggsave(filename = "fig8-2.pdf", path = "output/Part2_output/Part2_c_number_cluster/", device = "pdf", width = 6, height = 6)
#Conditional EPS export
if (eps == 1) {
  ggsave(filename = "fig8-2.eps", path = "output/Part2_output/Part2_c_number_cluster/", device = cairo_ps, width = 6, height = 6, dpi = 1000)
  
}


end_time <- Sys.time()
runtime <- end_time-start_time
print(runtime)

#Clear enviroment
rm(list = setdiff(ls(), c(to_keep)))
