Recently, a series of concurrent works, Flow matching (Lipman et al., 2023), Rectified Flow (Liu et al., 2023), and Stochastic Interpolants (Albergo et al., 2023), proposed a new class of generative model based on Continuous Normalizing Flow (CNF), which we will refer to as Flow Matching (FM) based models. In principle, FM is a more flexible alternative to the current state-of-the-art Diffusion Models (DMs), and can be viewed as a generalization of DMs in two important ways:

  • Trasport between arbitrary distributions:
    DM requires the base probability distribution to be Gaussian. FM allows base distribtion to be any arbitrary distribution (e.g., image distribution). This can enable many applications such as image-to-image translations.
  • Arbitrary probability paths:
    Path followed by DM is just one of the infinitely many possible probability paths that could be attained using FM models. This is important because it allows for flexible, application-specific designs of probability paths such as optimal transport path.

Note: The blog assumes some familiarity with the problem of generative modeling, probability theory, basic differential equations, and diffusion models.

Normalizing Flow

FM can be seen as a subclass of the general flow based generative models. A flow model aims to transport a base distribution ρ0 to a target distribution ρ1, both defined over Rd, using a transport map Ψ:RdRd, i.e., if x0ρ0, then Ψ(x0)ρ1. A popular framework, called the Normalizing Flow, learns such a transport map Ψ using maximum likelihood objective to learn a data distribution ρ1 by fixing the base distribution ρ0 to be a simple distribution, e.g., Gaussian, that is amenable to easy sampling and density evaluation. The objective of normalizing flow follows from the change of variable formula for probability density function as shown below: maxΨExρ1logpΨ(x):=logρ0(Ψ1(x))+log|detJΨ1(x)| As one can see, the above objective involves computing the determinant of the Jacobian of the inverse of the transport map Ψ. For general functions Ψ, this is a computationally prohibitive operation, especially for high-dimensional data. To avoid this expensive computation, Ψ is parameterized with a sequence of simple invertible transformations such that the Jacobian determinant is easy to compute. This restriction limited the expressiveness of normalizing flow models, consequently limiting their performance compared to other generative models such as GANs.

Continuous Normalizing Flow (CNF)

CNF uses a continuous time perspective of the aforementioned transport process. Consider a continuous time dependent map Ψt for t[0,1] such that [Ψ1]#ρ0=ρ1 and [Ψ0]#ρ0=ρ0, a time-dependent velocity field, vt:RdRd, and corresponding time-dependent probability path pt:RdR>0. The vector field vt is related to the transport map Ψt via the following ordinary differential equation (ODE): ddtΨt(x)=vt(Ψt(x)). And the time-dependent probability path pt is related to end distributions via: pt=[Ψt]#ρ0

Samples from s
Figure 1: Illustration of the CNF idea.

Mass Continuity Equation

A velocity field vt results in the probability path pt if and only if it satisfies the mass continuity equation: ptt+(ptvt)=0. This equation follows from Gauss's divergence theorem by enforcing probability mass conservation. Basically, one can use this equation to verify if a vector field vt generates a given probability path pt, or to even find such a vector field (which is what we will do).

Flow Matching (FM)

Given a probability density path pt with p0=ρ0, p1=ρ1, and corresponding vector field ut, the flow matching objective is to minimize the following loss function: LFM(vt)=EtU([0,1]),xptvt(x)ut(x)2. LFM is simple but intractable in practice as it requires access to two quantities that we have no prior knowledge on: (i) samples from pt,t and (ii) the vector field ut. In the following sections, we will discuss how to solve the two problems, (i) and (ii), to make the flow matching objective practical.

Stochastic Interpolant

In (Albergo et al., 2023), the authors define a time differentiable interpolant function It:Rd×RdRd, t[0,1] such that It=0(x0,x1)=x0,It=1(x0,x1)=x1. A typical example of It is the linear interpolant: It(x0,x1)=(1t)x0+tx1. Next, a joint distribution (coupling) ρ(x0,x1) is chosen such that ρ(x0,x1)dx0=ρ1(x1), and  ρ(x0,x1)dx1=ρ0(x0). The independent coupling ρ(x0,x1)=ρ0(x0)ρ1(x1) is a special example that satisfies the above conditions. The final tractable flow matching objective takes the following form: L=Et,(x0,x1)ρ(x0,x1),x=It(x0,x1)vt(x)tIt(x0,x1)2 For the linear interpolant It(x0,x1)=(1t)x0+tx1 and independent coupling the above objective becomes: L=Et,x0ρ0,x1ρ1,x=(1t)x0+tx1vt(x)(x1x0)2 The above objective is very simple to implement: one just needs to randomly sample a x0 and x1 from the two datasets, a time stamp and regress the vector field at It(x0,x1) with a vector pointing from x0 to x1. At first glance, it almost seems too simple and naive: how can matching the vector field with random directions recover the desired vector field ut? The short answer is that, in expectation, these random directions will recover the desired vector field ut. To understand why this is the case, the following section will guide you through the detailed proof.

Proof:

In previous section, we discussed that the flow matching objective LFM is intractable due to two reasons: (i) sampling from pt and (ii) estimating ut. In this section, we will discuss how addressing these two issues results in the tractable and simple flow matching objective L.

(i) Sampling from pt

Then, xt=It(x0,x1) defines a stochastic process xt (hence called stochastic interpolant), given samples from (x0,x1)ρ(x0,x1). Probability path pt induced by xt is a valid time dependent probability path for constructing a CNF because when (x0,x1)ρ(x0,x1), I0(x0,x1)=x0p0=ρ0,I1(x0,x1)=x1p1=ρ1. Therefore, the procedure for sampling from pt is straightforward:
  1. Sample (x0,x1)ρ(x0,x1)
  2. Compute xt=It(x0,x1). xt is our required sample.

(ii) Estimating ut

Obtaining ut is a bit long, but not difficult to follow. First, note that we can express pt using Dirac delta function as follows: pt(x)=δIt(x0,x1)(x)ρ(x0,x1)dx0dx1, where δIt(x0,x1) is the Dirac delta function centered at It(x0,x1). We know that our desired velocity field ut corresponding to pt should satisfy the mass continuity equation: ptt=(ptut)tδIt(x0,x1)(x)ρ(x0,x1)dx0dx1=(pt(x)ut(x))(1)(tδIt(x0,x1)(x))ρ(x0,x1)dx0dx1=(pt(x)ut(x)) One important fact that will help us obtain ut is that δIt(x0,x1) itself defines a time dependent probability path between δx0 and δx1. Further, It(x0,x1) is the corresponding flow that achieves this continuous transport, i.e., [It(x0,x1)]#δx0=δIt(x0,x1) let us define the conditional vector field ut(x|x0,x1) as ut(x|x0,x1)={tIt(x0,x1) if x=It(x0,x1)0 otherwise Hence, by the definition of the velocity field of the flow, we have that ut(x|x0,x1) is the velocity field that induces the probability path δIt(x0,x1). Hence, the pair (It(x0,x1),ut(x|x0,x1)) satisfies the continuity equation. Therefore, tδIt(x0,x1)(x)=(δIt(x0,x1)(x)ut(x|x0,x1)) Using the above in Eq. (1), we get δIt(x0,x1)(x)ut(x|x0,x1)ρ(x0,x1)dx0dx1=(pt(x)ut(x))δIt(x0,x1)(x)ut(x|x0,x1)ρ(x0,x1)dx0dx1=(pt(x)ut(x))pt(x)(ut(x|x0,x1)δIt(x0,x1)(x)ρ(x0,x1)dx0dx1pt(x))=(pt(x)ut(x)) The above equation implies that ut(x)=ut(x|x0,x1)δIt(x0,x1)(x)ρ(x0,x1)dx0dx1pt(x) is a valid velocity field that satisfies the mass continuity equation with respect to pt. Now, let's use the above expression for ut(x) in LFM to obtain a practical objective for flow matching as follows: argminvt LFM(vt)=argminvt Et,xptvt(x)ut(x)2=argminvt Et,ptvt(x)22Et,ptvt(x),ut(x)+Et,ptut(x)2(2)=argminvt Et,ptvt(x)22Et,ptvt(x),ut(x) Consider the first term in the above equation. We can express it as follows: Et,ptvt(x)2=Etvt(x)2pt(x)dx=Etvt(x)2δIt(x0,x1)(x)ρ(x0,x1)dx0dx1dx=Etvt(x)2δIt(x0,x1)(x)ρ(x0,x1)dx0dx1dx(3)=Et,(x0,x1)ρ(x0,x1),x=It(x0,x1)vt(x)2 Now, consider the second term: Et,ptvt(x),ut(x)=Et,ptvt(x),ut(x|x0,x1)δIt(x0,x1)(x)ρ(x0,x1)dx0dx1pt(x)=Etvt(x),ut(x|x0,x1)δIt(x0,x1)(x)ρ(x0,x1)dx0dx1pt(x)pt(x)dx=Etvt(x),ut(x|x0,x1)δIt(x0,x1)(x)ρ(x0,x1)dx0dx1dx=Etvt(x),ut(x|x0,x1)δIt(x0,x1)(x)ρ(x0,x1)dx0dx1dx=Et,(x0,x1)ρ(x0,x1),x=It(x0,x1)vt(x),ut(x|x0,x1)(4)=Et,(x0,x1)ρ(x0,x1),x=It(x0,x1)vt(x),tIt(x0,x1) Using Eq. (3) and (4) in (2), we get argminvt LFM(vt)=argminvt Et,(x0,x1)ρ(x0,x1),x=It(x0,x1)vt(x)tIt(x0,x1)2

Note on Diffusion Model

Diffusion model (stochastic or deterministic/probability flow) can be viewed as a special case of flow matching. Consider a special instance of interpolant function It(x0,x1)=αtx0+σtx1. Further, let ρ0=N(0,1). Then Variance preserving (VP) SDE follows the probability path determined by σt=1αt2, where αt=exp(120tβ(s)ds), and β is the noise schedule function.

Results

Samples from s
Figure 2: Quantitative comparison with score based models and diffusion path (Image Source: Lipman et al., 2023).
Samples from s
Figure 3: [Top] Sample generated by evaluating It(x0,x1). [Bottom] Sample generated by integrating learnt velocity field (Image Source: Albergo et al., 2023).
Samples from s
Figure 4: [Left] Samples from score matching diffusion model. [Middle] Flow matching with diffusion probability path. [Right] Flow matching with linear interpolant It(x0,x1)=(1t)x0+tx1 (Image Source: Lipman et al., 2023).
Samples from s
Figure 5: Image to image translation result.(Image Source: Liu et al., 2023).


References

[Lipman et al., 2023] Lipman, Y., Chen, R. T., Ben-Hamu, H., Nickel, M., & Le, M. (2023). Flow matching for generative modeling. arXiv preprint arXiv:2210.02747.

[Chen et al., 2018] Chen, R. T., Rubanova, Y., Bettencourt, J., & Duvenaud, D. K. (2018). Neural ordinary differential equations. Advances in neural information processing systems, 31.

[Albergo et al., 2023] Albergo, M. S., & Vanden-Eijnden, E. (2023). Building normalizing flows with stochastic interpolants. arXiv preprint arXiv:2209.15571.

[Liu et al., 2023] Liu, X., Gong, C., & Liu, Q. (2023). Flow straight and fast: Learning to generate and transfer data with rectified flow. arXiv preprint arXiv:2209.03003.