Samples of Thoughts

about data, statistics and everything in between

Markov Chain Monte Carlo

8.1 King Markov and His island kingdom

A simple example of the Markov Chain Monte Carlo algorithm:

num_weeks <- 1e5
positions <- rep(0, num_weeks)
current <- 10
for (i in 1:num_weeks) {
  # record current position
  positions[i] <- current
  
  # flip coin to generate proposal
  proposal <- current + sample( c(-1, 1), size=1)
  if ( proposal < 1 ) proposal <- 10
  if ( proposal > 10 ) proposal <- 1
  
  # move?
  prob_move <- proposal / current
  current <- ifelse( runif(1) < prob_move , proposal, current)
}
par(mfrow=c(1,2))
plot( (1:100), positions[1:100], xlab="week", ylab="island", col="midnightblue")
plot(table(positions), col="midnightblue", xlab="island", ylab="number of weeks")

8.3 Easy HMC: map2stan

Using the terrain ruggedness data from Chapter 7:

library(rethinking)
data(rugged)
d <- rugged
d$log_gdp <- log(d$rgdppc_2000)
dd <- d[ complete.cases(d$rgdppc_2000), ]

Fitting the old way using map:

m8.1 <- map(
  alist(
    log_gdp ~ dnorm( mu, sigma ),
    mu <- a + bR*rugged + bA*cont_africa + bAR*rugged*cont_africa ,
    a ~ dnorm( 0, 100),
    bR ~ dnorm(0, 10),
    bA ~ dnorm(0 , 10),
    bAR ~ dnorm(0, 10),
    sigma ~ dunif(0, 10)
  ),
  data = dd
)
precis(m8.1)

To use Stan, we should do some preprocessing. In particular, preprocess all variable transformations and make a trimmed data frame, only containing the variables used in the model.

dd.trim <- dd[ , c("log_gdp", "rugged", "cont_africa")]
str(dd.trim)

Using Stan:

m8.1stan <- map2stan(
  alist(
    log_gdp ~ dnorm( mu, sigma) ,
    mu <- a + bR*rugged + bA*cont_africa + bAR*rugged*cont_africa,
    a ~ dnorm(0, 100),
    bR ~ dnorm(0, 10),
    bA ~ dnorm(0, 10),
    bAR ~ dnorm(0, 10),
    sigma ~ dcauchy(0, 2)
  ), 
  data=dd.trim,
  start=list(a=5, bR=0, bA=0, bAR=0, sigma=1)
)
precis(m8.1stan)

It is possible to draw more samples from the stan model, also using more chains:

m8.1stan_4chains <- map2stan( m8.1stan, chains=4, cores=4)
precis(m8.1stan_4chains)

To visualize the results, you can plot the samples. To pull out samples, use

post <- extract.samples( m8.1stan )
str(post)
pairs(post)

A prettier plot is also available, directly on the stan model:

pairs( m8.1stan )

By default, map2stan computes DIC and WAIC. We can extract them with

DIC(m8.1stan)

and

WAIC(m8.1stan)

Alternatively, it is also displayed in the default show output:

show(m8.1stan)

To get the trace plots of the Markov Chain:

plot( m8.1stan, window=c(100,2000), col="royalblue4", n_cols=2)

To get a glimpse at the raw stan code, we can use stancode()

stancode(m8.1stan)

8.4 Care and feeding of your Markov chain

Example of non-convergent chain:

y <- c(-1, 1)
m8.2 <- map2stan(
  alist(
    y ~ dnorm( mu, sigma),
    mu <- alpha
  ),
  data=list(y=y), start=list(alpha=0, sigma=1),
  chains=2, iter=4000, warmup=1000
)

There are quite a few warnings on divergencies. Let’s have a look at the estimates:

precis(m8.2)

This doesn’t look right: The estimates are a very far way out there, the effective number of samples is relatively low and Rhat is above 1. While Rhat in my case is only around 1.01, even such a value is already suspicious. Let’s have a look at the trace plots.

plot(m8.2, col=c("black", "royalblue4"), n_cols=1)

The problem: The priors are very flat which means that even values of 500 millions are plausible values. We can fix this by adding a weakly informative prior:

m8.3 <- map2stan(
  alist(
    y ~ dnorm( mu, sigma),
    mu <- alpha,
    alpha ~ dnorm(1, 10),
    sigma ~ dcauchy( 0, 1)
  ),
  data=list(y=y), start=list(alpha=0, sigma=1),
  chains=2, iter=4000, warmup=1000
)
precis(m8.3)

The estimates seem much more reasonable and the Rhat value is now 1.

plot(m8.3, col=c("black", "royalblue4"), n_cols=1)

The chains also look good now.

If we compare the prior and posterior distribution, even two points can overcome these weakly informative priors and thus lead to better results than flat priors.

post <- extract.samples(m8.3)
par(mfrow=c(1, 2))
sq <- seq(-15, 20, length.out = 100)
plot( density(post$alpha,  from=-15, to=20, adj=1),
      lwd=2, col="royalblue4", xlab="alpha", 
     main="")
points(sq, dnorm(sq, 1, 10), type="l", lty=2)
text(4.5, 0.3, labels = "Posterior")
text(8, 0.06, labels="Prior")

sq <- seq(0, 10, length.out = 100)
plot( density( post$sigma, from=0, to=10, adj=1.5),
      lwd=2, col="royalblue4", xlab="sigma", 
      main="")
points(sq, 2*dcauchy(sq, 0, 1), type="l", lty=2)

Non-identifiable parameters

We’ve learned before how highly correlated predictors lead to non-identifiable parameters. Let’s have a look how these look inside a Markov chain.

y <- rnorm( 100, mean=0, sd=1 )

We fit the following unidentifiable model:

m8.4 <- map2stan(
  alist(
    y ~ dnorm( mu, sigma),
    mu <- a1 + a2,
    sigma ~ dcauchy( 0, 1)
  ), 
  data=list(y=y), start=list(a1=0, a2=0, sigma=1),
  chains=2, iter=4000, warmup=1000
)
precis(m8.4)

These estimates of a1 and a2 look suspicious. Also, n_eff and Rhat have terrible values.

plot(m8.4, col=c("black", "royalblue4"), n_cols=1)

The trace plots also don’t look good: The two chains are not mixing and are definitely not stationary. Again, we can use weak priors to solve this problem:

m8.5 <- map2stan(
  alist(
    y ~ dnorm( mu, sigma),
    mu <- a1 + a2,
    a1 ~ dnorm(0, 10),
    a2 ~ dnorm(0, 10),
    sigma ~ dcauchy(0, 1)
  ),
  data=list(y=y), start=list(a1=0, a2=0, sigma=1),
  chains=2, iter=4000, warmup=1000
)
precis(m8.5)

Not only did the model sample much faster, both the estimates and the values for n_eff and Rhat look much better.

plot(m8.5, col=c("black", "royalblue4"), n_cols=1)

The trace plots as well look very good: stationary and well mixed.

Overthinking: Cauchy distribution

The Cauchy distribution does not have mean since it has a very thick-tailed distribution. At any moment in a Cauchy sampling process, a very high value can be drawn that overwhelms all of the previous draw and hence the the distribution does not converge to a mean.

set.seed(13)
y <- rcauchy(1e4, 0, 5)
mu <- sapply(1:length(y), function(i) sum(y[1:i]/i))
plot(mu, type="l")

Corrie

Read more posts by this author.