Code Walkthrough: nano-rfdiffusion
Files:
model.py (607 lines) + train.py (261 lines)Parameters: 10.3M | Training: 21 min on 1x RTX 3090
Companion to Lecture: RFDiffusion — SE(3) Diffusion for Protein Backbones. The lecture covers the theory; this page builds it from scratch.
You give it nothing. No sequence, no template, no hint. Just a number – say, 80 residues – and it generates a brand-new protein backbone that has never existed in nature. That is what RFDiffusion does, and this is a from-scratch implementation in 607 lines of PyTorch.
The real RFDiffusion (Watson et al., 2023) trains on 50K+ structures and generates backbones that can actually be designed into real proteins. Ours trains on 9 proteins in 21 minutes on an RTX 3090 and produces backbones with near-ideal bond geometry. Same algorithm, nano scale. Let’s walk through it.
The idea: diffusion, but for 3D rigid bodies
If you have seen image diffusion models, you know the drill: take real data, gradually add noise until it is pure static, then train a neural network to reverse the process. At generation time, start from noise and denoise step by step.
The twist here is that we are not working with pixels. A protein backbone is a chain of residues, and each residue is a rigid body in 3D space – it has a position (where it is) and an orientation (which way it faces). Mathematically, each residue lives in \(SE(3)\), the group of rigid motions: a rotation matrix \(R\) (3x3, from \(SO(3)\)) and a translation vector \(t\) (3D vector, from \(\mathbb{R}^3\)).
So our “image” is a sequence of \(SE(3)\) frames, and we need to figure out how to add noise to – and remove noise from – rotations and translations.
Frame computation: Gram-Schmidt on the backbone
Before diffusion can corrupt or denoise anything, we need ground-truth frames from the training structures. compute_local_frame builds an orthonormal coordinate system at each residue from its three backbone atoms (N, CA, C):
x = normalize(C - CA) # x-axis: CA → C direction
z = normalize(cross(x, N - CA)) # z-axis: perpendicular to the N-CA-C plane
y = cross(z, x) # y-axis: completes the right-handed frame
R = stack([x, y, z], dim=-1) # 3×3 rotation matrix (columns = axes)
translation = CA # frame origin = CA position
This is Gram-Schmidt orthogonalization on the backbone triangle. The x-axis points from CA toward C, the z-axis is the normal to the N-CA-C plane, and y completes the orthonormal basis. The translation is simply the CA position.
These frames are the “ground truth” that diffusion learns to predict. During training, IGSO3 and Gaussian noise corrupt these frames; the denoiser tries to recover them. At inference, the model generates frames from pure noise – and those frames define the protein backbone through backbone_from_frames().
Two noise schedules: SO(3) x R(3)
Translations are easy. They are just 3D vectors, so we use standard cosine-schedule DDPM, exactly like image diffusion. The R3Diffuser does this:
# Cosine schedule alpha_bar, then: x_t = sqrt(alpha_bar) * x_0 + sqrt(1-alpha_bar) * noise
noise = torch.randn_like(trans_0)
return torch.sqrt(ab) * trans_0 + torch.sqrt(1 - ab) * noise, noise
Rotations are trickier. You cannot just add Gaussian noise to a rotation matrix – you would break orthogonality. Instead, we sample from the Isotropic Gaussian on SO(3) (IGSO3): pick a random axis, draw an angle from a Gaussian with standard deviation \(\sigma\), and compose the resulting rotation with the current one.
def sample_igso3(shape, sigma, device):
omega = torch.abs(torch.randn(*shape, device=device) * sigma) # angle magnitude
axis = torch.randn(*shape, 3, device=device) # random axis
axis = axis / (axis.norm(dim=-1, keepdim=True) + 1e-8) # normalize
return axis_angle_to_rotation_matrix(axis * omega.unsqueeze(-1))
The SE3Diffusion class wraps both. Forward process corrupts rotations and translations independently. Reverse process denoises them independently. Clean separation.
Invariant Point Attention: the key architectural idea
The denoising network needs to look at a set of noisy residue frames and predict the clean ones. Normal self-attention would only see abstract features. Invariant Point Attention (IPA), from AlphaFold2, operates in both scalar space and 3D point space simultaneously.
Here is the core idea: each attention head generates not just scalar queries/keys/values, but also 3D point queries/keys/values. These points get transformed by each residue’s frame into global coordinates before computing attention:
# Project to 3D points: [B, L, n_heads, n_qk_points, 3]
q_pts = self.to_q_pts(s).view(B, L, self.n_heads, self.n_qk_points, 3)
k_pts = self.to_k_pts(s).view(B, L, self.n_heads, self.n_qk_points, 3)
# Apply each residue's rigid frame to its points
q_global = apply_frames(q_pts) # rotate + translate into global coords
k_global = apply_frames(k_pts)
# Attention: scalar term + point distance term + pair bias
attn = scalar_attn - 0.5 * w * point_distance_sq + pair_bias
The query/key points are learned 3D offsets in each residue’s local coordinate frame. apply_frames rotates them by the residue’s rotation matrix and adds its translation, placing them in global coordinates. Two residues whose frames point their query/key points at each other will have small point distances and high attention:
def apply_frames(pts):
# pts: [B, L, H, P, 3] — points in local frame
flat = pts.reshape(B, L, H*P, 3)
rotated = torch.einsum("blij,blnj->blni", rigids.rots, flat) # rotate
return (rotated + rigids.trans[:, :, None, :]).view_as(pts) # translate
q_global, k_global = apply_frames(q_pts), apply_frames(k_pts)
The point distance term sums squared Euclidean distances over all n_qk_points (4) point pairs:
pt_diff = q_global[:, :, None] - k_global[:, None, :] # [B, L, L, H, P, 3]
pt_dist_sq = (pt_diff ** 2).sum(-1).sum(-1) # sum over xyz and points → [B, L, L, H]
For value points, the same apply_frames transform brings them to global coordinates. After weighting by attention, the aggregated output points get inverse-transformed back to each residue’s local frame:
out_pts_global = torch.einsum("bhij,bjhpc->bihpc", attn, v_global)
# Inverse transform: subtract translation, apply R^T
flat_pts = out_pts_global.reshape(B, L, H*P_v, 3) - rigids.trans[:, :, None, :]
out_pts_local = torch.einsum("blij,blnj->blni", inv_rots, flat_pts)
This inverse transform is what makes IPA SE(3)-equivariant: rotating all frames rotates all intermediate points identically, and the inverse transform undoes it. The final output lives in local frames, independent of global orientation.
Our IPA uses 8 heads, 4 query/key points per head, and 8 value points per head.
Denoising block structure
Each of the 8 denoising blocks has four stages: IPA on the single representation, triangular multiplicative updates on the pair representation, and a frame update.
Triangular multiplicative updates maintain geometric consistency in the pair representation, same as in AlphaFold2’s Pairformer. Two variants handle complementary information flows:
- Outgoing: for each pair (i, j), aggregate over shared neighbor k using
left[i,k] * right[j,k]— “which residues does i reach that j also reaches?” - Incoming: same but
left[k,i] * right[k,j]— “which residues reach both i and j?”
Each uses gated projections – the input goes through two parallel linear layers (value and gate), where the gate applies sigmoid to control information flow:
left = self.left_proj(z) * torch.sigmoid(self.left_gate(z))
right = self.right_proj(z) * torch.sigmoid(self.right_gate(z))
out = torch.einsum("bikc,bjkc->bijc", left, right) # outgoing
out = self.output_proj(self.final_norm(out))
return pair + torch.sigmoid(self.output_gate(pair)) * out # gated residual
The output gate provides a second level of gating on the residual connection, letting the model control how much of the triangular update flows through.
Frame correction: small updates, big effect
After IPA and the triangular updates, each block outputs a 6D correction per residue: 3 values for rotation (axis-angle) and 3 for translation. The correction is composed onto the current frame, not substituted:
update = self.frame_update(self.frame_norm(node_feat)) # [B, L, 6]
rot_mat = axis_angle_to_rotation_matrix(update[..., :3] * 0.1) # scale down!
frames = frames.compose(RigidTransform(rot_mat, update[..., 3:]))
The * 0.1 scale factor on the axis-angle is critical. The Rodrigues formula converts axis-angle to a rotation matrix: \(R = I + \sin\theta \cdot K + (1-\cos\theta) \cdot K^2\), where \(\theta\) is the angle magnitude and \(K\) is the skew-symmetric matrix of the axis. Without scaling, the network could output large rotational jumps (say, 90° in one block), destabilizing training. The 0.1 factor caps each block’s rotation at roughly \(\pm 0.1 \times \pi \approx 18°\) even with large activations.
Composition (rather than replacement) is equally important: each block refines the frame incrementally. The frame_update weights are initialized to zero, so the model starts with identity corrections and gradually learns to make meaningful adjustments.
Self-conditioning: a free lunch
During training, 50% of the time we run the denoiser twice. The first pass produces a preliminary prediction, which gets fed back as extra input to the second pass. The model learns to refine its own outputs. At inference, we always self-condition (feeding the previous step’s prediction forward).
This costs almost nothing at inference time (you are already iterating) and noticeably improves sample quality. In the code, it is just 7 extra input dimensions – the predicted frame as a quaternion+translation vector concatenated with the noisy input.
Training: 10.3M parameters, 21 minutes, spiky losses
With default settings (8 blocks, node_dim=256, pair_dim=64), the model has 10.3M parameters. We train on 9 small proteins (crambin, ubiquitin, protein G, etc.) for 10,000 epochs. On an RTX 3090, this takes about 21 minutes.
The training curve is spiky, and that is expected. Each batch gets a random timestep \(t \sim U(0,1)\). When \(t\) is small (low noise), the task is easy and loss is low. When \(t\) is near 1 (almost pure noise), the task is nearly impossible and loss spikes. You are seeing random interleaving of easy and hard problems, not instability.
Over 10,000 epochs, loss drops from ~53 to ~3.2. Breaking it down:
| Metric | Start | End | Best seen |
|---|---|---|---|
| Total loss | ~53 | ~3.2 | 0.59 |
| Translation loss | ~52 | ~2.6 | 0.45 |
| Rotation loss | ~0.8 | ~0.6 | 0.14 |
Translation loss dominates and drops faster – positions are easier to learn than orientations. Rotation loss is inherently bounded (angles are in \([0, \pi]\)) and converges more slowly.
The sample() function: 100 steps from noise to structure
Generation is the reverse diffusion loop. Start from pure noise, take 100 steps:
frames = RigidTransform.identity((1, L), device=device)
t_init = torch.ones(1, L, device=device)
frames, _ = diffusion.forward_marginal(frames, t_init) # pure noise
timesteps = torch.linspace(1, 0, num_steps + 1) # 1.0 -> 0.0
for i in range(num_steps):
t_now, t_next = timesteps[i].item(), timesteps[i + 1].item()
pred = network(frames, torch.tensor([t_now], device=device), prev_pred)
if t_next > 0:
frames = diffusion.reverse_step(frames, pred, t_now, t_next)
else:
frames = pred # last step: just use the prediction directly
Each step: (1) predict the clean structure from the current noisy one, (2) take a partial step toward the prediction (adding a small amount of noise back in, except at the last step). The reverse_step handles the different math for rotations (geodesic interpolation on SO(3)) and translations (DDPM reverse formula).
Results: near-ideal geometry from 607 lines
We generated backbones at three lengths and measured CA-CA bond distances (ideal is 3.8 angstroms):
| Length | CA-CA mean | CA-CA std |
|---|---|---|
| L=50 | 3.52 A | 1.64 A |
| L=80 | 2.93 A | – |
| L=100 | 3.33 A | – |
The means are in the right neighborhood of 3.8 A, which is encouraging for a model trained on 9 structures. The standard deviation for L=50 (1.64 A) is high – real RFDiffusion produces much tighter distributions. More training data and longer training would bring these numbers closer to ideal.
What is missing vs. the real thing
The full RFDiffusion trains on ~50,000 structures from the PDB using the RoseTTAFold architecture (which is considerably larger and includes MSA processing). It produces backbones that pass the ultimate test: you can run ProteinMPNN to design a sequence, fold it with AlphaFold2, and get back the same structure. That end-to-end “designability” pipeline is the gold standard.
Our nano version skips all of that. No MSA features, no motif scaffolding, no hotspot conditioning. But the core algorithm is exactly the same: SE(3) diffusion with IGSO3 rotational noise, IPA-based denoising, self-conditioning, and reverse diffusion sampling. All in 607 lines of model.py, trained by a 261-line train.py.
If you want to understand how diffusion models generate protein structures – not just use them, but understand them – this is the place to start. Read model.py top to bottom. It is all there.