Diffusion models are trained to reverse a stochastic process through score matching. However, a lot of diffusion models rely on a small but critical implementation detail called thresholding. Thresholding projects the sampling process to the data support after each discretized diffusion step, stabilizing generation at the cost of breaking the theoretical framework. Interestingly, as one limits the number of steps to infinity, thresholding converges to a reflected stochastic differential equation. In this blog post, we will be discussing our recent work on Reflected Diffusion Models, which explores this connection to develop a new class of diffusion models which correctly trains for thresholded sampling and respects general boundary constraints.
Diffusion models
At their core, diffusion models start by perturbing data on
taking our initial data distribution
This defines a stochastic transport from our simple distribution
The only unknown component is the score function
which has several equivalent (but tractable) alternative forms (see Yang’s excellent blog post for an overview). Here,
Unfortunately, the devil always lies in the details. We need to discretize time to simulate the generative SDE, resulting in an Euler-Maruyama scheme with i.i.d. Brownian increments
Numerical error can arise from the discretization, the learned score function, or just plain bad luck when sampling the increments, causing our trajectory
it is even possible to never optimize in low probability regions, so the score network behavior there tends to be undefined. Previously, this problem appeared for vanilla score matching, degrading the performance of naive Langevin dynamics
Diffusion models were initially motivated as a stack of VAEs which gradually denoised the input
inspiring a natural fix. We seek to generate images, so the predicted mean
x_recon = tf.clip_by_value(x_recon, -1., 1.)
return x.clamp(-1, 1)
return x.clamp(-1, 1)
Thresholding avoids the divergent behavior we saw previously, generating much nicer samples.
Unfortunately, diffusion models are supposed to reverse the forward corrupting process. Changing the generative process like this breaks the fundamental assumption, resulting in a mismatch. This has been linked with phenomena like oversaturation when using a large amount of diffusion guidance (such as in the Imagen generated example), necessitating task-specific techniques that don’t generalize, such as dynamic thresholding
As we take
converges to what’s known as a reflected stochastic differential equation
The behavior is exactly the same on the interior of our domain, but, at the boundary, the new term
Similar to standard stochastic differential equations, it is possible to reverse this with a reverse reflected stochastic differential equation
This forms the same forward/reverse coupling that undergirds standard diffusion models, so we can use this principle to define Reflected Diffusion Models.
To construct our Reflected Diffusion Models, we need to learn
One still needs to accurately compute the transition density
Strategy 1 is accurate for small times since we don’t need to compute as many reflections, while strategy 2 is accurate for large times since the distribution goes closer to uniform, requiring fewer harmonic components. These strategies shore up the other’s weaknesses, so we combine them to efficiently compute the transition density.
We have already seen that thresholding provides a Euler-Maruyama type discretization. The core idea is that it approximates
This produces reasonable samples, but we can actually further augment the sampling procedure. Since
With this component, we match all of the constructs from standard diffusion models. We can achieve state-of-the-art perceptual quality (as measured by Inception score) without modifying the architecture or any other components. Unfortunately, the FID score, another common metric, tends to lag behind because our generated samples have noise (at the scale of 1-2 pixels) that FID is notoriously sensitive to.
Method | Inception score (↑) |
---|---|
NCSN++ | 9.89 |
Subspace Diffusion | 9.99 |
Ours | 10.42 |
Analogous to the widely used DDIM scheme
This results in a reflected diffusion process that has the same marginal probabilities as our original SDE, and allows us to sample with a lower variance. Amazingly, as we take
We can replace
Interestingly, learning with the
Method | CIFAR-10 BPD (↓) | ImageNet-32 BPD (↓) | |
---|---|---|---|
ScoreFlow | 2.86 | 3.83 | |
(with importance sampling) | 2.83 | 3.76 | |
VDM | 2.70 | —— | |
(with learned noise) | 2.65 | 3.72 | |
Ours | 2.68 | 3.74 |
One of the major perks of diffusion models is their controllability. Using some conditional information
Currently, this notion of controllable diffusion normally appears as classifier-free diffusion guidance
In the literature, increasing
Additionally, we can’t use quick deterministic ODE sampling methods since we can’t mimic the effect of thresholding. In fact, this seems to cause samples to diverge even more:
Furthermore, although a large guidance weight
We hypothesize that this these artifacts are due to the mismatch between training and sampling. In particular, the trained behavior seeks to push samples out-of-bounds, and the sampling procedure clips these out-of-bounds pixels to
Lastly, when we combine score networks for classifier-free guidance, the Neumann boundary condition is maintained. As such, it is possible to sample from classifier-free guided diffusion models using our Probability Flow ODE, requiring far fewer evaluations.
We have constructed our framework to be completely general with respect to the underlying domain
For instance, applying Reflected Diffusion Models to simplices results in a simplex diffusion method that learn in high dimensions. We did not explore this further, but in principle, this opens up potential applications in fields such as language modeling.
This blog post presented a detailed overview of our recent work “Reflected Diffusion Models”. Our full paper can be found on arxiv and contains many more mathematical details, additional results, and deeper explanations. We have also released code.