
#' Linear mixed models for data matrix
#'
#' Fits many linear mixed effects models for analysis of gaussian data with
#' random effects, with parallelisation and optimisation for speed. It is
#' suitable for longitudinal analysis of high dimensional data. Wald type 2
#' Chi-squared test is used to calculate p-values.
#'
#' @param modelFormula the model formula. This must be of the form `"~ ..."`
#'   where the structure is assumed to be `"gene ~ ..."`. The formula must
#'   include a random effects term. See formula structure for random effects in
#'   \code{\link[lme4:lmer]{lme4::lmer()}}
#' @param maindata data matrix with genes in rows and samples in columns
#' @param metadata a dataframe of sample information with variables in columns
#'   and samples in rows
#' @param id Optional. Used to specify the column in metadata which contains the
#'   sample IDs to be used in repeated samples for random effects. If not
#'   specified, the function defaults to using the variable after the "|" in the
#'   random effects term in the formula.
#' @param offset Vector containing model offsets (default = NULL). If provided
#'   the `lmer()` offset is set to `offset`. See
#'   \code{\link[lme4:lmer]{lme4::lmer()}}
#' @param test.stat Character value specifying test statistic. Current options
#'   are "Wald" for type 2 Wald Chi square test using code derived and modified
#'   from [car::Anova] to improve speed for matrix tests. Or "F" for conditional
#'   F tests using Saiterthwaite's method of approximated Df. This uses
#'   [lmerTest::lmer] and is somewhat slower.
#' @param reducedFormula Optional design formula without random effects. If not
#'   given, it is automatically generated by removing the random effects from
#'   the main formula. Used to calculate confidence intervals for final fitted
#'   models on each gene for plotting purposes.
#' @param modelData Optional dataframe. Default is generated by call to
#'   `expand.grid` using levels of variables in the formula. Used to calculate
#'   model predictions (estimated means & 95% CI) for plotting via [modelPlot].
#'   It can therefore be used to add/remove points in [modelPlot].
#' @param designMatrix Optional custom design matrix generated by call to
#'   `model.matrix` using `modelData` and `reducedFormula`. Used to calculate
#'   model predictions for plotting.
#' @param control the `lmer` optimizer control (default = `lmerControl()`). See
#'   \code{\link[lme4:lmerControl]{lme4::lmerControl()}}.
#' @param cores number of cores to use for parallelisation. Default = 1. 
#' @param removeSingles whether to remove individuals with no repeated measures
#'   (default = FALSE)
#' @param verbose Logical whether to display messaging (default = TRUE)
#' @param returnList Logical whether to return results as a list or lmmSeq 
#' object (default = FALSE). Helpful for debugging.
#' @param progress Logical whether to display a progress bar
#' @param ... Other parameters passed to \code{\link[lme4:lmer]{lme4::lmer()}}
#' @return Returns an S4 class `lmmSeq` object with results for gene-wise
#'   linear mixed models; or a list of results if `returnList` is `TRUE`.
#'   
#' @details
#' Two key methods are used to speed up computation above and beyond simple
#' parallelisation. The first is to speed up [lme4::lmer()] by calling
#' [lme4::lFormula] once at the start and then updating the `lFormula` output
#' with new data. The 2nd speed up is through optimised code for repeated type 2
#' Wald Chi-squared tests (original code was derived from [car::Anova]). For
#' example, elements such as the hypothesis matrices are generated only once to
#' reduce unnecessarily repetitive computation, and the generation of p-values
#' from Chi-squared values is vectorised and performed at the end. F-tests using
#' the `lmerTest` package have not been optimised and are therefore slower.
#' 
#' Parallelisation is performed using [parallel::mclapply] on unix/mac and
#' [parallel::parLapply] on windows. Progress bars use [pbmcapply::pbmclapply]
#' on unix/mac and [pbapply::pblapply] on windows.
#'   
#' @importFrom lme4 subbars findbars fixef lmerControl nobars isSingular
#'   lFormula
#' @importFrom lmerTest lmer
#' @importFrom parallel mclapply detectCores parLapply makeCluster clusterEvalQ
#'   clusterExport stopCluster
#' @importFrom pbmcapply pbmclapply
#' @importFrom pbapply pblapply
#' @importFrom methods slot new
#' @importFrom stats AIC complete.cases logLik reshape terms vcov pchisq
#'   update.formula model.matrix predict setNames anova coef
#' @export
#' @examples
#' data(PEAC_minimal_load)
#' logtpm <- log2(tpm +1)
#' lmmtest <- lmmSeq(~ Timepoint * EULAR_6m + (1 | PATID),
#'                      maindata = logtpm[1:2, ],
#'                      metadata = metadata,
#'                      verbose = FALSE)
#' names(attributes(lmmtest))


lmmSeq <- function(modelFormula,
                   maindata,
                   metadata,
                   id = NULL,
                   offset = NULL,
                   test.stat = c("Wald", "F"),
                   reducedFormula = "",
                   modelData = NULL,
                   designMatrix = NULL,
                   control = lmerControl(),
                   cores = 1,
                   removeSingles = FALSE,
                   verbose = TRUE,
                   returnList = FALSE, 
                   progress = FALSE,
                   ...) {
  lmmcall <- match.call(expand.dots = TRUE)
  test.stat <- match.arg(test.stat)
  maindata <- as.matrix(maindata)
  # Catch errors
  if (length(findbars(modelFormula)) == 0) {
    stop("No random effects terms specified in formula")
  }
  if (ncol(maindata) != nrow(metadata)) {
    stop("maindata columns different size to metadata rows")
  }
  if (!is.null(offset) & ncol(maindata) != length(offset)) {
    stop("Different offset length")
  }
  
  # Manipulate formulae
  fullFormula <- update.formula(modelFormula, gene ~ ., simplify = FALSE)
  nonRandomFormula <- subbars(modelFormula)
  variables <- rownames(attr(terms(nonRandomFormula), "factors"))
  subsetMetadata <- metadata[, variables]
  if (is.null(id)) {
    fb <- findbars(modelFormula)
    id <- sub(".*[|]", "", fb)
    id <- gsub(" ", "", id)
  }
  ids <- as.character(metadata[, id])
  
  # Option to subset to remove unpaired samples
  if (removeSingles) {
    nonSingle <- names(table(ids))[table(ids) > 1]
    pairedIndex <- ids %in% nonSingle
    maindata <- maindata[, pairedIndex]
    subsetMetadata <- subsetMetadata[pairedIndex, ]
    ids <- ids[pairedIndex]
    offset <- offset[pairedIndex]
  }
  
  if (verbose) cat(paste0("\nn = ", length(ids), " samples, ",
                          length(unique(ids)), " individuals\n"))
  
  # setup model prediction
  if (reducedFormula == "") reducedFormula <- nobars(modelFormula)
  if (is.null(modelData)) {
    reducedVars <- rownames(attr(terms(reducedFormula), "factors"))
    varLevels <- lapply(reducedVars, function(x) {
      if (is.factor(metadata[, x])) {
        return(levels(subsetMetadata[, x]))
      } else {sort(unique(subsetMetadata[, x]))}
    })
    modelData <- expand.grid(varLevels)
    colnames(modelData) <- reducedVars
  }
  
  if (is.null(designMatrix)){
    designMatrix <- model.matrix(reducedFormula, modelData)
  }
  
  start <- Sys.time()
  fullList <- lapply(rownames(maindata), function(i) maindata[i, ])
  
  if (test.stat == "Wald") {
    # Adapted from lme4::modular / lme4::lmer
    subsetMetadata$gene <- maindata[1, ]
    lmod <- lFormula(fullFormula, subsetMetadata,
                     offset = offset, control = control, ...)
    
    hyp.matrix <- hyp_matrix(fullFormula, metadata, "gene")
    
    # For each gene perform a fit
    # lmerFast
    if (Sys.info()["sysname"] == "Windows" & cores > 1) {
      cl <- makeCluster(cores)
      clusterExport(cl, varlist = c("lmerFast",
                                    "lmod", "control", "modelData",
                                    "designMatrix",
                                    "hyp.matrix"),
                    envir = environment())
      if (progress) {
        resultList <- pblapply(fullList, function(geneRow) {
          lmerFast(geneRow, lmod, control,
                   modelData, designMatrix, hyp.matrix)
        }, cl = cl)
      } else {
        resultList <- parLapply(cl = cl, fullList, function(geneRow) {
          lmerFast(geneRow, lmod, control,
                   modelData, designMatrix, hyp.matrix)
        })
      }
      stopCluster(cl)
    } else{
      if (progress) {
        resultList <- pbmclapply(fullList, function(geneRow) {
          lmerFast(geneRow, lmod, control,
                   modelData, designMatrix, hyp.matrix)
        }, mc.cores = cores)
        if ("value" %in% names(resultList)) resultList <- resultList$value
      } else {
        resultList <- mclapply(fullList, function(geneRow) {
          lmerFast(geneRow, lmod, control,
                   modelData, designMatrix, hyp.matrix)
        }, mc.cores = cores)
      }
    }
    
  } else {
    # lmerTest
    if (Sys.info()["sysname"] == "Windows" & cores > 1) {
      cl <- makeCluster(cores)
      clusterExport(cl, varlist = c("lmerTestCore", "fullList", "fullFormula",
                                    "subsetMetadata", "control", "modelData",
                                    "offset", "designMatrix", ...),
                    envir = environment())
      if (progress) {
        resultList <- pblapply(fullList, function(geneRow) {
          lmerTestCore(geneRow, fullFormula = fullFormula, data = subsetMetadata,
                   control = control, modelData = modelData, offset = offset,
                   designMatrix = designMatrix, ...)
        }, cl = cl)
      } else {
        resultList <- parLapply(cl = cl, fullList, function(geneRow) {
          lmerTestCore(geneRow, fullFormula = fullFormula, data = subsetMetadata,
                   control = control, modelData = modelData, offset = offset,
                   designMatrix = designMatrix, ...)
        })
      }
      stopCluster(cl)
    } else{
      if (progress) {
        resultList <- pbmclapply(fullList, function(geneRow) {
          lmerTestCore(geneRow, fullFormula = fullFormula, data = subsetMetadata,
                   control = control, modelData = modelData, offset = offset,
                   designMatrix = designMatrix, ...)
        }, mc.cores = cores)
        if ("value" %in% names(resultList)) resultList <- resultList$value
      } else {
        resultList <- mclapply(fullList, function(geneRow) {
          lmerTestCore(geneRow, fullFormula = fullFormula, data = subsetMetadata,
                   control = control, modelData = modelData, offset = offset,
                   designMatrix = designMatrix, ...)
        }, mc.cores = cores)
      }
    }
  }
  
  if(returnList) return(resultList)
  
  # Print timing if verbose
  end <- Sys.time()
  if (verbose) print(end - start)
  
  # Output
  names(resultList) <- rownames(maindata)
  noErr <- vapply(resultList, function(x) x$tryErrors == "", FUN.VALUE = TRUE)
  if (length(which(noErr)) == 0) { 
    stop("All genes returned an error. Check sufficient data in each group")
  }
  
  predList <- lapply(resultList[noErr], "[[", "predict")
  outputPredict <- do.call(rbind, predList)
  
  outLabels <- apply(modelData, 1, function(x) paste(x, collapse = "_"))
  colnames(outputPredict) <- c(paste0("y_", outLabels),
                               paste0("LCI_", outLabels),
                               paste0("UCI_", outLabels))
  
  if (sum(!noErr) != 0) {
    if (verbose) cat(paste0("Errors in ", sum(!noErr), " gene(s): ",
                            paste0(names(noErr)[! noErr], collapse = ", ")))
    outputErrors <- vapply(resultList[!noErr], function(x) {x$tryErrors},
                           FUN.VALUE = character(1))
  } else {outputErrors <- c("No errors")}
  
  optInfo <- t(vapply(resultList[noErr], function(x) {
    setNames(x$optinfo, c("Singular", "Conv"))
  }, FUN.VALUE = c(1, 1)))
  
  s <- organiseStats(resultList[noErr], test.stat)
  meanExp <- rowMeans(maindata[noErr, , drop = FALSE])
  s$res <- cbind(s$res, meanExp)
  
  # Create lmmSeq object with results
  new("lmmSeq",
      info = list(call = lmmcall,
                  offset = offset,
                  designMatrix = designMatrix,
                  control = substitute(control),
                  test.stat = test.stat),
      formula = fullFormula,
      stats = s,
      predict = outputPredict,
      reducedFormula = reducedFormula,
      maindata = maindata,
      metadata = subsetMetadata,
      modelData = modelData,
      optInfo = optInfo,
      errors = outputErrors,
      vars = list(id = id,
                  removeSingles = removeSingles)
  )
}


## see lme4::modular
#' @importFrom lme4 mkLmerDevfun optimizeLmer checkConv mkMerMod
#' 
lmerFast <- function(geneRow,
                     lmod,
                     control,
                     modelData,
                     designMatrix,
                     hyp.matrix) {
  lmod$fr$gene <- geneRow
  devfun <- do.call(mkLmerDevfun, c(lmod, list(control=control)))
  opt <- optimizeLmer(devfun,
                      optimizer = control$optimizer,
                      restart_edge = control$restart_edge,
                      boundary.tol = control$boundary.tol,
                      control = control$optCtrl,
                      calc.derivs=control$calc.derivs,
                      use.last.params=control$use.last.params)
  cc <- try(suppressMessages(suppressWarnings(
    checkConv(attr(opt,"derivs"), opt$par,
              ctrl = control$checkConv,
              lbound = environment(devfun)$lower)
  )), silent = TRUE)
  fit <- mkMerMod(environment(devfun), opt, lmod$reTrms, fr = lmod$fr,
                  lme4conv=cc)
  
  if (!inherits(fit, "try-error")) {
    # intercept dropped genes
    if (length(attr(fit@pp$X, "msgRankdrop")) > 0)  {
      return( list(stats = NA, predict = NA, optinfo = NA,
                   tryErrors = attr(fit@pp$X, "msgRankdrop")) )
    }
    stats <- setNames(c(AIC(fit), as.numeric(logLik(fit))),
                      c("AIC", "logLik"))
    fixedEffects <- lme4::fixef(fit)
    stdErr <- coef(summary(fit))[, 2]
    vcov. <- suppressWarnings(vcov(fit, complete = FALSE))
    vcov. <- as.matrix(vcov.)
    waldtest <- lmer_wald(fixedEffects, hyp.matrix, vcov.)
    
    newY <- predict(fit, newdata = modelData, re.form = NA)
    a <- designMatrix %*% vcov.
    b <- as.matrix(a %*% t(designMatrix))
    predVar <- diag(b)
    newSE <- sqrt(predVar)
    newLCI <- newY - newSE * 1.96
    newUCI <- newY + newSE * 1.96
    predictdf <- c(newY, newLCI, newUCI)
    singular <- as.numeric(lme4::isSingular(fit))
    conv <- length(slot(fit, "optinfo")$conv$lme4$messages)
    ret <- list(stats = stats,
                coef = fixedEffects,
                stdErr = stdErr,
                chisq = waldtest$chisq,
                df = waldtest$df,
                predict = predictdf,
                optinfo = c(singular, conv),
                tryErrors = "")
    return(ret)
  } else {
    return(list(stats = NA, coef = NA, stdErr = NA, chisq = NA, df = NA, 
                predict = NA, optinfo = NA, tryErrors = fit[1]))
  }
}


lmerTestCore <- function(geneRow,
                         fullFormula,
                         data,
                         control,
                         modelData,
                         designMatrix,
                         offset,
                         ...) {
  data[, "gene"] <- geneRow
  fit <- try(suppressMessages(suppressWarnings(
    lmerTest::lmer(fullFormula, data = data, control = control, offset = offset,
                   ...))),
    silent = TRUE)
  if (!inherits(fit, "try-error")) {
    # intercept dropped genes
    if (length(attr(fit@pp$X, "msgRankdrop")) > 0)  {
      return( list(stats = NA, predict = NA, optinfo = NA,
                   tryErrors = attr(fit@pp$X, "msgRankdrop")) )
    }
    stats <- setNames(c(AIC(fit), as.numeric(logLik(fit))),
                      c("AIC", "logLik"))
    fixedEffects <- lme4::fixef(fit)
    stdErr <- coef(summary(fit))[, 2]
    vcov. <- suppressWarnings(vcov(fit, complete = FALSE))
    vcov. <- as.matrix(vcov.)
    Ftest <- as.matrix(anova(fit)[, -c(1,2)])
    
    newY <- predict(fit, newdata = modelData, re.form = NA)
    a <- designMatrix %*% vcov.
    b <- as.matrix(a %*% t(designMatrix))
    predVar <- diag(b)
    newSE <- sqrt(predVar)
    newLCI <- newY - newSE * 1.96
    newUCI <- newY + newSE * 1.96
    predictdf <- c(newY, newLCI, newUCI)
    singular <- as.numeric(lme4::isSingular(fit))
    conv <- length(slot(fit, "optinfo")$conv$lme4$messages)
    rm(fit, data)
    ret <- list(stats = stats,
                coef = fixedEffects,
                stdErr = stdErr,
                Ftest = Ftest,
                predict = predictdf,
                optinfo = c(singular, conv),
                tryErrors = "")
    return(ret)
  } else {
    return(list(stats = NA, coef = NA, stdErr = NA, Ftest = NA,
                predict = NA, optinfo = NA, tryErrors = fit[1]))
  }
}


organiseStats <- function(resultList, test.stat) {
  statsList <- lapply(resultList, "[[", "stats")
  s <- do.call(rbind, statsList)
  coefList <- lapply(resultList, "[[", "coef")
  cf <- do.call(rbind, coefList)
  SEList <- lapply(resultList, "[[", "stdErr")
  stdErr <- do.call(rbind, SEList)
  if (test.stat == "Wald") {
    chisqList <- lapply(resultList, "[[", "chisq")
    chisq <- do.call(rbind, chisqList)
    dfList <- lapply(resultList, "[[", "df")
    df <- do.call(rbind, dfList)
    pvals <- pchisq(chisq, df=df, lower.tail = FALSE)
    colnames(df) <- colnames(chisq)
    colnames(pvals) <- colnames(chisq)
    s <- list(res = s, coef = cf, stdErr = stdErr, Chisq = chisq, Df = df,
              pvals = pvals)
  } else {
    NumDF <- lapply(resultList, function(x) x$Ftest[,1])
    NumDF <- do.call(rbind, NumDF)
    DenDF <- lapply(resultList, function(x) x$Ftest[,2])
    DenDF <- do.call(rbind, DenDF)
    Fval <- lapply(resultList, function(x) x$Ftest[,3])
    Fval <- do.call(rbind, Fval)
    pvals <- lapply(resultList, function(x) x$Ftest[,4])
    pvals <- do.call(rbind, pvals)
    s <- list(res = s, coef = cf, stdErr = stdErr, NumDF = NumDF, DenDF = DenDF, 
              Fval = Fval, pvals = pvals)
  }
}
