Vishal Pandey | ML Research Engineer; Neuroscience

Dreaming in Latent Space: Deriving the Sequence-Level ELBO for World Models

Motivation: Why Latent World Models?

Imagine an RL agent that no longer needs to interact with the real environment at every timestep. Instead, it imagines futures inside its own learned model, a fast, compact, and differentiable simulator. This is the promise of latent world models.

But how do we train such models?
What’s the right objective to learn latent dynamics, reconstruct observations, and learn a policy, all at once?

This blog walks through the full derivation of the sequence-level Evidence Lower Bound (ELBO) used in world model training, inspired by methods like Dreamer and PlaNet.


Problem Setup

We deal with trajectories of observations and actions:

The goal is to model the joint distribution:

p(x1:T,a1:T)=p(x1:T,a1:T,z1:T)dz1:T

Since this integral is intractable (due to the latent variables), we use variational inference to approximate it.


Latent Generative Model

We define the following generative process:

pθ(x1:T,a1:T,z1:T)=p(z1)t=1Tp(xtzt)πθ(atzt)p(zt+1zt,at)

Each term corresponds to:

This is a latent controlled Markov process.


Why Not Maximize the Log-Likelihood Directly?

Because:

logp(x1:T,a1:T)=logp(x1:T,a1:T,z1:T)dz1:T

is intractable due to high-dimensional integration over z1:T.


Variational Posterior (Approximate Inference)

We introduce a learned variational distribution:

qϕ(z1:Tx1:T,a1:T)=t=1Tqϕ(ztxt,a<t)

This acts like a Bayesian filter each zt is inferred based on the past.


ELBO Derivation Step-by-Step

We apply Jensen’s inequality:

logp(x1:T,a1:T)=log𝔼q[p(x1:T,a1:T,z1:T)q(z1:T)]𝔼q[logp(x1:T,a1:T,z1:T)q(z1:T)]

Define this lower bound as ELBO:


Expand the Joint Log Terms

Now expand:

ELBO=𝔼qϕ[&logp(z1)+t=1Tlogp(xtzt)+t=1Tlogπθ(atzt)&+t=1T1logp(zt+1zt,at)t=1Tlogqϕ(zt·)]

Let’s rearrange the KL terms:


Final Form of the Sequence-Level ELBO

ELBO=𝔼qϕ[t=1Tlogp(xtzt)+logπθ(atzt)KL(qϕ(zt·)p(ztzt1,at1))]

This is our training objective.


Interpretation of Each Term

Term Meaning
logp(xtzt) Latent observation decoder. Ensures latent state can predict reality.
logπθ(atzt) Optimizes the policy in latent space.
KL-divergence Forces inferred latents to match transitions in the world model.

What This Loss Does:


Implementation Notes (Pytorch Pseudo)

# Latent rollout
z_t = sample_posterior(x[:t+1], a[:t])  # q_phi
x_hat = decoder(z_t)
a_hat = policy(z_t)
z_next_pred = dynamics(z_t, a_hat)

# ELBO components
log_px = log_prob(x_t, x_hat)
log_pi = log_prob(a_t, a_hat)
kl = kl_divergence(q_phi, p_theta)

# Total ELBO (maximize)
elbo = log_px + log_pi - kl

How This Enables "Imagination"

Once trained, the agent no longer needs to interact with the real environment at every step. It can simulate future rollouts inside its learned latent model:

  1. Sample a latent state:

    • z1~qϕ(z1x1)
  2. Roll forward using the dynamics model:

    • zt+1~p(zt+1zt,at)
  3. Select actions from the latent policy:

    • at~πθ(atzt)
  4. Optionally decode the imagined observations:

    • xt~p(xtzt)

The agent is now dreaming futures and acting based on them enabling fast planning and efficient behavior without expensive environment interaction.


Summary

We’ve derived and interpreted the sequence-level Evidence Lower Bound (ELBO) that forms the foundation of latent world model approaches like Dreamer, PlaNet, and others.


Takeaways