import Comments from '../../components/Comments'
Introduction
A central challenge in machine learning, particularly in generative modeling, is to model complex datasets using highly flexible families of probability distributions that maintain analytical or computational tractability for learning, sampling, inference, and evaluation.
This post summarizes the fundamental concepts of diffusion models, their optimization strategies, and applications, focusing on the mathematical foundations and practical implications.
Denoising Diffusion Probabilistic Models (DDPM)
Diffusion models are a class of generative models first proposed by Sohl-Dickstein et al. (Sohl-Dickstein et al., 2015). Inspired by nonequilibrium thermodynamics, the method systematically and gradually destroys data structure through a forward diffusion process, then learns a reverse process to restore structure and yield a highly flexible and tractable generative model.
Forward Process
A diffusion model formulates the learned data distribution as , where are latent variables with the same dimensionality as the real data .
The forward process is a Markov chain where transitions from to follow multivariate Gaussian distributions. The joint distribution of latent variables () given the real data is:
Key properties:
- The coefficients are pre-defined and determine the "velocity" of diffusion.
- After sufficient diffusion steps, the final state approaches an isotropic Gaussian distribution when is large enough.
- For clarity, we use to represent the learned distribution and to represent the real data distribution.
A notable property of this forward process, as mentioned in Ho et al. (Ho et al., 2020) (Section 2), is that has a closed-form expression derived using the reparameterization trick (opens in a new tab). Let and :
The sum of two uncorrelated multivariate normal distributions and is also a multivariate normal distribution (proof details (opens in a new tab)).
Typically, we can afford larger update steps when the sample becomes noisier, so and therefore .
Reverse Process
Ideally, if we knew , we could gradually remove noise from corrupted samples to recover the original image. However, this conditional distribution is not readily available and its computation requires the entire dataset. Specifically:
Computing requires evaluating integrals in Eq. (\ref{eq:imprac_cond_prob_expr}), which is computationally expensive. Instead, we use the diffusion model to learn and approximate the true conditional distribution. When is sufficiently small, is also Gaussian (details (opens in a new tab)).
The joint distribution of the diffusion model is:
With this distribution, we can sample from an isotropic Gaussian distribution and expect the reverse process to gradually transform it into samples that follow .
Loss Function
The loss function for training the diffusion model is the standard variational bound on negative log likelihood:
Therefore, the expected negative log likelihood is lower bounded by the variational lower bound:
To convert each term in the equation to be analytically computable, the objective can be further rewritten to be a combination of several KL-divergence and entropy terms (See the detailed step-by-step process in Appendix B in Sohl-Dickstein et al. (Sohl-Dickstein et al., 2015)):
In summary, we can split the variational lower bound into components and label them as follows:
In the above decomposition:
- can be computed from Eq. (\ref{eq:xt_x0_relation})
- are parameterized and learned
Next, we show that can be computed in closed form even though can't.
From the above derivation, we observe that the conditional distribution is also Gaussian and can be written in standard multivariate normal form as . Notice that , the and can be computed as follows,
Recall the relation between and deduced from Eq. \ref{eq:xt_x0_relation}, the Eq. \ref{eq:standard_form_mu_t} can be further rewrited as follows:
Parameterization of reverse diffusion process and
Recall our previous decomposition of variational lower bound loss, we have closed form computation of real data distribution (i.e. and ), we still need a parameterization of . As we discussed previously, when is small enough, we can approximate by the Gaussian distribution . We expect that the training process can let to predict . With this parameterization, each component of loss function is the KL divergence between two multivariate Gaussian distributions and has a relatively simple closed form (opens in a new tab). The loss term become
In practice, we can further simplify the loss function Eq. (\ref{eq:lt_before_simple_sigma}) by predefine the variancen matrix as and, experimentally, or had similar results. Therefore, we can write:
Furthermore, is further parameterized in section 3.2 of Ho et al. (Ho et al., 2020) to be corresponded with the form of in Eq. (\ref{eq:standard_form_mu_t}) as follows,
In this parameterization, the neural network will be used to approximate the instead of directly. This parameterization further simplify in Eq. (\ref{eq:lt_mu_t_not_param}) into
To this point, every term in can be computed in explicit closed form and is ready for training. However, empirically, Ho et al. (Ho et al., 2020) found that training the diffusion model works better with a simplified objective that ignores the weighting term:
This simplification enable us to compute arbitrary time steps for each sample , instead of computing the entire series as in . The entire DDPM algorithm is show as follows.
Side notes: The timestep is also an input to the neural network and it is typically encoded into some vector. For instance, in DDPM, integer timesteps are encoded into floats vector through sinusoidal function (opens in a new tab).
Denoising Diffusion Implicit Models (DDIM)
Though diffusion models like DDPM already demonstrated the ability to produce high quality samples that are comparable with the state-of-the-art generative model, such as GANs, the computation complexity of the sampling process is a critical drawback. These diffusion-based models typically require many iterations to produce a high quality sample, whereas models like GANs only need one iteration. A quantitative experiment in Song et al. (Song et al., 2020) shows that, with same GPU setup and similar neural network complexity, it takes around 20 hours to sample 50k images of size from a DDPM, but less than a minute to do so from a GAN model. To resolve this high computational cost without lossing too much generation quality, Song et al. (Song et al., 2020) proposed Denoising Diffusion Implicit Models(DDIM). This algorithm is based on two observation/intuitives.
-
The deduction of the loss function only depends on and the sampling process only depends on . To be more specific, the loss function remains the same form as long as the relation in Eq. (\ref{eq:xt_x0_relation}) still hold.
-
A DDPM trained on has, in fact, included the "knowledge" for training a DDPM with . This can be naturally observed from training process of the simplified version of loss function. It gives us a intuition that we can use a subset of parameters during the sampling process and reduce the computational cost.
Based on the first observation, we can build different conditional distributions that has the same distribution. Same marginal distribution results in the same loss function and different choices of conditional distribution results in different sampling choices. In fact, without the constraint of as in Eq. (\ref{eq:diff_model_origin}), we have a broader choice(i.e. a larger solution space) of .
Based on the second observation, the sampling process can only use a subset of steps used in training process. By reducing the updating steps, the sampling process can greatly speed-up.
Non-Markovian Forward Processes
The key observation here is that the DDPM loss function only deppends on the marginals , but not directly on the joint distribution or transition distribution . Follow the deduction in Spaces.Ac.cn (opens in a new tab), we can use undetermined coefficient method to compute the form of and we also assume it takes a normal distribution. We first summarize the condition as follows:
-
To maintain the same loss function, we need the same marginal distribution . The corresponded sampling process formula is .
-
Assume , where and are coefficient to be decided. The sampling process with is .
By combining the marginal distribution and assumed form of , we can compute the marginal distribution of as follows,
Comparing Eq. (\ref{eq:sample_with_undeter_coefs}) with Eq. (\ref{eq:xt_x0_relation}), remember that we need to let the marginal distribution to be the same and , we can have the following relation,
There are three variables and only two equation, therefore, we can view as a independent variable and solve that
Therefore, we can obtain a family of inference distribution indexed by ,
As a result, in the sampling procedure, the updating formula is,
Comparing with DDPM, this is generalized form of generative processes. Since the marginal distribution remains the same, the loss function did not change and the training process is identical. This means that we can use this new generative process with a diffusion model trained in DDPM way and, with different level of , we can generate different image with same initial noise. Among different choices, is a special case in which the generation process is deterministic given the initial noise. This model is called denoising diffusion implicit model since it is an implicit probabilistic model and it is trained with the DDPM objective.
Accelerate generation processes
We need to point out that, in our previous discussion, we did not start with and the sequence determined the model. The key observation here is that the training process of DDPM, in its essence, contained the data/processes of training over any subsequence . This can be observed from the loss functions. The training process over a set of parameter is
Therefore, DDPM trained on already incorporated information used to train DDPM on . When the size of is much smaller than , generating samples with the former parameters set will be much faster.
Remarks on DDIM
- Why don't we just directly train on and sample from the model?\ There might be two considerations for training on T steps but sampling in steps. Firstly, diffusion model trained on more sophisticated setup might improve the model's capability of generalization. Secondly, use subsequence to speed up is one way of acceleration and there might be other acceleration method with this more sophisticate model.
- Can we use DDPM and sample with subset of parameters ? What is the purpose of choosing this new family of conditional distribution?\ For purpose of accelerating sample generation process, one can certainly use DDPM and skip some steps during generation. However, clearly, the newly proposed distribution family is more flexible and has the potential of generate more diversified samples without any additional cost other than DDPM. As a matter of fact, letting , the DDIM is equivalent to DDPM's sampling process.
- Additional benefit comes with DDIM?\ DDIM has "consistency" property since the generative process is deterministic. It means that multiple samples conditioned on the same latent variable should have similar hig-level features. Due to this consistency, DDIM can do semantically meaningful interpolation in the latent variable.
Result comparison between DDPM and DDIM
Experiments in DDPM and DDIM paper have quantitatively and qualitatively examined the images generated. Here we review two aspects: the sampling quality and the interpolation result.
Sampling quality
In DDIM's experiment, as the following screenshot taken from it shows, the authors compared results of different number of diffusion steps and different level of noise. The empirical result is that lower noise level , the better image quality generated with accelerated diffusion process.
Interpolation
Both DDIM and DDPM examined their performance on interpolation of images. They use the forward process as stochastic encoder to generate embeddings . Then decoding the interploated latent where represent interpolation parameter(s).
- In DDPM, the authors simply use a linear interpolation, i.e. .
- n DDIM, the authors use a spherical linear interpolation,\
where .
Interesting Reading
- Lilian Weng's post on diffusion model (opens in a new tab)
- Spaces.Ac.cn's post on DDIM (opens in a new tab)