#--------------------------------------------------
# Small utilities
#--------------------------------------------------

Ntip0 <- function(phy)
    ifelse(is.null(phy), 0, Ntip(phy))

plotstates <- function(phy)
{
    plot(phy, show.tip.label=F)
    tiplabels(pch=21, bg=c("black", "red")[phy$tip.state+1], cex=1.5)
}

plothistory <- function(phy)
{
    h <- history.from.sim.discrete(phy, 0:1)
    plot(h, phy)
}

binary.to.hidden <- function(st)
{
    stm <- matrix(nrow=length(st), ncol=4)
    rownames(stm) <- names(st)
    colnames(stm) <- c("0a", "0b", "1a", "1b")
    for (sp in names(st))
    {
        if (st[sp] == 0) {
            stm[sp,] <- c(1, 1, 0, 0)
        } else if (st[sp] == 1) {
            stm[sp,] <- c(0, 0, 1, 1)
        }
    }
    return(stm)
}

#--------------------------------------------------
# Wrapper for maximum likelihood estimation
#--------------------------------------------------

### Written by Dan Rabosky

fitML <- function(fx, nopt=5, lmin = 0.0001, lmax=20.0, MAXBAD = 1000, initscale = 0.1)
{
  for (i in 1:nopt)
  {
    badcount <- 0
    
    # iv <- getStartingParamsDiversitree(fx, lmin=lmin, lmax=lmax*initscale)
    iv <- runif(min=lmin, max=lmax, n=length(argnames(fx)))
    
    resx <- try(optim(iv ,fx, method='L-BFGS-B', control=list(maxit=1000, fnscale=-1), lower=lmin, upper=lmax), silent=T)
    while (class(resx) == 'try-error')
    {
      # iv <- getStartingParamsDiversitree(fx, lmin=lmin, lmax=lmax*initscale)
      iv <- runif(min=lmin, max=lmax, n=length(argnames(fx)))
      resx <- try(optim(iv , fx, method='L-BFGS-B', control=list(maxit=1000, fnscale=-1), lower=lmin, upper=lmax), silent=T)
      
      badcount <- badcount + 1
      if (badcount > MAXBAD)
      {
          stop("Too many fails in fitML\n")
      }
    }
    
    if (i == 1) {
        best <- resx
    } else {
      if (best$value < resx$value)
      {
          best <- resx
      }
    }
    
  }
  
  fres <- list(pars=best$par, loglik=best$value)
  fres$AIC <- -2*fres$loglik + 2*length(argnames(fx))
  fres$counts <- best$counts
  #fres$like_function <- fx
  fres$convergence <- best$convergence
  fres$message <- best$message
  return(fres)
}

#--------------------------------------------------
# Fixes for diversitree
#--------------------------------------------------

### This modified version of "check.states" is needed to work with diversitree
### 0.9-7. It allows Mkn models to run when not all states are present.
### (See also mod to initial.tip.xxsse.R)

patched.check.states <- function(tree, states, allow.unnamed=FALSE,
                         strict=FALSE, strict.vals=NULL,
                         as.integer=TRUE) {
  multicheck <- TRUE # for multistate strict checking
  if ( is.matrix(states) ) {
    ## Multistate characters (experimental).  This will not work with
    ## clade trees, but they are only interesting for BiSSE, which has
    ## NA values for multistate (even weight).
    if ( inherits(tree, "clade.tree") )
      stop("Clade trees won't work with multistate tips yet")
    n <- rowSums(states > 0)
    if ( any(n == 0) )
      stop(sprintf("No state found for taxa: %s",
                   paste(names(n)[n == 0], collapse=", ")))
    if (any(rowSums(states) == 0))
        multicheck <- FALSE

    i.mono <- which(n == 1)
    i.mult <- which(n >  1)

    tmp <- diversitree:::matrix.to.list(states)
    names(tmp) <- rownames(states)

    states.mult <- lapply(tmp[i.mult], as.numeric)

    states <- rep(NA, length(tmp))
    names(states) <- names(tmp)
    states[i.mono] <- sapply(tmp[i.mono], function(x)
                             which(x != 0))

    attr(states, "multistate") <- list(i=i.mult, states=states.mult)
  }
  
  if ( is.null(names(states)) ) {
    if ( allow.unnamed ) {
      if ( length(states) == length(tree$tip.label) ) {
        names(states) <- tree$tip.label
        warning("Assuming states are in tree$tip.label order")
      } else {
        stop(sprintf("Invalid states length (expected %d)",
                     length(tree$tip.label)))
      }
    } else {
      stop("The states vector must contain names")
    }
  }
  
  if ( !all(tree$tip.label %in% names(states)) )
    stop("Not all species have state information")

  ## When multistate characters are present, this may fail even
  ## for cases where it should not.
  ## now, multicheck helps this
  if ( !is.null(strict.vals) ) {
    if ( isTRUE(all.equal(strict.vals, 0:1)) )
      if ( is.logical(states) )
        states[] <- as.integer(states)
    
    if ( strict ) {
      if ( !isTRUE(all.equal(sort(strict.vals),
                             sort(unique(na.omit(states))))) & !multicheck)
        stop("Because strict state checking requested, all (and only) ",
             sprintf("states in %s are allowed",
                     paste(strict.vals, collapse=", ")))
    } else {
      tmp <- unique(na.omit(states))
      if (!is.na(tmp[[1]])) {
          extra <- setdiff(sort(tmp), strict.vals)
          if ( length(extra) > 0 )
            stop(sprintf("Unknown states %s not allowed in states vector",
                         paste(extra, collapse=", ")))
      }
    }
    if ( as.integer && any(!is.na(states)) )
      states <- diversitree:::check.integer(states)
  }

  if ( inherits(tree, "clade.tree") ) {
    spp.clades <- unlist(tree$clades)
    if ( !all(spp.clades %in% names(states)) )
      stop("Species in 'clades' do not have states information")
    states[union(tree$tip.label, spp.clades)]
  } else {
    ret <- states[tree$tip.label]
    ## Ugly hack...
    attr(ret, "multistate") <- attr(states, "multistate")
    ret
  }
}

assignInNamespace("check.states", patched.check.states, "diversitree")
rm(patched.check.states)

### This modified version of "check.states" is needed to work with diversitree
### 0.9-7. It allows all tips to be uncertain/multiple-value.
###
### Even more importantly, it fixes a bug with the interaction of multistate
### characters and sampling incompleteness.

patched.initial.tip.xxsse <- function(cache, base.zero = FALSE) {
    k <- cache$info$k
    f <- cache$sampling.f
    y <- matrix(rep(c(1 - f, rep(0, k)), k + 1), k + 1, 2 * k, 
        TRUE)
    y[k + 1, (k + 1):(2 * k)] <- diag(y[1:k, (k + 1):(2 * k)]) <- f
    y <- diversitree:::matrix.to.list(y)
    y.i <- cache$states
    if (base.zero) 
        y.i <- y.i + 1L
    y.i[is.na(y.i)] <- k + 1
    if (!is.null(multistate <- attr(cache$states, "multistate"))) {
        y.multi <- unique(multistate$states)
        y.i.multi <- match(multistate$states, y.multi)
        y <- c(y, lapply(y.multi, function(x) c(1-f, x*f))) # EEG: lacked "*f"
        y.i[multistate$i] <- y.i.multi + k + 1
    }
    diversitree:::dt.tips.grouped(y, as.numeric(y.i), cache)
}

assignInNamespace("initial.tip.xxsse", patched.initial.tip.xxsse, "diversitree")
rm(patched.initial.tip.xxsse)
