Core Flow Matching Logic (The Essential Parts)¶

In [13]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np

# Check if GPU is available
if torch.cuda.is_available():
    device = torch.device('cuda') # nvidia/cuda
elif torch.mps.is_available():
    device = torch.device('mps') # apple
else:
    device = torch.device('cpu') # no acceleration

print(device)
mps
In [3]:
def get_flow_target(x_data, t, x_noise):
    """
    Compute the interpolation x_t and the target velocity.
    
    Rectified Flow Interpolation:
    x_t = t * x_data + (1 - t) * x_noise
    
    Target Velocity (dx_t/dt):
    v = x_data - x_noise
    
    Note: t goes from 0 (noise) to 1 (data).
    """
    # t is shape (B,), ensure correct broadcasting
    t = t.view(-1, 1)
    
    # Linear interpolation
    x_t = t * x_data + (1 - t) * x_noise
    
    # Target velocity 
    target_v = x_data - x_noise
    
    return x_t, target_v

def train_step(model, x_data):
    """Single training step: predict velocity field."""
    batch_size = x_data.shape[0]
    
    # 1. Sample noise and time
    x_noise = torch.randn_like(x_data)
    t = torch.rand(batch_size, device=x_data.device) # Uniform [0, 1]
    
    # 2. Compute interpolation and target
    x_t, target_v = get_flow_target(x_data, t, x_noise)
    
    # 3. Model predicts velocity
    predicted_v = model(x_t, t)
    
    # 4. MSE loss between predicted velocity and target direction
    loss = nn.functional.mse_loss(predicted_v, target_v)
    return loss

@torch.no_grad()
def sample_euler(model, n_samples, n_steps, device):
    """Generate samples by solving the ODE from t=0 (noise) to t=1 (data)."""
    # Start from pure noise (t=0)
    x = torch.randn(n_samples, 2).to(device)
    
    trajectory = [x.cpu().numpy()]
    dt = 1.0 / n_steps
    
    for i in range(n_steps):
        # Current time t
        t = i / n_steps
        t_batch = torch.full((n_samples,), t, device=device, dtype=torch.float32)
        
        # Predict velocity
        velocity = model(x, t_batch)
        
        # Euler step: x_{t+dt} = x_t + v(x_t, t) * dt
        x = x + velocity * dt
        
        trajectory.append(x.cpu().numpy())
        
    return x, trajectory

MLP Flow Model¶

In [4]:
class ResidualBlock(nn.Module):
    def __init__(self, hidden_dim, dropout_prob=0.1):
        super().__init__()
        self.split_fc = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(), 
            nn.Linear(hidden_dim, hidden_dim),
        )
        self.act = nn.SiLU()
        self.dropout = nn.Dropout(p=dropout_prob)
        
    def forward(self, x):
        return x + self.dropout(self.split_fc(x))

class VectorFieldModel(nn.Module):
    """MLP that predicts velocity v(x, t). Same architecture as DiffusionModel.
       Wait, diffusion model used integer timesteps and embedding. 
       We need continuous time embedding or just scalar input.
       Let's stick to Gaussian Fourier Features or sinusoidal embeddings for continuous t.
       Or, for simplicity given the previous demo, just use the same embedding 
       but scale t from [0, 1] to [0, 100] or similar, or just project scalar t.
    """
    
    def __init__(self, dim=2, hidden_dim=128):
        super().__init__()
        
        # Time embedding (Sinusoidal or simple projection for scalar t)
        # We'll use a simple Linear projection for t, but maybe Sinusoidal is better.
        # Let's use Gaussian Fourier features for better high frequency handling
        self.time_mlp = nn.Sequential(
            nn.Linear(1, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Initial projection
        self.input_proj = nn.Linear(dim, hidden_dim)
        
        # Residual layers
        self.layers = nn.ModuleList([
            ResidualBlock(hidden_dim),
            ResidualBlock(hidden_dim),
            ResidualBlock(hidden_dim),
            ResidualBlock(hidden_dim),
            ResidualBlock(hidden_dim),
        ])
        
        # Output projection
        self.final_norm = nn.LayerNorm(hidden_dim)
        self.output_proj = nn.Linear(hidden_dim, dim)
    
    def forward(self, x, t):
        # t is (B,) or (B, 1), make sure it's (B, 1)
        if t.ndim == 1:
            t = t.unsqueeze(-1)
            
        # Embed time: (B, H)
        t_emb = self.time_mlp(t)
        
        # Embed input: (B, H)
        x_emb = self.input_proj(x)
        
        # Combine
        h = x_emb + t_emb
        
        for layer in self.layers:
            h = layer(h)
            
        h = self.final_norm(h)
        return self.output_proj(h)

Toy Dataset (2D Spirals)¶

In [5]:
def create_spiral_data(n_samples=100000):
    """Generate 2D spiral dataset for visualization."""
    t = torch.linspace(0, 4*np.pi, n_samples)
    x = t * torch.cos(t) / (4*np.pi)
    y = t * torch.sin(t) / (4*np.pi)
    return torch.stack([x, y], dim=1)

Training Loop¶

In [6]:
def train_flow_model(n_epochs=10000):
    
    # Create data
    data = create_spiral_data(1000).to(device)
    
    # Create Model
    model = VectorFieldModel().to(device)
    
    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    # LR Scheduler
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=n_epochs//3, gamma=0.5)

    print(f"Training on {device}...")
    print(f"Data shape: {data.shape}")
    
    # Training loop
    for epoch in range(n_epochs):
        # Data batch (we use full dataset here since it's small)
        loss = train_step(model, data)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        if epoch % (n_epochs // 10) == 0:
            print(f"Epoch {epoch:4d} | Loss: {loss.item():.6f} | LR: {scheduler.get_last_lr()[0]:.2e}")
    
    return model, data, device

Visualization¶

In [7]:
def visualize_results(model, data, device, n_steps=50):
    """Create comprehensive visualization of the flow matching process."""
    
    fig = plt.figure(figsize=(9, 6))
    
    # 1. Original data
    ax1 = plt.subplot(2, 3, 1)
    ax1.scatter(data[:, 0].cpu(), data[:, 1].cpu(), s=2, alpha=0.5)
    ax1.set_title('Original Data (Target)', fontsize=12, fontweight='bold')
    ax1.set_xlim(-1, 1)
    ax1.set_ylim(-1, 1)
    ax1.set_aspect('equal')
    ax1.grid(True, alpha=0.3)
    
    # 2. Linear Interpolation (Ground Truth Flow)
    # x_t = t * x_data + (1-t) * x_noise
    ax2 = plt.subplot(2, 3, 2)
    sample_points = data[:200].cpu()
    noise_points = torch.randn_like(sample_points)
    
    times_to_show = [0.0, 0.25, 0.5, 0.75, 1.0]
    colors = plt.cm.viridis(np.linspace(0, 1, len(times_to_show)))
    
    for idx, t in enumerate(times_to_show):
        x_t = t * sample_points + (1 - t) * noise_points
        label = f't={t}' if t in [0, 1] else None # Simplify legend
        ax2.scatter(x_t[:, 0], x_t[:, 1], s=2, alpha=0.4, 
                   color=colors[idx], label=f't={t}')
        
    ax2.set_title('Training Paths (Linear)', fontsize=12, fontweight='bold')
    ax2.legend(markerscale=3, fontsize=8)
    ax2.set_xlim(-2, 2)
    ax2.set_ylim(-2, 2)
    ax2.set_aspect('equal')
    ax2.grid(True, alpha=0.3)
    
    # 3. Generated samples
    ax3 = plt.subplot(2, 3, 3)
    generated, _ = sample_euler(model, 1000, n_steps, device)
    ax3.scatter(generated[:, 0].cpu(), generated[:, 1].cpu(), s=2, alpha=0.5, color='orange')
    ax3.set_title('Generated Samples (ODE)', fontsize=12, fontweight='bold')
    ax3.set_xlim(-1, 1)
    ax3.set_ylim(-1, 1)
    ax3.set_aspect('equal')
    ax3.grid(True, alpha=0.3)
    
    # 4. Vector Field Visualization
    ax4 = plt.subplot(2, 3, 4)
    # Create grid
    grid_x, grid_y = torch.meshgrid(
        torch.linspace(-1.5, 1.5, 20),
        torch.linspace(-1.5, 1.5, 20),
        indexing='xy'
    )
    grid_points = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=1).to(device)
    
    # Visualize field at t=0.5
    t_val = 0.5
    t_batch = torch.full((grid_points.shape[0],), t_val, device=device)
    with torch.no_grad():
        v = model(grid_points, t_batch).cpu()
    
    ax4.quiver(grid_x.numpy(), grid_y.numpy(), v[:, 0].reshape(20,20), v[:, 1].reshape(20,20), 
               scale=20, alpha=0.6, color='purple')
    ax4.set_title(f'Velocity Field at t={t_val}', fontsize=12, fontweight='bold')
    ax4.set_xlim(-1.5, 1.5)
    ax4.set_ylim(-1.5, 1.5)
    ax4.set_aspect('equal')
    ax4.grid(True, alpha=0.3)
    
    # 5. Denoising trajectory
    ax5 = plt.subplot(2, 3, 5)
    _, trajectory = sample_euler(model, 100, n_steps, device)
    
    for i, points in enumerate(trajectory):
        # Only plot every few steps to avoid clutter
        if i % (n_steps // 10) == 0 or i == n_steps:
            alpha_val = i / len(trajectory)
            ax5.scatter(points[:, 0], points[:, 1], s=1, alpha=0.3,
                       color=plt.cm.plasma(alpha_val))
    
    ax5.set_title('Sample Trajectories (t=0->1)', fontsize=12, fontweight='bold')
    ax5.set_xlim(-2, 2)
    ax5.set_ylim(-2, 2)
    ax5.set_aspect('equal')
    ax5.grid(True, alpha=0.3)
    
    # 6. Comparison (Placeholder for Chamfer or just overlay)
    ax6 = plt.subplot(2, 3, 6)
    ax6.scatter(data[::10, 0].cpu(), data[::10, 1].cpu(), 
               s=2, alpha=0.5, color='blue', label='Real')
    ax6.scatter(generated[:300, 0].cpu(), generated[:300, 1].cpu(), 
               s=2, alpha=0.5, color='orange', label='Generated')
    
    ax6.set_title(f'Real vs Generated', fontsize=12, fontweight='bold')
    ax6.legend(markerscale=3)
    ax6.set_xlim(-1, 1)
    ax6.set_ylim(-1, 1)
    ax6.set_aspect('equal')
    ax6.grid(True, alpha=0.3)
    
    plt.tight_layout()
    return fig


def create_animation(model, data, device, n_steps=100, num_frames=100):
    """Create animation of the flow process from noise to data."""
    _, trajectory = sample_euler(model, 200, n_steps, device)
    
    # Subsample trajectory to get num_frames
    if len(trajectory) > num_frames:
        indices = np.linspace(0, len(trajectory) - 1, num_frames, dtype=int)
        trajectory = [trajectory[i] for i in indices]
        
    fig, ax = plt.subplots(figsize=(8, 8))
    
    # Final destination uses last frame
    final_points = trajectory[-1]
    
    # Colors based on radius of final points (spiral structure)
    radii = np.sqrt(np.sum(final_points**2, axis=1))
    norm = plt.Normalize(radii.min(), radii.max())
    colors = plt.cm.viridis(norm(radii))
    
    # Grid for vector field
    grid_x, grid_y = torch.meshgrid(
        torch.linspace(-2, 2, 20),
        torch.linspace(-2, 2, 20),
        indexing='xy'
    )
    grid_points = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=1).to(device)
    
    def update(frame):
        ax.clear()
        points = trajectory[frame]
        
        # Plot points
        ax.scatter(points[:, 0], points[:, 1], s=10, alpha=0.8, c=colors, zorder=3)
        
        # Background reference
        ax.scatter(data[::5, 0].cpu(), data[::5, 1].cpu(), 
                  s=5, alpha=0.15, color='blue', zorder=1)
        
        # Approximate time
        t_val = frame / (len(trajectory) - 1)
        
        # Vector field at current time
        with torch.no_grad():
            t_batch = torch.full((grid_points.shape[0],), t_val, device=device)
            v = model(grid_points, t_batch).cpu()
            
        # Plot sparse vector field
        v_np = v.numpy()
        ax.quiver(grid_x.numpy(), grid_y.numpy(), 
                 v_np[:, 0].reshape(20,20), v_np[:, 1].reshape(20,20), 
                 scale=50, alpha=0.2, color='gray', zorder=2)
        
        ax.set_title(f'Flow Matching | t: {t_val:.2f} (Noise -> Data)', 
                    fontsize=14, fontweight='bold')
        ax.set_xlim(-2, 2)
        ax.set_ylim(-2, 2)
        ax.set_aspect('equal')
        ax.grid(True, alpha=0.3)
    
    anim = FuncAnimation(fig, update, frames=len(trajectory), 
                        interval=50, repeat=False)
    return fig, anim

Main¶

Train model¶

In [8]:
model, data, device = train_flow_model(n_epochs=10000)
Training on mps...
Data shape: torch.Size([1000, 2])
Epoch    0 | Loss: 1.267313 | LR: 1.00e-03
Epoch 1000 | Loss: 0.632325 | LR: 1.00e-03
Epoch 2000 | Loss: 0.595868 | LR: 1.00e-03
Epoch 3000 | Loss: 0.615737 | LR: 1.00e-03
Epoch 4000 | Loss: 0.591772 | LR: 5.00e-04
Epoch 5000 | Loss: 0.624182 | LR: 5.00e-04
Epoch 6000 | Loss: 0.587788 | LR: 5.00e-04
Epoch 7000 | Loss: 0.565417 | LR: 2.50e-04
Epoch 8000 | Loss: 0.576013 | LR: 2.50e-04
Epoch 9000 | Loss: 0.602354 | LR: 2.50e-04
In [9]:
%matplotlib ipympl

# Static visualization
fig = visualize_results(model, data, device, n_steps=50)
plt.show()
Figure
No description has been provided for this image

Animation¶

In [14]:
fig_anim, anim = create_animation(model, data, device, n_steps=100)
plt.show()
Figure
No description has been provided for this image
In [ ]: