library(rstan)
options(mc.cores=1)
library(brms)
library(rjags)
load.module("glm")
library(doBy)
getPosterior <- function(samples,regexp,unlist=TRUE) {
  if (is.list(samples)) {
    if (unlist) {
      res <- do.call(rbind,lapply(samples,function(res){res[,grep(regexp,colnames(res)),drop=FALSE]}))
    } else {
      res <- lapply(samples,function(res){res[,grep(regexp,colnames(res)),drop=FALSE]})
      class(res) <- class(samples)
    }
  } else {
    res <- samples[,grep(regexp,colnames(samples)),drop=FALSE]
  }
  return(res)
}
odds <- function(p){p/(1-p)}

setwd("C:/Current/EPHOR/Task4.2.3")

load("all.subjects.ever6.Rdata")
str(all.subjects)
covariates <- setdiff(colnames(all.subjects)[!grepl("^isco",colnames(all.subjects))],c("subjctid","status","histotyp","A"))
all.iscos <- colnames(all.subjects)[grepl("^isco_",colnames(all.subjects)) & grepl("_ever$",colnames(all.subjects))]
data <- all.subjects[,c("subjctid","status",covariates,all.iscos)]
str(data)
nrow(data)
ncol(data)

set.seed(100)
ix <- 1:nrow(data)
excl.codes <- NULL
sel.codes <- intersect(setdiff(all.iscos,excl.codes),colnames(data))

Y <- as.numeric(factor(data$status[ix]))-1
X <- apply(data[ix,sel.codes],2,as.numeric)
Z <- model.matrix(as.formula(paste0("~",paste0(covariates,collapse="+"))),data=data)[ix,-1]
Z <- Z[,setdiff(colnames(Z),paste0("time_quit",c(1,2)))]


strsplit(sel.codes[1],"_")
ISCO5 <- sapply(sel.codes,function(curcode){
  floor(as.numeric(strsplit(curcode,"_")[[1]][2]))
})
ISCO3 <- sapply(sel.codes,function(curcode){
  floor(as.numeric(strsplit(curcode,"_")[[1]][2])/100)
})
G <- as.numeric(factor(ISCO3))
max(G)
table(table(G))
t <- table(G)
G <- as.numeric(factor(G,levels=order(t)))
t <- table(G)
table(G==1)
N1 <- max(which(t==1))

range(apply(X,2,sd))
FACTOR <- median(apply(X,2,sd))
X <- X/FACTOR

fname <- paste0("HM_synergy.ST1",".Rdata")
if (!file.exists(fname)) {
  M1 <- glm(Y ~ X + Z, family=binomial) #M1 is a simple glm of all 5D jobs and cov
  B <- coef(M1)
  V <- vcov(M1)
  sel <- grepl("Xisco",names(B))
  B <- B[sel]
  # excl.codes <- sel.codes[is.na(B)]
  V <- V[sel,sel]
  nms <- gsub("_dc","",gsub("Xisco_","",names(B)))
  names(B) <- rownames(V) <- colnames(V) <- nms
  R <- cov2cor(V)
  diag(R) <- NA
  range(R,na.rm=TRUE)
  B1 <- B
  SE1 <- sqrt(diag(V))
  save(M1,B1,SE1,file=fname)
} else {
  load(fname)
}

devAskNewPage(ask=TRUE)
for (mod in c("NORMAL2","RIDGE2","LASSO2","HS2")) {
  # mod <- c("NORMAL2","RIDGE2","LASSO2","HS2")[1]
  stan.data <- list(Nx=length(B1),Ng=max(G),B=B1,SE=SE1,G=G)
  stan.data <- append(stan.data,list(normal_scale=0.5,normal2_scale=0.1)) #normal prior
  stan.data <- append(stan.data,list(ridge_scale=0.5,ridge2_scale=0.1)) #ridge prior
  stan.data <- append(stan.data,list(lasso_df=1,lasso_scale=0.5,lasso2_df=1,lasso2_scale=0.1)) #lasso prior
  stan.data <- append(stan.data,list(hs_df=1,hs_df_global=1,hs_scale_global=odds(1/10)*2*1/sqrt(nrow(X)),hs_df_slab=4,hs_scale_slab=2,
                                     hs2_df=1,hs2_df_global=1,hs2_scale_global=odds(1/10)*2*1/sqrt(nrow(X)),hs2_df_slab=4,hs2_scale_slab=2)) #horseshoe prior (for more on HS priors, see DOI: 10.1214/17-EJS1337SI)
                                 
  str(stan.data)
  Nchains <- 4
  Nburn <- 1000
  Niter <- 3000
  stan.control <- list(adapt_delta=0.95, max_treedepth=25)

  mname <- paste0("HM_synergy_ST2.",mod)
  fname <- paste0(mname,".Rdata")
  if (!file.exists(fname)) {
    M2 <- stan(file=paste0(mname,".stan"),data=stan.data,control=stan.control,chain=Nchains,iter=Niter+Nburn,warmup=Nburn)
    #M2 is second level model with any selected priors (e.g. normal, ridge, lasso, horseshoe)
    save(M2,file=fname)
  } else {
    load(fname)
  }
  res <- summary(M2)$summary
  res <- res[grepl("bX",rownames(res)),]
  if (!FALSE) {
    EXPORT <- cbind(data.frame(ISCO=ISCO5,ISCO3=ISCO3),res)
    rownames(EXPORT) <- NULL
    write.csv(EXPORT,file=paste0(mname,".csv"),row.names=FALSE)
  }
  selected <- sign(res[,"2.5%"])==sign(res[,"97.5%"])
  B2 <- res[,c("mean")]
  SE2 <- res[,c("sd")]
  xlim <- ylim <- range(B2/SE2,B1/SE1)
  plot(B2/SE2,B1/SE1,xlim=xlim,ylim=ylim,xlab="stage-2 coefficients",ylab="stage-1 coefficients",main=mod,col=ifelse(selected,"darkred","lightblue"),pch=19)
  abline(a=0,b=1)
  abline(h=c(-1.96,1.96),v=c(-1.96,1.96),lty=2,col="green")
}
devAskNewPage(ask=FALSE)
