bkmr
and bkmrhat
bkmr
is a package to implement Bayesian kernel machine
regression (BKMR) using Markov chain Monte Carlo (MCMC). Notably,
bkmr
is missing some key features in Bayesian inference and
MCMC diagnostics: 1) no facility for running multiple chains in parallel
2) no inference across multiple chains 3) limited posterior summary of
parameters 4) limited diagnostics. The bkmrhat
package is a
lightweight set of function that fills in each of those gaps by enabling
post-processing of bkmr
output in other packages and
building a small framework for parallel processing.
bkmrhat
packagekmbaryes
function from bkmr
, or use multiple parallel chains
kmbayes_parallel
from bkmrhat
kmbayes_diagnose
function (uses functions from the
rstan
package) OR convert the BKMR fit(s) to
mcmc
(one chain) or mcmc.list
(multiple
chains) objects from the coda
package using
as.mcmc
or as.mcmc.list
from the
bkmrhat
package. The coda
package has a whole
host of inference and diagnostic procedures (but may lag behind some of
the diagnostics functions from rstan
).coda
functions or
combine chains from a kmbayes_parallel
fit using
kmbayes_combine
. Final posterior inferences can be made on
the combined object, which enables use of bkmr
package
functions for visual summaries of independent and joint effects of
exposures in the bkmr
model.First, simulate some data from the bkmr
function
library("bkmr")
library("bkmrhat")
library("coda")
set.seed(111)
dat <- bkmr::SimData(n = 50, M = 5, ind=1:3, Zgen="realistic")
y <- dat$y
Z <- dat$Z
X <- cbind(dat$X, rnorm(50))
head(cbind(y,Z,X))
## y z1 z2 z3 z4 z5
## [1,] 4.0169017 -0.1158956 -0.09407257 -0.1588709 -0.42142542 0.1721683
## [2,] 10.6535912 -0.4967600 -0.19909875 0.9509691 0.45218228 1.4554858
## [3,] 8.0644572 0.3465623 -0.09978463 0.3812603 0.05898454 0.9558229
## [4,] 1.7274402 1.3671838 2.36393578 1.3669522 1.53321305 0.9574047
## [5,] 0.3112323 -0.4232780 -0.04696948 1.1634185 -0.03507992 0.1608525
## [6,] 5.7651402 -0.1378359 -0.21404393 0.0416876 0.22237734 -0.9101163
##
## [1,] 1.0428561 -1.0503824
## [2,] 4.4612601 0.3251424
## [3,] 3.0073031 -2.1048716
## [4,] -0.4592363 -0.9551027
## [5,] -0.5985688 -0.5306399
## [6,] 1.8228560 0.8274405
There is some overhead in parallel processing when using the
future
package, so the payoff when using parallel
processing may vary by the problem. Here it is about a 2-4x speedup, but
you can see more benefit at higher iterations. Note that this may not
yield as many usable iterations as a single large chain if a substantial
burnin period is needed, but it will enable useful convergence
diagnostics. Note that the future package can implement sequential
processing, which effectively turns the kmbayes_parallel into a loop,
but still has all other advantages of multiple chains.
# enable parallel processing (up to 4 simultaneous processes here)
future::plan(strategy = future::multisession)
# single run of 4000 observations from bkmr package
set.seed(111)
system.time(kmfit <- suppressMessages(kmbayes(y = y, Z = Z, X = X, iter = 4000, verbose = FALSE, varsel = FALSE)))
## user system elapsed
## 15.016 27.562 10.995
# 4 runs of 1000 observations from bkmrhat package
set.seed(111)
system.time(kmfit5 <- suppressMessages(kmbayes_parallel(nchains=4, y = y, Z = Z, X = X, iter = 1000, verbose = FALSE, varsel = FALSE)))
## Chain 1
## Chain 2
## Chain 3
## Chain 4
## user system elapsed
## 0.713 0.034 15.309
The diagnostics from the rstan package come from the
monitor
function (see the help files for that function in
the rstan pacakge)
# Using rstan functions (set burnin/warmup to zero for comparability with coda numbers given later
# posterior summaries should be performed after excluding warmup/burnin)
singlediag = kmbayes_diagnose(kmfit, warmup=0, digits_summary=2)
## Single chain
## Inference for the input samples (1 chains: each with iter = 4000; warmup = 0):
##
## Q5 Q50 Q95 Mean SD Rhat Bulk_ESS Tail_ESS
## beta1 1.9 2.0 2.0 2.0 0.0 1.00 2665 2969
## beta2 0.0 0.1 0.3 0.1 0.1 1.00 3041 3693
## lambda 4.3 11.1 27.6 13.1 8.2 1.01 259 196
## r1 0.0 0.0 0.1 0.0 0.1 1.00 170 162
## r2 0.0 0.0 0.1 0.0 0.1 1.00 310 244
## r3 0.0 0.0 0.0 0.0 0.0 1.00 155 128
## r4 0.0 0.0 0.1 0.0 0.1 1.00 170 155
## r5 0.0 0.0 0.0 0.0 0.1 1.01 90 131
## sigsq.eps 0.2 0.3 0.5 0.4 0.1 1.00 1326 1631
##
## For each parameter, Bulk_ESS and Tail_ESS are crude measures of
## effective sample size for bulk and tail quantities respectively (an ESS > 100
## per chain is considered good), and Rhat is the potential scale reduction
## factor on rank normalized split chains (at convergence, Rhat <= 1.05).
# Using rstan functions (multiple chains enable R-hat)
multidiag = kmbayes_diagnose(kmfit5, warmup=0, digits_summary=2)
## Parallel chains
## Inference for the input samples (4 chains: each with iter = 1000; warmup = 0):
##
## Q5 Q50 Q95 Mean SD Rhat Bulk_ESS Tail_ESS
## beta1 1.9 2.0 2.0 2.0 0.0 1.00 2367 1711
## beta2 0.0 0.1 0.3 0.1 0.1 1.00 2009 1954
## lambda 4.3 10.1 24.4 11.8 7.0 1.02 207 102
## r1 0.0 0.0 0.1 0.1 0.2 1.02 128 67
## r2 0.0 0.0 0.2 0.1 0.2 1.03 149 56
## r3 0.0 0.0 0.1 0.0 0.2 1.04 78 64
## r4 0.0 0.0 0.1 0.0 0.1 1.02 98 145
## r5 0.0 0.0 0.2 0.1 0.2 1.03 99 53
## sigsq.eps 0.2 0.3 0.5 0.4 0.1 1.01 439 407
##
## For each parameter, Bulk_ESS and Tail_ESS are crude measures of
## effective sample size for bulk and tail quantities respectively (an ESS > 100
## per chain is considered good), and Rhat is the potential scale reduction
## factor on rank normalized split chains (at convergence, Rhat <= 1.05).
# using coda functions, not using any burnin (for demonstration only)
kmfitcoda = as.mcmc(kmfit, iterstart = 1)
kmfit5coda = as.mcmc.list(kmfit5, iterstart = 1)
# single chain trace plot
traceplot(kmfitcoda)
The trace plots look typical, and fine, but trace plots don’t give a
full picture of convergence. Note that there is apparent quick
convergence for a couple of parameters demonstrated by movement away
from the starting value and concentration of the rest of the samples
within a narrow band.
Seeing visual evidence that different chains are sampling from the same marginal distributions is reassuring about the stability of the results.
Now examine “cross correlation”, which can help identify highly correlated parameters in the posterior, which can be problematic for MCMC sampling. Here there is a block {r3,r4,r5} which appear to be highly correlated. All other things equal, having highly correlated parameters in the posterior means that more samples are needed than would be needed with uncorrelated parameters.
## beta1 beta2 lambda r1 r2
## beta1 1.00000000 0.07733726 0.020816066 0.15791108 0.14489557
## beta2 0.07733726 1.00000000 -0.059673636 -0.11479475 -0.12891028
## lambda 0.02081607 -0.05967364 1.000000000 0.04963909 0.04178173
## r1 0.15791108 -0.11479475 0.049639091 1.00000000 0.79825315
## r2 0.14489557 -0.12891028 0.041781725 0.79825315 1.00000000
## r3 0.15952865 -0.08849151 0.006925921 0.50841621 0.68184307
## r4 0.23091705 -0.06415646 -0.002732878 0.67347750 0.71342585
## r5 0.15971059 -0.10964670 0.043259501 0.75711064 0.93661034
## sigsq.eps -0.03541153 0.07444138 -0.327970696 -0.23171096 -0.23704523
## r3 r4 r5 sigsq.eps
## beta1 0.159528646 0.230917049 0.1597106 -0.03541153
## beta2 -0.088491506 -0.064156458 -0.1096467 0.07444138
## lambda 0.006925921 -0.002732878 0.0432595 -0.32797070
## r1 0.508416212 0.673477505 0.7571106 -0.23171096
## r2 0.681843070 0.713425851 0.9366103 -0.23704523
## r3 1.000000000 0.689998376 0.7737409 -0.10128083
## r4 0.689998376 1.000000000 0.7600625 -0.13636196
## r5 0.773740905 0.760062544 1.0000000 -0.21093636
## sigsq.eps -0.101280830 -0.136361955 -0.2109364 1.00000000
Now examine “autocorrelation” to identify parameters that have high correlation between subsequent iterations of the MCMC sampler, which can lead to inefficient MCMC sampling. All other things equal, having highly autocorrelated parameters in the posterior means that more samples are needed than would be needed with low-autocorrelation parameters.
Graphical tools can be limited, and are sometimes difficult to use
effectively with scale parameters (of which bkmr
has many).
Additionally, no single diagnostic is perfect, leading many authors to
advocate the use of multiple, complementary diagnostics. Thus, more
formal diagnostics are helpful.
Gelman’s r-hat diagnostic gives an interpretable diagnostic: the
expected reduction in the standard error of the posterior means if you
could run the chains to an infinite size. These give some idea about
when is a fine idea to stop sampling. There are rules of thumb about
using r-hat to stop sampling that are available from several authors
(for example you can consult the help files for rstan
and
coda
).
Effective sample size is also useful - it estimates the amount of information in your chain, expressed in terms of the number of independent posterior samples it would take to match that information (e.g. if we could just sample from the posterior directly).
# Gelman's r-hat using coda estimator (will differ from rstan implementation)
gelman.diag(kmfit5coda)
## Potential scale reduction factors:
##
## Point est. Upper C.I.
## beta1 1.00 1.00
## beta2 1.00 1.00
## lambda 1.05 1.13
## r1 1.02 1.03
## r2 1.05 1.09
## r3 1.07 1.15
## r4 1.07 1.13
## r5 1.05 1.11
## sigsq.eps 1.00 1.01
##
## Multivariate psrf
##
## 1.07
## beta1 beta2 lambda r1 r2 r3 r4 r5
## 2892.7570 3198.3261 202.6782 118.1804 284.6426 295.7300 146.9502 115.5348
## sigsq.eps
## 1242.4771
## beta1 beta2 lambda r1 r2 r3 r4 r5
## 2143.5876 2880.0101 318.0519 123.0036 119.2547 136.5004 193.8485 105.6125
## sigsq.eps
## 850.5044
Posterior kernel marginal densities, 1 chain
Posterior kernel marginal densities, multiple chains combined. Look for multiple modes that may indicate non-convergence of some chains
Other diagnostics from the coda
package are available
here.
Finally, the chains from the original kmbayes_parallel
fit can be combined into a single chain (see the help files for how to
deal with burn-in, the default in bkmr
is to use the first
half of the chain, which is respected here). The
kmbayes_combine
function smartly first combines the burn-in
iterations and then combines the iterations after burnin, such that the
burn-in rules of subsequent functions within the bkmr
package are respected. Note that unlike the as.mcmc.list
function, this function combines all iterations into a single chain, so
trace plots will not be good diagnotistics in this combined object, and
it should be used once one is assured that all chains have converged and
the burn-in is acceptable.
With this combined set of samples, you can follow any of the
post-processing functions from the bkmr
functions, which
are described here: https://jenfb.github.io/bkmr/overview.html. For example,
see below the estimation of the posterior mean difference along a series
of quantiles of all exposures in Z.
##
## Iterations = 1:4000
## Thinning interval = 1
## Number of chains = 1
## Sample size per chain = 4000
##
## 1. Empirical mean and standard deviation for each variable,
## plus standard error of the mean:
##
## Mean SD Naive SE Time-series SE
## beta1 1.97357 0.04598 0.0007270 0.0008549
## beta2 0.13518 0.08710 0.0013771 0.0015401
## lambda 13.08699 8.23403 0.1301915 0.5783746
## r1 0.02829 0.06333 0.0010013 0.0058256
## r2 0.03552 0.05200 0.0008222 0.0030823
## r3 0.02140 0.04096 0.0006476 0.0023818
## r4 0.02931 0.06623 0.0010472 0.0054636
## r5 0.03055 0.10618 0.0016788 0.0098783
## sigsq.eps 0.35970 0.08495 0.0013432 0.0024101
##
## 2. Quantiles for each variable:
##
## 2.5% 25% 50% 75% 97.5%
## beta1 1.88651 1.94226 1.97310 2.00380 2.06407
## beta2 -0.03357 0.07694 0.13345 0.19279 0.30592
## lambda 3.95746 7.68835 11.09318 16.25233 33.16143
## r1 0.01031 0.01229 0.01676 0.02568 0.08740
## r2 0.01056 0.01474 0.02290 0.04175 0.10979
## r3 0.01025 0.01194 0.01437 0.02179 0.06415
## r4 0.01021 0.01286 0.01763 0.02860 0.07571
## r5 0.01014 0.01198 0.01451 0.02031 0.06762
## sigsq.eps 0.23236 0.29752 0.34779 0.40793 0.55917
##
## Iterations = 1:1000
## Thinning interval = 1
## Number of chains = 4
## Sample size per chain = 1000
##
## 1. Empirical mean and standard deviation for each variable,
## plus standard error of the mean:
##
## Mean SD Naive SE Time-series SE
## beta1 1.97220 0.04757 0.0007521 0.001134
## beta2 0.12948 0.09183 0.0014520 0.001740
## lambda 11.76360 7.02139 0.1110180 0.410121
## r1 0.05423 0.16378 0.0025897 0.018419
## r2 0.07722 0.18602 0.0029412 0.021687
## r3 0.04928 0.16250 0.0025694 0.016215
## r4 0.04291 0.11688 0.0018480 0.008879
## r5 0.05352 0.16240 0.0025678 0.018949
## sigsq.eps 0.36171 0.09018 0.0014258 0.003302
##
## 2. Quantiles for each variable:
##
## 2.5% 25% 50% 75% 97.5%
## beta1 1.88251 1.94143 1.97146 2.00155 2.0671
## beta2 -0.05827 0.07128 0.13124 0.19054 0.3044
## lambda 3.25664 7.11622 10.11615 14.53353 30.3804
## r1 0.01010 0.01223 0.01707 0.02712 0.7265
## r2 0.01084 0.01617 0.02846 0.05229 0.8885
## r3 0.01035 0.01198 0.01473 0.02104 0.6824
## r4 0.01025 0.01270 0.01763 0.02984 0.2784
## r5 0.01025 0.01159 0.01487 0.02220 0.7232
## sigsq.eps 0.21768 0.29936 0.34981 0.41255 0.5662
## lower upper
## beta1 1.88272469 2.05913505
## beta2 -0.03379651 0.30462773
## lambda 2.62734881 27.66498240
## r1 0.01003506 0.05828581
## r2 0.01000265 0.08925073
## r3 0.01009384 0.04319179
## r4 0.01001438 0.06122670
## r5 0.01003380 0.03968790
## sigsq.eps 0.21094543 0.52732628
## attr(,"Probability")
## [1] 0.95
## [[1]]
## lower upper
## beta1 1.87881466 2.05640734
## beta2 -0.06790542 0.29408811
## lambda 4.03387403 25.18297025
## r1 0.01049732 0.45345607
## r2 0.01013077 0.20479069
## r3 0.01007166 0.07531484
## r4 0.01027842 0.08344169
## r5 0.01044085 0.35237108
## sigsq.eps 0.20873868 0.55829021
## attr(,"Probability")
## [1] 0.95
##
## [[2]]
## lower upper
## beta1 1.89301316 2.06171257
## beta2 -0.03168293 0.33174439
## lambda 2.16371834 22.42148597
## r1 0.01010240 0.07291425
## r2 0.01031875 0.14093974
## r3 0.01014711 0.06031589
## r4 0.01007560 0.12774120
## r5 0.01031048 0.04156094
## sigsq.eps 0.22080250 0.56493912
## attr(,"Probability")
## [1] 0.95
##
## [[3]]
## lower upper
## beta1 1.86689907 2.06291839
## beta2 -0.04088964 0.31660729
## lambda 3.15093841 25.77385790
## r1 0.01006494 0.05443499
## r2 0.01024398 0.44847169
## r3 0.01005822 0.91002697
## r4 0.01010409 0.09978123
## r5 0.01006096 0.44352639
## sigsq.eps 0.18776118 0.52538413
## attr(,"Probability")
## [1] 0.95
##
## [[4]]
## lower upper
## beta1 1.87431092 2.06061875
## beta2 -0.04925949 0.30091574
## lambda 3.25664492 24.56205553
## r1 0.01002125 0.60943465
## r2 0.01067705 0.80759889
## r3 0.01009028 0.06928882
## r4 0.01013029 0.09261106
## r5 0.01009554 0.53367948
## sigsq.eps 0.20121944 0.52602621
## attr(,"Probability")
## [1] 0.95
# combine multiple chains into a single chain
fitkmccomb = kmbayes_combine(kmfit5)
# For example:
summary(fitkmccomb)
## Fitted object of class 'bkmrfit'
## Iterations: 4000
## Outcome family: gaussian
## Model fit on: 2025-03-09 04:12:30.708982
## Running time: 13.39785 secs
##
## Acceptance rates for Metropolis-Hastings algorithm:
## param rate
## 1 lambda 0.4161040
## 2 r1 0.1897974
## 3 r2 0.3438360
## 4 r3 0.1472868
## 5 r4 0.1865466
## 6 r5 0.1370343
##
## Parameter estimates (based on iterations 2001-4000):
## param mean sd q_2.5 q_97.5
## 1 beta1 1.97170 0.04505 1.88241 2.06101
## 2 beta2 0.13308 0.08748 -0.04784 0.30174
## 3 sigsq.eps 0.36394 0.08443 0.22955 0.55831
## 4 r1 0.02266 0.01794 0.01041 0.06867
## 5 r2 0.03716 0.03347 0.01032 0.13517
## 6 r3 0.01822 0.01304 0.01019 0.05190
## 7 r4 0.02410 0.02436 0.01025 0.08196
## 8 r5 0.01802 0.01052 0.01025 0.04156
## 9 lambda 11.84660 6.34991 3.82256 28.37833
## NULL
mean.difference <- suppressWarnings(OverallRiskSummaries(fit = fitkmccomb, y = y, Z = Z, X = X,
qs = seq(0.25, 0.75, by = 0.05),
q.fixed = 0.5, method = "exact"))
mean.difference
## quantile est sd
## 1 0.25 -0.43409703 0.09329187
## 2 0.30 -0.37823674 0.06917024
## 3 0.35 -0.18115927 0.04230407
## 4 0.40 -0.14181918 0.03513813
## 5 0.45 -0.06350457 0.03092162
## 6 0.50 0.00000000 0.00000000
## 7 0.55 0.17379452 0.05137174
## 8 0.60 0.31492642 0.07284092
## 9 0.65 0.58316326 0.11592165
## 10 0.70 0.71234910 0.13873256
## 11 0.75 0.85574860 0.17996696
with(mean.difference, {
plot(quantile, est, pch=19, ylim=c(min(est - 1.96*sd), max(est + 1.96*sd)),
axes=FALSE, ylab= "Mean difference", xlab = "Joint quantile")
segments(x0=quantile, x1=quantile, y0 = est - 1.96*sd, y1 = est + 1.96*sd)
abline(h=0)
axis(1)
axis(2)
box(bty='l')
})
These results parallel previous session and are given here without comment, other than to note that no fixed effects (X variables) are included, and that it is useful to check the posterior inclusion probabilities to ensure they are similar across chains.
set.seed(111)
system.time(kmfitbma.list <- suppressWarnings(kmbayes_parallel(nchains=4, y = y, Z = Z, X = X, iter = 1000, verbose = FALSE, varsel = TRUE)))
## Chain 1
## Iteration: 100 (10% completed; 0.30565 secs elapsed)
## Iteration: 200 (20% completed; 1.06899 secs elapsed)
## Iteration: 300 (30% completed; 1.8132 secs elapsed)
## Iteration: 400 (40% completed; 2.3549 secs elapsed)
## Iteration: 500 (50% completed; 3.17816 secs elapsed)
## Iteration: 600 (60% completed; 3.60108 secs elapsed)
## Iteration: 700 (70% completed; 4.25175 secs elapsed)
## Iteration: 800 (80% completed; 4.73596 secs elapsed)
## Iteration: 900 (90% completed; 5.26382 secs elapsed)
## Iteration: 1000 (100% completed; 5.77758 secs elapsed)
## Chain 2
## Iteration: 100 (10% completed; 0.33909 secs elapsed)
## Iteration: 200 (20% completed; 0.87785 secs elapsed)
## Iteration: 300 (30% completed; 1.62694 secs elapsed)
## Iteration: 400 (40% completed; 2.19425 secs elapsed)
## Iteration: 500 (50% completed; 2.82111 secs elapsed)
## Iteration: 600 (60% completed; 3.39094 secs elapsed)
## Iteration: 700 (70% completed; 4.10487 secs elapsed)
## Iteration: 800 (80% completed; 4.54474 secs elapsed)
## Iteration: 900 (90% completed; 5.04014 secs elapsed)
## Iteration: 1000 (100% completed; 5.52137 secs elapsed)
## Chain 3
## Iteration: 100 (10% completed; 0.49312 secs elapsed)
## Iteration: 200 (20% completed; 0.99414 secs elapsed)
## Iteration: 300 (30% completed; 1.67507 secs elapsed)
## Iteration: 400 (40% completed; 2.21084 secs elapsed)
## Iteration: 500 (50% completed; 2.66794 secs elapsed)
## Iteration: 600 (60% completed; 3.22445 secs elapsed)
## Iteration: 700 (70% completed; 4.19478 secs elapsed)
## Iteration: 800 (80% completed; 4.76103 secs elapsed)
## Iteration: 900 (90% completed; 5.32119 secs elapsed)
## Iteration: 1000 (100% completed; 5.76054 secs elapsed)
## Chain 4
## Iteration: 100 (10% completed; 0.56549 secs elapsed)
## Iteration: 200 (20% completed; 1.18778 secs elapsed)
## Iteration: 300 (30% completed; 1.85735 secs elapsed)
## Iteration: 400 (40% completed; 2.96883 secs elapsed)
## Iteration: 500 (50% completed; 3.50446 secs elapsed)
## Iteration: 600 (60% completed; 3.92451 secs elapsed)
## Iteration: 700 (70% completed; 4.48064 secs elapsed)
## Iteration: 800 (80% completed; 5.07473 secs elapsed)
## Iteration: 900 (90% completed; 5.52158 secs elapsed)
## Iteration: 1000 (100% completed; 5.71411 secs elapsed)
## user system elapsed
## 0.379 0.006 6.292
## Parallel chains
## Inference for the input samples (4 chains: each with iter = 1000; warmup = 500):
##
## Q5 Q50 Q95 Mean SD Rhat Bulk_ESS Tail_ESS
## beta1 1.9 2.0 2.0 2.0 0.0 1.00 607 1786
## beta2 0.0 0.1 0.3 0.1 0.1 1.00 1758 1697
## lambda 4.6 11.7 26.7 13.2 7.1 1.07 54 93
## r1 0.0 0.0 0.0 0.0 0.0 1.08 44 125
## r2 0.0 0.0 0.1 0.1 0.0 1.11 37 69
## r3 0.0 0.0 0.0 0.0 0.0 1.02 142 119
## r4 0.0 0.0 0.1 0.0 0.0 1.05 76 48
## r5 0.0 0.0 0.0 0.0 0.0 1.05 117 92
## sigsq.eps 0.3 0.4 0.6 0.4 0.1 1.01 439 844
##
## For each parameter, Bulk_ESS and Tail_ESS are crude measures of
## effective sample size for bulk and tail quantities respectively (an ESS > 100
## per chain is considered good), and Rhat is the potential scale reduction
## factor on rank normalized split chains (at convergence, Rhat <= 1.05).
# posterior exclusion probability of each chain
lapply(kmfitbma.list, function(x) t(ExtractPIPs(x)))
## [[1]]
## [,1] [,2] [,3] [,4] [,5]
## variable "z1" "z2" "z3" "z4" "z5"
## PIP "0.460" "1.000" "0.302" "0.350" "0.636"
##
## [[2]]
## [,1] [,2] [,3] [,4] [,5]
## variable "z1" "z2" "z3" "z4" "z5"
## PIP "0.644" "0.978" "0.284" "0.460" "0.632"
##
## [[3]]
## [,1] [,2] [,3] [,4] [,5]
## variable "z1" "z2" "z3" "z4" "z5"
## PIP "0.780" "0.992" "0.390" "0.498" "0.686"
##
## [[4]]
## [,1] [,2] [,3] [,4] [,5]
## variable "z1" "z2" "z3" "z4" "z5"
## PIP "0.590" "0.950" "0.418" "0.606" "0.680"
## Fitted object of class 'bkmrfit'
## Iterations: 4000
## Outcome family: gaussian
## Model fit on: 2025-03-09 04:12:45.362653
## Running time: 5.77791 secs
##
## Acceptance rates for Metropolis-Hastings algorithm:
## param rate
## 1 lambda 0.4671168
## 2 r/delta (overall) 0.3325831
## 3 r/delta (move 1) 0.4068966
## 4 r/delta (move 2) 0.2563581
##
## Parameter estimates (based on iterations 2001-4000):
## param mean sd q_2.5 q_97.5
## 1 beta1 1.95413 0.04833 1.85772 2.04989
## 2 beta2 0.12335 0.09093 -0.05970 0.30086
## 3 sigsq.eps 0.39193 0.09797 0.24245 0.63211
## 4 r1 0.01248 0.01452 0.00000 0.04656
## 5 r2 0.05151 0.03300 0.01176 0.13335
## 6 r3 0.00784 0.01455 0.00000 0.04212
## 7 r4 0.01232 0.02202 0.00000 0.06878
## 8 r5 0.01270 0.01514 0.00000 0.05259
## 9 lambda 13.22433 7.13686 3.72303 31.18482
##
## Posterior inclusion probabilities:
## variable PIP
## 1 z1 0.6185
## 2 z2 0.9800
## 3 z3 0.3485
## 4 z4 0.4785
## 5 z5 0.6585
## NULL
## variable PIP
## 1 z1 0.6185
## 2 z2 0.9800
## 3 z3 0.3485
## 4 z4 0.4785
## 5 z5 0.6585
mean.difference2 <- suppressWarnings(OverallRiskSummaries(fit = kmfitbma.comb, y = y, Z = Z, X = X, qs = seq(0.25, 0.75, by = 0.05),
q.fixed = 0.5, method = "exact"))
mean.difference2
## quantile est sd
## 1 0.25 -0.43740334 0.09277264
## 2 0.30 -0.36830105 0.06982347
## 3 0.35 -0.18123941 0.04456475
## 4 0.40 -0.14994018 0.03840903
## 5 0.45 -0.05793824 0.03508668
## 6 0.50 0.00000000 0.00000000
## 7 0.55 0.12416353 0.05876532
## 8 0.60 0.33019585 0.07770291
## 9 0.65 0.59447956 0.12018455
## 10 0.70 0.72115132 0.14487676
## 11 0.75 0.88076768 0.18720987
with(mean.difference2, {
plot(quantile, est, pch=19, ylim=c(min(est - 1.96*sd), max(est + 1.96*sd)),
axes=FALSE, ylab= "Mean difference", xlab = "Joint quantile")
segments(x0=quantile, x1=quantile, y0 = est - 1.96*sd, y1 = est + 1.96*sd)
abline(h=0)
axis(1)
axis(2)
box(bty='l')
})
bkmrhat
also has ported versions of the native posterior
summarization functions to compare how these summaries vary across
parallel chains. Note that these should serve as diagnostics, and final
posterior inference should be done on the combined chain. The easiest of
these functions to demonstrate is the
OverallRiskSummaries_parallel
function, which simply runs
OverallRiskSummaries
(from the bkmr
package)
on each chain and combines the results. Notably, this function fixes the
y-axis at zero for the median, so it under-represents overall predictive
variation across chains, but captures variation in effect estimates
across the chains. Ideally, that variation is negligible - e.g. if you
see differences between chains that would result in different
interpretations, you should re-fit the model with more iterations. In
this example, the results are reasonably consistent across chains, but
one might want to run more iterations if, say, the differences seen
across the upper error bounds are of such a magnitude as to be
practically meaningful.
set.seed(111)
system.time(kmfitbma.list <- suppressWarnings(kmbayes_parallel(nchains=4, y = y, Z = Z, X = X, iter = 1000, verbose = FALSE, varsel = TRUE)))
## Chain 1
## Iteration: 100 (10% completed; 0.20548 secs elapsed)
## Iteration: 200 (20% completed; 0.76547 secs elapsed)
## Iteration: 300 (30% completed; 1.3945 secs elapsed)
## Iteration: 400 (40% completed; 1.99192 secs elapsed)
## Iteration: 500 (50% completed; 2.65605 secs elapsed)
## Iteration: 600 (60% completed; 3.2037 secs elapsed)
## Iteration: 700 (70% completed; 3.80628 secs elapsed)
## Iteration: 800 (80% completed; 4.37069 secs elapsed)
## Iteration: 900 (90% completed; 5.00516 secs elapsed)
## Iteration: 1000 (100% completed; 5.56893 secs elapsed)
## Chain 2
## Iteration: 100 (10% completed; 0.3113 secs elapsed)
## Iteration: 200 (20% completed; 0.85779 secs elapsed)
## Iteration: 300 (30% completed; 1.53719 secs elapsed)
## Iteration: 400 (40% completed; 2.13433 secs elapsed)
## Iteration: 500 (50% completed; 2.6291 secs elapsed)
## Iteration: 600 (60% completed; 3.17072 secs elapsed)
## Iteration: 700 (70% completed; 3.74548 secs elapsed)
## Iteration: 800 (80% completed; 4.76075 secs elapsed)
## Iteration: 900 (90% completed; 5.23622 secs elapsed)
## Iteration: 1000 (100% completed; 5.90344 secs elapsed)
## Chain 3
## Iteration: 100 (10% completed; 0.53021 secs elapsed)
## Iteration: 200 (20% completed; 1.17362 secs elapsed)
## Iteration: 300 (30% completed; 1.87651 secs elapsed)
## Iteration: 400 (40% completed; 2.46672 secs elapsed)
## Iteration: 500 (50% completed; 3.00859 secs elapsed)
## Iteration: 600 (60% completed; 3.6063 secs elapsed)
## Iteration: 700 (70% completed; 4.16999 secs elapsed)
## Iteration: 800 (80% completed; 4.73244 secs elapsed)
## Iteration: 900 (90% completed; 5.32029 secs elapsed)
## Iteration: 1000 (100% completed; 6.02845 secs elapsed)
## Chain 4
## Iteration: 100 (10% completed; 0.60657 secs elapsed)
## Iteration: 200 (20% completed; 1.61362 secs elapsed)
## Iteration: 300 (30% completed; 2.27458 secs elapsed)
## Iteration: 400 (40% completed; 2.81039 secs elapsed)
## Iteration: 500 (50% completed; 3.39985 secs elapsed)
## Iteration: 600 (60% completed; 3.96603 secs elapsed)
## Iteration: 700 (70% completed; 4.68908 secs elapsed)
## Iteration: 800 (80% completed; 5.25211 secs elapsed)
## Iteration: 900 (90% completed; 5.61278 secs elapsed)
## Iteration: 1000 (100% completed; 5.86945 secs elapsed)
## user system elapsed
## 0.383 0.011 6.425
meandifference_par = OverallRiskSummaries_parallel(kmfitbma.list, y = y, Z = Z, X = X ,qs = seq(0.25, 0.75, by = 0.05), q.fixed = 0.5, method = "exact")
## Chain 1
## Chain 2
## Chain 3
## Chain 4
## quantile est sd chain
## 1 0.25 -0.43399519 0.09677667 1
## 2 0.30 -0.36352052 0.07236802 1
## 3 0.35 -0.17642660 0.04594553 1
## 4 0.40 -0.14811018 0.03930450 1
## 5 0.45 -0.04764707 0.03047827 1
## 6 0.50 0.00000000 0.00000000 1
nchains = length(unique(meandifference_par$chain))
with(meandifference_par, {
plot.new()
plot.window(ylim=c(min(est - 1.96*sd), max(est + 1.96*sd)),
xlim=c(min(quantile), max(quantile)),
ylab= "Mean difference", xlab = "Joint quantile")
for(cch in seq_len(nchains)){
width = diff(quantile)[1]
jit = runif(1, -width/5, width/5)
points(jit+quantile[chain==cch], est[chain==cch], pch=19, col=cch)
segments(x0=jit+quantile[chain==cch], x1=jit+quantile[chain==cch], y0 = est[chain==cch] - 1.96*sd[chain==cch], y1 = est[chain==cch] + 1.96*sd[chain==cch], col=cch)
}
abline(h=0)
axis(1)
axis(2)
box(bty='l')
legend("bottom", col=1:nchains, pch=19, lty=1, legend=paste("chain", 1:nchains), bty="n")
})
regfuns_par = PredictorResponseUnivar_parallel(kmfitbma.list, y = y, Z = Z, X = X ,qs = seq(0.25, 0.75, by = 0.05), q.fixed = 0.5, method = "exact")
## Chain 1
## Chain 2
## Chain 3
## Chain 4
## variable z est se chain
## 1 z1 -2.159373 -0.3297703 0.5129312 1
## 2 z1 -2.048986 -0.3141977 0.4904068 1
## 3 z1 -1.938600 -0.2985819 0.4680000 1
## 4 z1 -1.828214 -0.2829365 0.4457449 1
## 5 z1 -1.717827 -0.2672755 0.4236794 1
## 6 z1 -1.607441 -0.2516131 0.4018457 1
nchains = length(unique(meandifference_par$chain))
# single variable
with(regfuns_par[regfuns_par$variable=="z1",], {
plot.new()
plot.window(ylim=c(min(est - 1.96*se), max(est + 1.96*se)),
xlim=c(min(z), max(z)),
ylab= "Predicted Y", xlab = "Z")
pc = c("#000000", "#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7", "#999999")
pc2 = c("#0000001A", "#E69F001A", "#56B4E91A", "#009E731A", "#F0E4421A", "#0072B21A", "#D55E001A", "#CC79A71A", "#9999991A")
for(cch in seq_len(nchains)){
ribbonX = c(z[chain==cch], rev(z[chain==cch]))
ribbonY = c(est[chain==cch] + 1.96*se[chain==cch], rev(est[chain==cch] - 1.96*se[chain==cch]))
polygon(x=ribbonX, y = ribbonY, col=pc2[cch], border=NA)
lines(z[chain==cch], est[chain==cch], pch=19, col=pc[cch])
}
axis(1)
axis(2)
box(bty='l')
legend("bottom", col=1:nchains, pch=19, lty=1, legend=paste("chain", 1:nchains), bty="n")
})
Sometimes you just need to run more samples in an existing chain. For
example, you run a bkmr fit for 3 days, only to find you don’t have
enough samples. A “continued” fit just means that you can start off at
the last iteration you were at and just keep building on an existing set
of results by lengthening the Markov chain. Unfortunately, due to how
the kmbayes
function accepts starting values (for the
official install version), you can’t quite do this exactly in
many cases (The function will relay a message and possible solutions, if
any. bkmr
package authors are aware of this issue). The
kmbayes_continue
function continues a bkmr
fit
as well as the bkmr
package will allow. The r
parameters from the fit must all be initialized at the same value, so
kmbayes_continue
starts a new MCMC fit at the final values
of all parameters from the prior bkmr fit, but sets all of the
r
parameters to the mean at the last iteration from the
previous fit. Additionally, if h.hat
parameters are
estimated, these are fixed to be above zero to meet similar constraints,
either by fixing them at their posterior mean or setting to a small
positive value. One should inspect trace plots to see whether this will
cause issues (e.g. if the traceplots demonstrate different patterns in
the samples before and after the continuation). Here’s an example with a
quick check of diagnostics of the first part of the chain, and the
combined chain (which could be used for inference or extended again, if
necessary). We caution users that this function creates 2 distinct, if
very similar Markov chains, and to use appropriate caution if traceplots
differ before and after each continuation. Nonetheless, in many cases
one can act as though all samples are from the same Markov chain.
Note that if you install the developmental version of the
bkmr
package you can continue fits from exactly where they
left off, so you get a true, single Markov chain. You can install that
via the commented code below
# install dev version of bkmr to allow true continued fits.
#install.packages("devtools")
#devtools::install_github("jenfb/bkmr")
set.seed(111)
# run 100 initial iterations for a model with only 2 exposures
Z2 = Z[,1:2]
kmfitbma.start <- suppressWarnings(kmbayes(y = y, Z = Z2, X = X, iter = 500, verbose = FALSE, varsel = FALSE))
## Iteration: 50 (10% completed; 0.05753 secs elapsed)
## Iteration: 100 (20% completed; 0.12966 secs elapsed)
## Iteration: 150 (30% completed; 0.18762 secs elapsed)
## Iteration: 200 (40% completed; 0.25323 secs elapsed)
## Iteration: 250 (50% completed; 0.31158 secs elapsed)
## Iteration: 300 (60% completed; 0.37598 secs elapsed)
## Iteration: 350 (70% completed; 0.44495 secs elapsed)
## Iteration: 400 (80% completed; 0.53054 secs elapsed)
## Iteration: 450 (90% completed; 0.58863 secs elapsed)
## Iteration: 500 (100% completed; 0.65308 secs elapsed)
## Single chain
## Inference for the input samples (1 chains: each with iter = 500; warmup = 250):
##
## Q5 Q50 Q95 Mean SD Rhat Bulk_ESS Tail_ESS
## beta1 1.9 1.9 2.0 1.9 0.0 1.00 279 217
## beta2 -0.1 0.1 0.3 0.1 0.1 1.00 238 206
## lambda 4.6 13.0 31.3 14.9 7.8 1.08 13 16
## r1 0.0 0.0 0.0 0.0 0.0 1.17 5 7
## r2 0.0 0.0 0.1 0.1 0.0 1.05 10 16
## sigsq.eps 0.3 0.4 0.6 0.4 0.1 1.00 87 219
##
## For each parameter, Bulk_ESS and Tail_ESS are crude measures of
## effective sample size for bulk and tail quantities respectively (an ESS > 100
## per chain is considered good), and Rhat is the potential scale reduction
## factor on rank normalized split chains (at convergence, Rhat <= 1.05).
## mean se_mean sd 2.5% 25% 50%
## beta1 1.94302414 0.002744458 0.04643461 1.84782583 1.91330142 1.94655854
## beta2 0.10483405 0.006191034 0.09620251 -0.07303313 0.04304370 0.10467199
## lambda 14.90060019 1.869438290 7.81070475 4.60230691 8.65519030 13.04053754
## r1 0.01476940 0.002347690 0.00449794 0.01078574 0.01235942 0.01320025
## r2 0.05291092 0.009623519 0.04332210 0.01010751 0.01924294 0.03815578
## sigsq.eps 0.41328122 0.010095801 0.09214694 0.27259916 0.35051490 0.39841682
## 75% 97.5% n_eff Rhat valid Q5 Q50
## beta1 1.97193285 2.02693739 282 1.004931 1 1.86400517 1.94655854
## beta2 0.16804699 0.27700828 236 0.997853 1 -0.05155040 0.10467199
## lambda 18.65686378 34.19469565 18 1.084344 1 4.60230691 13.04053754
## r1 0.01703200 0.02877183 9 1.167807 1 0.01078574 0.01320025
## r2 0.07782864 0.15639005 21 1.052302 1 0.01010751 0.03815578
## sigsq.eps 0.46665521 0.62932750 94 1.002317 1 0.28064363 0.39841682
## Q95 MCSE_Q2.5 MCSE_Q25 MCSE_Q50 MCSE_Q75
## beta1 2.01504706 0.0098237593 0.0048432127 0.003083228 0.004034983
## beta2 0.26110337 0.0166696373 0.0093358891 0.007387626 0.010829577
## lambda 31.32637623 1.1971720562 2.0194134613 2.532679950 1.535664875
## r1 0.02153674 0.0002641772 0.0009430751 0.001429907 0.004168245
## r2 0.14102786 0.0028040064 0.0071858659 0.011789241 0.013548261
## sigsq.eps 0.60258149 0.0068027724 0.0052486067 0.005123756 0.008923788
## MCSE_Q97.5 MCSE_SD Bulk_ESS Tail_ESS
## beta1 0.006818301 0.001946731 279 217
## beta2 0.008917280 0.005057796 238 206
## lambda 1.781705411 1.344630129 13 16
## r1 0.004563786 0.001813399 5 7
## r2 0.010363171 0.006905270 10 16
## sigsq.eps 0.013662390 0.007248003 87 219
## Validating control.params...
## Validating starting.values...
## Iteration: 201 (10% completed; 0.26671 secs elapsed)
## Acceptance rates for Metropolis-Hastings algorithm:
## param rate
## 1 lambda 0.425
## 2 r1 0.140
## 3 r2 0.475
## Iteration: 401 (20% completed; 0.53244 secs elapsed)
## Acceptance rates for Metropolis-Hastings algorithm:
## param rate
## 1 lambda 0.4500
## 2 r1 0.1475
## 3 r2 0.3975
## Iteration: 601 (30% completed; 0.84683 secs elapsed)
## Acceptance rates for Metropolis-Hastings algorithm:
## param rate
## 1 lambda 0.4750000
## 2 r1 0.1400000
## 3 r2 0.4083333
## Iteration: 801 (40% completed; 1.11445 secs elapsed)
## Acceptance rates for Metropolis-Hastings algorithm:
## param rate
## 1 lambda 0.48250
## 2 r1 0.13875
## 3 r2 0.40125
## Iteration: 1001 (50% completed; 1.38025 secs elapsed)
## Acceptance rates for Metropolis-Hastings algorithm:
## param rate
## 1 lambda 0.457
## 2 r1 0.137
## 3 r2 0.412
## Iteration: 1201 (60% completed; 1.64621 secs elapsed)
## Acceptance rates for Metropolis-Hastings algorithm:
## param rate
## 1 lambda 0.4683333
## 2 r1 0.1483333
## 3 r2 0.3975000
## Iteration: 1401 (70% completed; 1.95021 secs elapsed)
## Acceptance rates for Metropolis-Hastings algorithm:
## param rate
## 1 lambda 0.4885714
## 2 r1 0.1500000
## 3 r2 0.4028571
## Iteration: 1601 (80% completed; 2.21696 secs elapsed)
## Acceptance rates for Metropolis-Hastings algorithm:
## param rate
## 1 lambda 0.484375
## 2 r1 0.155625
## 3 r2 0.408750
## Iteration: 1801 (90% completed; 2.6581 secs elapsed)
## Acceptance rates for Metropolis-Hastings algorithm:
## param rate
## 1 lambda 0.4911111
## 2 r1 0.1627778
## 3 r2 0.4061111
## Iteration: 2001 (100% completed; 2.95569 secs elapsed)
## Acceptance rates for Metropolis-Hastings algorithm:
## param rate
## 1 lambda 0.4935
## 2 r1 0.1590
## 3 r2 0.4055
## Single chain
## Inference for the input samples (1 chains: each with iter = 2500; warmup = 1250):
##
## Q5 Q50 Q95 Mean SD Rhat Bulk_ESS Tail_ESS
## beta1 1.9 1.9 2.0 1.9 0.0 1.00 1195 1239
## beta2 0.0 0.1 0.3 0.1 0.1 1.00 861 1219
## lambda 4.9 12.1 30.5 14.3 8.5 1.00 126 93
## r1 0.0 0.0 0.1 0.0 0.0 1.03 77 78
## r2 0.0 0.0 0.2 0.1 0.0 1.00 112 100
## sigsq.eps 0.3 0.4 0.6 0.4 0.1 1.00 514 890
##
## For each parameter, Bulk_ESS and Tail_ESS are crude measures of
## effective sample size for bulk and tail quantities respectively (an ESS > 100
## per chain is considered good), and Rhat is the potential scale reduction
## factor on rank normalized split chains (at convergence, Rhat <= 1.05).
## mean se_mean sd 2.5% 25% 50%
## beta1 1.94451464 0.001354599 0.04679152 1.84957428 1.91337229 1.94425420
## beta2 0.10591205 0.003098008 0.09190734 -0.07244492 0.04307369 0.10417774
## lambda 14.34371537 0.741761549 8.52169987 4.05070504 7.75788463 12.05628895
## r1 0.02298157 0.001565890 0.01454607 0.01069964 0.01263807 0.01684080
## r2 0.05795423 0.004171398 0.04711284 0.01155283 0.02140563 0.04287896
## sigsq.eps 0.40958211 0.004090066 0.09282983 0.26471559 0.34187272 0.39924364
## 75% 97.5% n_eff Rhat valid Q5 Q50
## beta1 1.97704630 2.03445739 1186 0.9998051 1 1.86691892 1.94425420
## beta2 0.16721410 0.29665914 874 0.9996921 1 -0.04406332 0.10417774
## lambda 18.43684053 35.00798030 130 1.0031465 1 4.88809530 12.05628895
## r1 0.02691796 0.06024277 87 1.0285881 1 0.01105215 0.01684080
## r2 0.08141999 0.18286772 128 1.0017863 1 0.01242185 0.04287896
## sigsq.eps 0.46162498 0.62034160 515 0.9992769 1 0.28188526 0.39924364
## Q95 MCSE_Q2.5 MCSE_Q25 MCSE_Q50 MCSE_Q75
## beta1 2.02154219 0.0053480030 0.0018254915 0.0017327069 0.001809085
## beta2 0.25197027 0.0060693244 0.0032074359 0.0037835712 0.004042631
## lambda 30.54061347 0.6391431563 0.5063700931 0.7650849805 1.072408138
## r1 0.05704602 0.0002311747 0.0003811785 0.0007469754 0.002080573
## r2 0.16096251 0.0004350924 0.0020817839 0.0040577461 0.005804157
## sigsq.eps 0.58415925 0.0046659948 0.0037588012 0.0037328009 0.005981460
## MCSE_Q97.5 MCSE_SD Bulk_ESS Tail_ESS
## beta1 0.003404986 0.0009580805 1195 1239
## beta2 0.010762886 0.0021913487 861 1219
## lambda 2.300751507 0.5314745343 126 93
## r1 0.001167341 0.0011110153 77 78
## r2 0.005557337 0.0029563949 112 100
## sigsq.eps 0.004102877 0.0028937522 514 890
Thanks to Haotian “Howie” Wu for invaluable feedback on early versions of the package.