
<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" xml:lang="en"><generator uri="https://jekyllrb.com/" version="4.2.2">Jekyll</generator><link href="/feed.xml" rel="self" type="application/atom+xml" /><link href="/" rel="alternate" type="text/html" hreflang="en" /><updated>2026-05-20T13:52:03+00:00</updated><id>/feed.xml</id><title type="html">Kieran Didi</title><subtitle>PhD student interested in Protein Design, ML and bioinformatics.  Blogging about papers, tools and experiences in the CompBio space.
</subtitle><author><name>Kieran Didi</name><email>kieran.didi@gmail.com</email></author><entry><title type="html">The unification of representation learning and generative modelling</title><link href="/blog/ml/2025-12-31-r4g/" rel="alternate" type="text/html" title="The unification of representation learning and generative modelling" /><published>2025-12-31T00:00:00+00:00</published><updated>2026-02-17T20:16:30+00:00</updated><id>/blog/ml/r4g</id><content type="html" xml:base="/blog/ml/2025-12-31-r4g/"><![CDATA[<ul id="markdown-toc">
  <li><a href="#introduction" id="markdown-toc-introduction">Introduction</a></li>
  <li><a href="#background" id="markdown-toc-background">Background</a>    <ul>
      <li><a href="#generative-modelling-latent-diffusion-models" id="markdown-toc-generative-modelling-latent-diffusion-models">Generative modelling: Latent Diffusion Models</a></li>
      <li><a href="#representation-learning-self-supervised-vision-foundation-models" id="markdown-toc-representation-learning-self-supervised-vision-foundation-models">Representation Learning: self-supervised vision foundation models</a>        <ul>
          <li><a href="#from-cross-modal-contrastive-learning-to-single-modal-contrastive-learning" id="markdown-toc-from-cross-modal-contrastive-learning-to-single-modal-contrastive-learning">From Cross-Modal Contrastive Learning to Single-Modal Contrastive Learning</a></li>
          <li><a href="#visionlanguage-models-and-sigmoid-contrastive-losses" id="markdown-toc-visionlanguage-models-and-sigmoid-contrastive-losses">Vision–Language Models and Sigmoid Contrastive Losses</a></li>
          <li><a href="#self-distillation-and-momentum-encoders" id="markdown-toc-self-distillation-and-momentum-encoders">Self-Distillation and Momentum Encoders</a></li>
          <li><a href="#masked-image-modeling-and-predictive-architectures" id="markdown-toc-masked-image-modeling-and-predictive-architectures">Masked Image Modeling and Predictive Architectures</a></li>
        </ul>
      </li>
      <li><a href="#toward-a-platonic-representation-and-implications-for-generative-models" id="markdown-toc-toward-a-platonic-representation-and-implications-for-generative-models">Toward a Platonic Representation and Implications for Generative Models</a></li>
    </ul>
  </li>
  <li><a href="#overview-of-the-four-phases" id="markdown-toc-overview-of-the-four-phases">Overview of the Four Phases</a></li>
  <li><a href="#phase-1-aligning-diffusion-features-to-vision-foundation-models" id="markdown-toc-phase-1-aligning-diffusion-features-to-vision-foundation-models">Phase 1: Aligning Diffusion Features to Vision Foundation Models</a></li>
  <li><a href="#phase-2-aligning-the-vae-latent-space-to-foundation-models" id="markdown-toc-phase-2-aligning-the-vae-latent-space-to-foundation-models">Phase 2: Aligning the VAE Latent Space to Foundation Models</a></li>
  <li><a href="#phase-3-operating-directly-in-vision-foundation-model-feature-spaces" id="markdown-toc-phase-3-operating-directly-in-vision-foundation-model-feature-spaces">Phase 3: Operating Directly in Vision Foundation Model Feature Spaces</a></li>
  <li><a href="#phase-4-questioning-the-need-for-pretrained-representations" id="markdown-toc-phase-4-questioning-the-need-for-pretrained-representations">Phase 4: Questioning the Need for Pretrained Representations</a></li>
  <li><a href="#the-other-direction-generative-models-as-representations" id="markdown-toc-the-other-direction-generative-models-as-representations">The Other Direction: Generative Models as Representations</a>    <ul>
      <li><a href="#from-pixel-prediction-to-embedding-prediction" id="markdown-toc-from-pixel-prediction-to-embedding-prediction">From Pixel Prediction to Embedding Prediction</a></li>
      <li><a href="#diffusion-models-learn-representations-too" id="markdown-toc-diffusion-models-learn-representations-too">Diffusion Models Learn Representations Too</a></li>
    </ul>
  </li>
  <li><a href="#representation-learning-and-alignment-in-molecular-machine-learning" id="markdown-toc-representation-learning-and-alignment-in-molecular-machine-learning">Representation Learning and Alignment in Molecular Machine Learning</a>    <ul>
      <li><a href="#molecular-embeddings-borrowing-from-nlp-and-computer-vision" id="markdown-toc-molecular-embeddings-borrowing-from-nlp-and-computer-vision">Molecular Embeddings: Borrowing from NLP and Computer Vision</a></li>
      <li><a href="#where-to-go-from-here" id="markdown-toc-where-to-go-from-here">Where to Go from Here?</a></li>
    </ul>
  </li>
  <li><a href="#conclusion" id="markdown-toc-conclusion">Conclusion</a></li>
  <li><a href="#credits" id="markdown-toc-credits">Credits</a></li>
  <li><a href="#references" id="markdown-toc-references">References</a></li>
</ul>

<p><em>For the TLDR version of this post, see <a href="https://www.blopig.com/blog/2026/01/what-molecular-ml-can-learn-from-the-vision-communitys-representation-revolution/">this version on the OPIG blog</a>.</em></p>

<p>If you would like to cite this post in an academic context, you can use this BibTeX snippet:</p>

<div class="language-bibtex highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nc">@misc</span><span class="p">{</span><span class="nl">didi2025r4g</span><span class="p">,</span>
  <span class="na">author</span> <span class="p">=</span> <span class="s">{Didi, Kieran}</span><span class="p">,</span>
  <span class="na">title</span> <span class="p">=</span> <span class="s">{The unification of representation learning and generative modelling}</span><span class="p">,</span>
  <span class="na">url</span> <span class="p">=</span> <span class="s">{https://kdidi.netlify.app/blog/ml/2025-12-31-r4g/}</span><span class="p">,</span>
  <span class="na">year</span> <span class="p">=</span> <span class="s">{2025}</span>
<span class="p">}</span>
</code></pre></div></div>

<h2 id="introduction">Introduction</h2>

<p>Both generative modeling and representation learning have made impressive advances in recent years, particularly in computer vision. Diffusion <sup id="fnref:ho2020ddpm" role="doc-noteref"><a href="#fn:ho2020ddpm" class="footnote" rel="footnote">1</a></sup><sup id="fnref:song2020score" role="doc-noteref"><a href="#fn:song2020score" class="footnote" rel="footnote">2</a></sup><sup id="fnref:lai2025principles" role="doc-noteref"><a href="#fn:lai2025principles" class="footnote" rel="footnote">3</a></sup> and flow models <sup id="fnref:lipman2024flow" role="doc-noteref"><a href="#fn:lipman2024flow" class="footnote" rel="footnote">4</a></sup><sup id="fnref:albergo2023building" role="doc-noteref"><a href="#fn:albergo2023building" class="footnote" rel="footnote">5</a></sup> have achieved unprecedented generation quality, while self-supervised paradigms like CLIP <sup id="fnref:radford2021learning" role="doc-noteref"><a href="#fn:radford2021learning" class="footnote" rel="footnote">6</a></sup>, DINO <sup id="fnref:caron_emerging_2021" role="doc-noteref"><a href="#fn:caron_emerging_2021" class="footnote" rel="footnote">7</a></sup>, and MAE <sup id="fnref:he2022masked" role="doc-noteref"><a href="#fn:he2022masked" class="footnote" rel="footnote">8</a></sup> have enabled state-of-the-art performance on classification, detection, and depth estimation. Yet generation has remained separate from other vision tasks, raising a natural question: can we create unified representations useful for both discriminative and generative tasks?</p>

<p><img src="/assets/img/blog/r4g/timeline.png" alt="Evolution of representation learning and generative modeling (2019-2025)" />
<em>Fig 1. Three parallel timelines showing the independent evolution of representation learning methods, latent generative modeling architectures, and the recent convergence of these fields in R4G (Representation for Generation).</em></p>

<p>This field—sometimes termed Representation for Generation (R4G)—has evolved rapidly over the past year, with multiple groups independently converging on similar insights. The rapid development reveals fundamental questions about visual representations: Are features learned during generation inherently different from those learned discriminatively? Can we bridge these paradigms for more efficient systems? Recent evidence suggests diffusion models already learn semantically meaningful representations <sup id="fnref:kadkhodaie_unconditional_2025" role="doc-noteref"><a href="#fn:kadkhodaie_unconditional_2025" class="footnote" rel="footnote">9</a></sup><sup id="fnref:liang_how_2024" role="doc-noteref"><a href="#fn:liang_how_2024" class="footnote" rel="footnote">10</a></sup>, and that generative classifiers exhibit surprisingly human-like properties <sup id="fnref:jaini_intriguing_2024" role="doc-noteref"><a href="#fn:jaini_intriguing_2024" class="footnote" rel="footnote">11</a></sup>. Is there a way these two can benefit from each other more explicitly?</p>

<p><img src="/assets/img/blog/r4g/overview_figure.png" alt="Four phases of representation-guided generation" />
<em>Fig 2. Evolution from no alignment (Phase 0) through feature alignment (Phase 1), VAE alignment (Phase 2), and VAE-less direct embedding diffusion (Phase 3). Phase 4 (pixel-space diffusion without pretrained models) represents a parallel evolution that has been ongoing throughout, with complementary contributions to pure latent-space methods. Each block shows the architectural approach and which papers introduced key innovations.</em></p>

<p>When I started reading into the literature, I was honestly quite overwhelmed by the sheer number of papers and approaches being proposed there this year alone, with new papers coming out every week. But after some reading and discussion of some of these papers with the respective authors as well as colleagues some patterns started to emerge. In this blog post I try to organize recent developments into four phases reflecting my take on how the field developed during 2025: from initial alignment strategies to questioning whether pretrained representations are necessary at all. As part of this I also touch upon pixel-space versus latent diffusion models (again) and how the trend goes both ways, i.e. how we can use generative models for representation learning. Finally, because at heart I am a molecule guy, I share some of my thoughts on how these ideas are beginning to influence molecular machine learning, and exciting directions to pursue there.</p>

<p><img src="/assets/img/blog/r4g/paper_overview.png" alt="The Top 25 papers from this blog" />
<em>Fig 3. An explosion of papers in this field has made it hard to keep an overview; by the end of this post you should hopefully be able to read any of these papers here and place it on your mental map into one of the phases we will discuss and be able to compare it to similar approaches.</em></p>

<h2 id="background">Background</h2>

<p>To talk about the unification of representation learning and generative modelling, it might be wise to shortly talk about each of these separately and review what happened in each of them recently. It has been known for quite a while that they are intimitely related (for a recent take on this see <a href="https://www.youtube.com/watch?v=4VwXBrMoC0E">this excellent talk</a> by Kaiming He from a CVPR2025 workshop). However, in practice they still function quite differently, both in terms of the losses and training recipes employed as well as the neural network architectures used. What is the latest in both of these fields?</p>

<p><img src="/assets/img/blog/r4g/generation_vs_representation.png" alt="Generation vs Representation" />
<em>Fig 4. Generative Modelling and Representation Learning are intimately connected, two sides of the same coin: while representation learning tries to map from data to some semantic representation space (e.g. to allow for easier classification of objects in an image), generative modelling wants to maps from abstract concepts like text prompts to actual data samples. Image from the <a href="https://visionbook.mit.edu/generative_modeling_and_rep_learning.html">“Foundation of Computer Vision” online book</a></em></p>

<h3 id="generative-modelling-latent-diffusion-models">Generative modelling: Latent Diffusion Models</h3>

<p><img src="/assets/img/blog/r4g/ldm_diagram.png" alt="LDM training pipeline" />
<em>Fig 5. Two-stage training: first, a VAE compresses images into latent space; second, diffusion operates in this latent space. This modular design accelerated adoption but separated tokenizer training from diffusion model training.</em></p>

<p>Diffusion models revolutionized generation by framing it as iterative denoising. While Sohl-Dickstein et al. <sup id="fnref:sohldickstein2015deep" role="doc-noteref"><a href="#fn:sohldickstein2015deep" class="footnote" rel="footnote">12</a></sup> first introduced the core idea of learning to reverse a diffusion process in 2015, the approach didn’t scale until Ho et al. <sup id="fnref:ho2020ddpm:1" role="doc-noteref"><a href="#fn:ho2020ddpm" class="footnote" rel="footnote">1</a></sup> introduced DDPMs that gradually add noise and learn to reverse it; Song et al. looked at it from a score-based perspective <sup id="fnref:song2020improved" role="doc-noteref"><a href="#fn:song2020improved" class="footnote" rel="footnote">13</a></sup>; and in 2021 they got together to formalize a unified perspective through stochastic differential equations <sup id="fnref:song2020score:1" role="doc-noteref"><a href="#fn:song2020score" class="footnote" rel="footnote">2</a></sup>. This was all still in pixel-space; each pixel was denoised in RGB space, hindering both scalability as well as performance. The key breakthrough came with latent diffusion: Vahdat et al. adopted their already NVAE work <sup id="fnref:vahdat2020nvae" role="doc-noteref"><a href="#fn:vahdat2020nvae" class="footnote" rel="footnote">14</a></sup> to propose LSGM <sup id="fnref:vahdat_score-based_2021" role="doc-noteref"><a href="#fn:vahdat_score-based_2021" class="footnote" rel="footnote">15</a></sup>, a theoretically principled framework for joint VAE+diffusion training with tractable score matching and proper variational bounds. However, despite superior theory, LSGM’s engineering complexity, including spectral regularization, careful hyperparameter tuning and variance reduction, limited practical adoption.</p>

<p>Rombach et al.’s Latent Diffusion Models (LDMs) <sup id="fnref:rombach_high-resolution_2022" role="doc-noteref"><a href="#fn:rombach_high-resolution_2022" class="footnote" rel="footnote">16</a></sup> simplified this dramatically. Rather than joint end-to-end training, LDMs adopted a two-stage design: first, a VAE compresses images into lower-dimensional latents (typically <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mn>256</mn><mo>×</mo><mn>256</mn><mo>×</mo><mn>3</mn><mo>→</mo><mn>32</mn><mo>×</mo><mn>32</mn><mo>×</mo><mn>4</mn></mrow><annotation encoding="application/x-tex">256 \times 256 \times 3 \to 32 \times 32 \times 4</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.7278em;vertical-align:-0.0833em;"></span><span class="mord">256</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.7278em;vertical-align:-0.0833em;"></span><span class="mord">256</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">3</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">→</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.7278em;vertical-align:-0.0833em;"></span><span class="mord">32</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.7278em;vertical-align:-0.0833em;"></span><span class="mord">32</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">4</span></span></span></span>); second, diffusion operates in this latent space. A key insight of the LDM paper—and where it fundamentally differs from LSGM—is that during autoencoding the model only encodes <em>perceptually relevant</em> details in the latent space, but not all fine-grained ultra high entropy information (like texture details). This pixel-level detail is re-generated in the decoder. This comes from incorporating a patch-based discriminator loss (discard local details and regenerate them) as well as the LPIPS loss (reconstruct in a feature space of perceptually relevant features) in addition to the regular MSE loss. The LDM paper calls this semantic vs. perceptual compression (see Fig. 2 in <sup id="fnref:rombach_high-resolution_2022:1" role="doc-noteref"><a href="#fn:rombach_high-resolution_2022" class="footnote" rel="footnote">16</a></sup>). This is drastically different from LSGM: in LSGM, all high-entropy local details are encoded in latent space—the same information that regular pixel-space DDPM models must model—which is why perhaps LSGM took longer to train and required bigger models in the latent space. LDM’s smart compression scheme for visual signals makes things way more scalable for image and video data. For a more in-depth discussion on latent diffusion models, see Sander Dieleman’s excellent blog <sup id="fnref:dieleman2025latents" role="doc-noteref"><a href="#fn:dieleman2025latents" class="footnote" rel="footnote">17</a></sup>. This simplified approach produced better perceptual quality and more stable training, allowing extension of the approach to other modalities like video generation <sup id="fnref:blattmann2023align" role="doc-noteref"><a href="#fn:blattmann2023align" class="footnote" rel="footnote">18</a></sup><sup id="fnref:brooks2024sora" role="doc-noteref"><a href="#fn:brooks2024sora" class="footnote" rel="footnote">19</a></sup>.</p>

<p>The modular two-stage approach provided significant advantages: VAEs pretrained once could be reused across different diffusion models, researchers could iterate independently on each component, and pretrained autoencoders from other work could be directly incorporated. This modularity accelerated research and deployment and enabled breakthroughs like Stable Diffusion XL <sup id="fnref:podell2023sdxl" role="doc-noteref"><a href="#fn:podell2023sdxl" class="footnote" rel="footnote">20</a></sup>. However, as subsequent sections discuss, this separation between tokenizer and generative model is now being reconsidered.</p>

<p>Peebles and Xie’s Diffusion Transformer (DiT) <sup id="fnref:peebles2023scalable" role="doc-noteref"><a href="#fn:peebles2023scalable" class="footnote" rel="footnote">21</a></sup> demonstrated that transformers could replace U-Nets, achieving state-of-the-art ImageNet generation with favorable scaling. DiT operates on latent patches, treating them as sequences like Vision Transformers. A key finding: model complexity correlates strongly with sample quality—increasing depth, width, or tokens consistently improves generation. The largest DiT-XL/2 model established transformers as scalable alternatives for diffusion, serving as the baseline against which subsequent alignment methods would be measured.</p>

<p>Recent developments have also explored alternative generative paradigms. Flow matching <sup id="fnref:lipman2024flow:1" role="doc-noteref"><a href="#fn:lipman2024flow" class="footnote" rel="footnote">4</a></sup> provides a simulation-free approach to training continuous normalizing flows with conceptually simpler and more flexible formulations than standard diffusion. The relationship between diffusion and flow matching has been clarified <sup id="fnref:gao2024diffusion" role="doc-noteref"><a href="#fn:gao2024diffusion" class="footnote" rel="footnote">22</a></sup>, showing they are fundamentally equivalent under certain conditions, differing primarily in parameterization and sampling schedules—one can make a good diffusion model work just as well, and one can also define suboptimal paths in flow matching. The popularity of flow matching likely stems more from its conceptual simplicity than from clear performance advantages. These rectified flow transformers have been successfully scaled to production systems <sup id="fnref:esser2024sd3" role="doc-noteref"><a href="#fn:esser2024sd3" class="footnote" rel="footnote">23</a></sup>.</p>

<h3 id="representation-learning-self-supervised-vision-foundation-models">Representation Learning: self-supervised vision foundation models</h3>

<p>Self-supervised learning aims to learn general-purpose visual representations without manual labels, enabling models to exploit vast unlabeled corpora <sup id="fnref:oord_representation_2019" role="doc-noteref"><a href="#fn:oord_representation_2019" class="footnote" rel="footnote">24</a></sup><sup id="fnref:chen_exploring_2020" role="doc-noteref"><a href="#fn:chen_exploring_2020" class="footnote" rel="footnote">25</a></sup><sup id="fnref:caron_deep_2019" role="doc-noteref"><a href="#fn:caron_deep_2019" class="footnote" rel="footnote">26</a></sup><sup id="fnref:caron_unsupervised_2021" role="doc-noteref"><a href="#fn:caron_unsupervised_2021" class="footnote" rel="footnote">27</a></sup><sup id="fnref:assran_self-supervised_2023" role="doc-noteref"><a href="#fn:assran_self-supervised_2023" class="footnote" rel="footnote">28</a></sup><sup id="fnref:huh_platonic_2024" role="doc-noteref"><a href="#fn:huh_platonic_2024" class="footnote" rel="footnote">29</a></sup>. Early approaches were largely contrastive: they defined positive and negative pairs and trained encoders so that positives map to nearby features while negatives are pushed apart <sup id="fnref:oord_representation_2019:1" role="doc-noteref"><a href="#fn:oord_representation_2019" class="footnote" rel="footnote">24</a></sup><sup id="fnref:chen_exploring_2020:1" role="doc-noteref"><a href="#fn:chen_exploring_2020" class="footnote" rel="footnote">25</a></sup><sup id="fnref:chen2020simclr" role="doc-noteref"><a href="#fn:chen2020simclr" class="footnote" rel="footnote">30</a></sup>. Subsequent work progressively weakened the dependence on labels, explicit negatives, and even pixel-level reconstruction, moving toward architectures that predict high-level, semantic representations <sup id="fnref:caron_emerging_2021:1" role="doc-noteref"><a href="#fn:caron_emerging_2021" class="footnote" rel="footnote">7</a></sup><sup id="fnref:oquab_dinov2_2024" role="doc-noteref"><a href="#fn:oquab_dinov2_2024" class="footnote" rel="footnote">31</a></sup><sup id="fnref:simeoni_dinov3_2025" role="doc-noteref"><a href="#fn:simeoni_dinov3_2025" class="footnote" rel="footnote">32</a></sup><sup id="fnref:assran_self-supervised_2023:1" role="doc-noteref"><a href="#fn:assran_self-supervised_2023" class="footnote" rel="footnote">28</a></sup><sup id="fnref:assran_v-jepa_2025" role="doc-noteref"><a href="#fn:assran_v-jepa_2025" class="footnote" rel="footnote">33</a></sup>.</p>

<h4 id="from-cross-modal-contrastive-learning-to-single-modal-contrastive-learning">From Cross-Modal Contrastive Learning to Single-Modal Contrastive Learning</h4>

<p>A natural starting point for self-supervised representation learning is cross-modal contrastive learning, where aligned pairs provide supervision “for free.” CLIP jointly trains image and text encoders so that the similarity between matching image–caption pairs is maximized and that between mismatched pairs is minimized, using a large-scale contrastive objective over Internet-scale image-text datasets <sup id="fnref:radford2021learning:1" role="doc-noteref"><a href="#fn:radford2021learning" class="footnote" rel="footnote">6</a></sup><sup id="fnref:brooks2024sora:1" role="doc-noteref"><a href="#fn:brooks2024sora" class="footnote" rel="footnote">19</a></sup><sup id="fnref:esser2024sd3:1" role="doc-noteref"><a href="#fn:esser2024sd3" class="footnote" rel="footnote">23</a></sup>. This removes the need for class labels but depends on enormous amounts of paired data, and on sufficiently many negatives in each batch to avoid trivial solutions where the model encodes only coarse semantics <sup id="fnref:zhai_sigmoid_2023" role="doc-noteref"><a href="#fn:zhai_sigmoid_2023" class="footnote" rel="footnote">34</a></sup>.</p>

<p>SimCLR showed that contrastive learning can work in a purely single-modal setting <sup id="fnref:chen2020simclr:1" role="doc-noteref"><a href="#fn:chen2020simclr" class="footnote" rel="footnote">30</a></sup>. Two heavily augmented views of the same image form a positive pair, and all other images in the batch serve as negatives. Combined with strong data augmentation, a temperature-scaled InfoNCE loss, and large encoder capacity, SimCLR achieves supervised-level performance on ImageNet, demonstrating that labels are not strictly necessary for high-quality features. However, as with most contrastive learning methods including the ones described before, this comes at the cost of extremely large batch sizes, which are needed to provide enough negative examples so that the contrastive loss encourages fine-grained, non-trivial representations rather than collapsing to coarse global features <sup id="fnref:oord_representation_2019:2" role="doc-noteref"><a href="#fn:oord_representation_2019" class="footnote" rel="footnote">24</a></sup>.</p>

<p>SwAV improves on this regime by replacing explicit pairwise comparisons with online clustering <sup id="fnref:caron_deep_2019:1" role="doc-noteref"><a href="#fn:caron_deep_2019" class="footnote" rel="footnote">26</a></sup><sup id="fnref:caron_unsupervised_2021:1" role="doc-noteref"><a href="#fn:caron_unsupervised_2021" class="footnote" rel="footnote">27</a></sup>. Instead of contrasting features directly, SwAV assigns representations to prototype clusters and enforces consistency of these assignments across multiple augmentations of the same image. This “swapped prediction” mechanism preserves many advantages of contrastive learning while being more memory-efficient and less sensitive to batch size, making it easier to scale to large datasets and long training schedules.</p>

<h4 id="visionlanguage-models-and-sigmoid-contrastive-losses">Vision–Language Models and Sigmoid Contrastive Losses</h4>

<p>Vision-language pretraining extends contrastive learning to cross-modal settings at scale. CLIP demonstrated that large-scale image-text contrastive learning yields highly transferable visual representations and strong zero-shot performance across tasks <sup id="fnref:radford2021learning:2" role="doc-noteref"><a href="#fn:radford2021learning" class="footnote" rel="footnote">6</a></sup><sup id="fnref:brooks2024sora:2" role="doc-noteref"><a href="#fn:brooks2024sora" class="footnote" rel="footnote">19</a></sup>. However, CLIP’s softmax-based loss ties batch size directly to the number of effective negatives, which complicates scaling and makes training expensive.</p>

<p>SigLIP addresses this by replacing the softmax contrastive loss with a pairwise sigmoid loss over image-text similarities <sup id="fnref:zhai_sigmoid_2023:1" role="doc-noteref"><a href="#fn:zhai_sigmoid_2023" class="footnote" rel="footnote">34</a></sup>. This loss operates independently on each pair, enabling smaller batch sizes while still learning strong fine-grained alignments between images and text. SigLIP 2 further augments this recipe by combining contrastive training with captioning-style objectives, self-supervised losses, and improved data mixtures, leading to better semantic understanding, localization, and dense prediction performance <sup id="fnref:tschannen_siglip_2025" role="doc-noteref"><a href="#fn:tschannen_siglip_2025" class="footnote" rel="footnote">35</a></sup>.</p>

<h4 id="self-distillation-and-momentum-encoders">Self-Distillation and Momentum Encoders</h4>

<p>A key limitation of contrastive methods is their reliance on negatives. BYOL and related methods showed that it is possible to dispense with explicit negatives by using a momentum-updated teacher network <sup id="fnref:grill_bootstrap_2020" role="doc-noteref"><a href="#fn:grill_bootstrap_2020" class="footnote" rel="footnote">36</a></sup><sup id="fnref:richemond_byol_2020" role="doc-noteref"><a href="#fn:richemond_byol_2020" class="footnote" rel="footnote">37</a></sup>. The student is trained to match the teacher’s representation of a differently augmented view of the same image; the teacher parameters are an exponential moving average of the student’s, which stabilizes training and prevents collapse in practice.</p>

<p>DINO extends this self-distillation paradigm and reveals several surprising properties of the resulting representations <sup id="fnref:caron_emerging_2021:2" role="doc-noteref"><a href="#fn:caron_emerging_2021" class="footnote" rel="footnote">7</a></sup>. Without labels or negatives, DINO learns features whose attention maps correspond to object boundaries and support unsupervised semantic segmentation, indicating non-trivial semantic organization. In principle, such momentum-encoder methods require only a single image per batch, since supervision comes from matching teacher and student outputs rather than contrasting with other samples.</p>

<p>DINOv2 scales this recipe with larger Vision Transformers, improved optimization, and a carefully curated, diverse training set <sup id="fnref:oquab_dinov2_2024:1" role="doc-noteref"><a href="#fn:oquab_dinov2_2024" class="footnote" rel="footnote">31</a></sup>. The resulting models produce highly robust and transferable features that rival or surpass supervised pretraining across many benchmarks, as well as serving as strong vision foundation encoders for downstream tasks, including generative modeling <sup id="fnref:skorokhodov_improving_2025" role="doc-noteref"><a href="#fn:skorokhodov_improving_2025" class="footnote" rel="footnote">38</a></sup><sup id="fnref:chen_masked_2025" role="doc-noteref"><a href="#fn:chen_masked_2025" class="footnote" rel="footnote">39</a></sup>. However, prolonged self-distillation can gradually erode fine-grained spatial information, especially in dense feature maps used for pixel-level tasks.</p>

<p>To address this, DINOv3 introduces Gram anchoring, a regularization that stabilizes dense feature representations over long training schedules by constraining second-order statistics across patches and scales <sup id="fnref:simeoni_dinov3_2025:1" role="doc-noteref"><a href="#fn:simeoni_dinov3_2025" class="footnote" rel="footnote">32</a></sup>. This mitigates the tendency of self-distillation to over-smooth features, preserving detailed structure that is crucial for dense prediction and generative tokenization while maintaining the semantic strengths of the DINO family.</p>

<h4 id="masked-image-modeling-and-predictive-architectures">Masked Image Modeling and Predictive Architectures</h4>

<p>In parallel, masked image modeling treats images analogously to masked language modeling in NLP. MAE masks a large fraction of image patches (typically around 75%) and trains an asymmetric encoder-decoder architecture to reconstruct the missing pixels <sup id="fnref:he2022masked:1" role="doc-noteref"><a href="#fn:he2022masked" class="footnote" rel="footnote">8</a></sup>. This forces the encoder to focus on global structure rather than local texture, producing efficient representations that work well for many downstream tasks with modest finetuning.</p>

<p>iBOT combines masked prediction with self-distillation, using a teacher network as an online tokenizer that predicts semantic tokens for masked patches instead of raw pixels <sup id="fnref:zhou2021ibot" role="doc-noteref"><a href="#fn:zhou2021ibot" class="footnote" rel="footnote">40</a></sup>. This hybrid objective closes much of the gap between contrastive and masked modeling approaches, yielding representations that perform strongly on both image-level classification and dense prediction.</p>

<p>Joint-embedding predictive architectures such as I-JEPA take a more explicitly semantic view: instead of reconstructing pixels, they predict high-level latent representations of masked regions from visible context <sup id="fnref:assran_self-supervised_2023:2" role="doc-noteref"><a href="#fn:assran_self-supervised_2023" class="footnote" rel="footnote">28</a></sup>. By operating entirely in representation space, I-JEPA avoids over-emphasizing low-level details and focuses learning on abstract structure, leading to scalable training and strong transfer across tasks.</p>

<h3 id="toward-a-platonic-representation-and-implications-for-generative-models">Toward a Platonic Representation and Implications for Generative Models</h3>

<p>Recent work from Philip Isola’s lab at MIT has provided empirical evidence for a remarkable phenomenon: representations learned by different models, architectures, and even modalities converge toward a shared structure as models scale and training data diversifies <sup id="fnref:huh_platonic_2024:1" role="doc-noteref"><a href="#fn:huh_platonic_2024" class="footnote" rel="footnote">29</a></sup>. This convergent behavior has motivated the Platonic Representation Hypothesis, which posits that as models grow in capacity and are trained on increasingly rich data, their internal representations converge toward a shared statistical model of reality; a “platonic” representation that is largely independent of any specific task or architecture <sup id="fnref:huh_platonic_2024:2" role="doc-noteref"><a href="#fn:huh_platonic_2024" class="footnote" rel="footnote">29</a></sup>.</p>

<p>The evidence for this convergence comes from multiple angles. The foundational work demonstrates that features from independently trained vision and language models become more aligned as scale and data diversity increase, and that different self-supervised objectives yield embeddings that occupy similar subspaces up to simple linear transformations <sup id="fnref:huh_platonic_2024:3" role="doc-noteref"><a href="#fn:huh_platonic_2024" class="footnote" rel="footnote">29</a></sup>. Subsequent research has shown that cross-modal training can benefit each modality individually: Gupta et al. demonstrate that leveraging unpaired multimodal data (e.g., text, audio, or images) consistently improves downstream performance in unimodal tasks, exploiting the assumption that different modalities are projections of a shared underlying reality <sup id="fnref:gupta_better_2025" role="doc-noteref"><a href="#fn:gupta_better_2025" class="footnote" rel="footnote">41</a></sup>. Perhaps most strikingly, Wang et al. show that when language models are prompted with sensory instructions (e.g., “see” or “hear”), their representations become more similar to specialist vision and audio encoders, revealing that text-only models implicitly encode multimodal structure that can be activated through appropriate prompting <sup id="fnref:wang_words_2025" role="doc-noteref"><a href="#fn:wang_words_2025" class="footnote" rel="footnote">42</a></sup>. This suggests that even purely text-trained language models converge toward similar representations as vision models, with the convergence becoming stronger as models scale <sup id="fnref:huh_platonic_2024:4" role="doc-noteref"><a href="#fn:huh_platonic_2024" class="footnote" rel="footnote">29</a></sup><sup id="fnref:wang_words_2025:1" role="doc-noteref"><a href="#fn:wang_words_2025" class="footnote" rel="footnote">42</a></sup>.</p>

<p>This hypothesis has direct implications for generative modeling. If discriminative vision foundation models such as DINOv2, DINOv3, and MAE converge toward an approximately optimal visual representation, then explicitly leveraging these encoders can accelerate the training and improve the quality of generative models that would otherwise have to discover similar structures from scratch <sup id="fnref:oquab_dinov2_2024:2" role="doc-noteref"><a href="#fn:oquab_dinov2_2024" class="footnote" rel="footnote">31</a></sup><sup id="fnref:skorokhodov_improving_2025:1" role="doc-noteref"><a href="#fn:skorokhodov_improving_2025" class="footnote" rel="footnote">38</a></sup><sup id="fnref:chen_masked_2025:1" role="doc-noteref"><a href="#fn:chen_masked_2025" class="footnote" rel="footnote">39</a></sup><sup id="fnref:bi_vision_2025" role="doc-noteref"><a href="#fn:bi_vision_2025" class="footnote" rel="footnote">43</a></sup>. Recent work on aligning diffusion models to pretrained visual encoders—through feature alignment, representation regularization, or joint training of tokenizers and generators—can thus be viewed as an attempt to steer generative models toward this platonic representation early in training <sup id="fnref:yu_representation_2025" role="doc-noteref"><a href="#fn:yu_representation_2025" class="footnote" rel="footnote">44</a></sup><sup id="fnref:wu_representation_2025" role="doc-noteref"><a href="#fn:wu_representation_2025" class="footnote" rel="footnote">45</a></sup><sup id="fnref:leng_repa-e_2025" role="doc-noteref"><a href="#fn:leng_repa-e_2025" class="footnote" rel="footnote">46</a></sup><sup id="fnref:yao_reconstruction_2025" role="doc-noteref"><a href="#fn:yao_reconstruction_2025" class="footnote" rel="footnote">47</a></sup><sup id="fnref:wang_repa_2025" role="doc-noteref"><a href="#fn:wang_repa_2025" class="footnote" rel="footnote">48</a></sup><sup id="fnref:wang_diffuse_2025" role="doc-noteref"><a href="#fn:wang_diffuse_2025" class="footnote" rel="footnote">49</a></sup>. This perspective sets the stage for our discussion of Phase 1 methods that explicitly align diffusion features to vision foundation models, and for later sections analyzing the emerging convergence between generative and discriminative representations.</p>

<h2 id="overview-of-the-four-phases">Overview of the Four Phases</h2>

<p>Before diving into the details, let me briefly outline the four phases and what to expect. <strong>Phase 1</strong> introduces representation alignment—regularizing diffusion features to match pretrained vision encoders like DINOv2. <strong>Phase 2</strong> takes this deeper by incorporating semantic structure into the VAE latent space itself. <strong>Phase 3</strong> questions whether we need VAE compression at all, proposing to diffuse directly in pretrained representation spaces. <strong>Phase 4</strong> represents a parallel evolution that has been ongoing throughout: improving pixel-space diffusion through architectural innovation, questioning whether pretrained representations are necessary at all.</p>

<p>A key insight that will emerge: spatial structure alignment matters more than global semantic information for generation quality. Methods that preserve local self-similarity patterns consistently outperform those optimizing for classification accuracy. Additionally, while these techniques are presented in the context of Latent Diffusion Models (where most scaling happens), many of the core ideas—representation alignment, semantic regularization—are general and apply equally to pixel-space methods.</p>

<p>It’s also worth noting that in its most general form, the “representation space” we generate from can be pure noise, as in standard pixel-space diffusion or GANs. Even in latent models, all generation ultimately originates from noise—the question is what structure we impose on the intermediate representations.</p>

<p>Beyond these four phases, we’ll also explore the reverse direction: how generative modeling itself can serve as a pretraining objective for learning discriminative representations. This bidirectional relationship—representations helping generation, and generation producing useful representations—suggests these paradigms may be more unified than historically assumed.</p>

<h2 id="phase-1-aligning-diffusion-features-to-vision-foundation-models">Phase 1: Aligning Diffusion Features to Vision Foundation Models</h2>

<p>The first wave at the end of 2024/start of 2025 recognized that diffusion models learn semantically meaningful representations during training, but more slowly and less effectively than specialized discriminative models <sup id="fnref:xiang_denoising_2023" role="doc-noteref"><a href="#fn:xiang_denoising_2023" class="footnote" rel="footnote">50</a></sup><sup id="fnref:chen2024deconstructing" role="doc-noteref"><a href="#fn:chen2024deconstructing" class="footnote" rel="footnote">51</a></sup>. The solution: align intermediate diffusion features with pretrained vision encoders to guide training.</p>

<p>REPA <sup id="fnref:yu_representation_2025:1" role="doc-noteref"><a href="#fn:yu_representation_2025" class="footnote" rel="footnote">44</a></sup> introduced this paradigm in October 2024 through straightforward regularization. The method extracts features from intermediate diffusion layers, projects them through small MLPs, and maximizes cosine similarity with frozen DINOv2 encoder features. This auxiliary loss complements standard denoising:</p>

<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><msub><mi mathvariant="script">L</mi><mtext>total</mtext></msub><mo>=</mo><msub><mi mathvariant="script">L</mi><mtext>diffusion</mtext></msub><mo>+</mo><mi>λ</mi><msub><mi mathvariant="script">L</mi><mtext>align</mtext></msub></mrow><annotation encoding="application/x-tex">\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{diffusion}} + \lambda \mathcal{L}_{\text{align}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathcal">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">total</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.8333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathcal">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">diffusion</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.9805em;vertical-align:-0.2861em;"></span><span class="mord mathnormal">λ</span><span class="mord"><span class="mord mathcal">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">align</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span></span>

<p>The paper builds upon the insights from earlier work that diffusion models learn discriminative representations during denoising, but it takes the critical step to show that aligning these emerging representations with high-quality pretrained features accelerates convergence. Longer training improves weak natural alignment, but the REPA loss strengthens this alignment from the start, leading to better representations <em>and</em> better generation—a dual benefit suggesting a genuinely helpful inductive bias.</p>

<p><img src="/assets/img/blog/r4g/LDM_to_REPA.png" alt="REPA: Feature alignment during diffusion denoising" />
<em>Fig 6. Left shows baseline LDM architecture. Right shows REPA with alignment loss from intermediate diffusion features to frozen DINOv2 representations, speeding early training through semantic guidance.</em></p>

<p>REG <sup id="fnref:wu_representation_2025:1" role="doc-noteref"><a href="#fn:wu_representation_2025" class="footnote" rel="footnote">45</a></sup> extended this by entangling semantic class tokens with latent content during denoising. Rather than just aligning intermediate features, REG concatenates the [CLS] token from frozen DINOv2 with noisy latents, training the diffusion model to jointly reconstruct noise and original [CLS] token. This minimal overhead (single token, <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo>&lt;</mo><mn>0.5</mn><mi mathvariant="normal">%</mi></mrow><annotation encoding="application/x-tex">&lt;0.5\%</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.5782em;vertical-align:-0.0391em;"></span><span class="mrel">&lt;</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.8056em;vertical-align:-0.0556em;"></span><span class="mord">0.5%</span></span></span></span> FLOPs increase) provides stronger guidance than feature alignment alone. Interestingly, class token concatenation helps substantially even without explicit REPA alignment, though combining both works best—suggesting multiple mechanisms for incorporating semantic structure can be complementary.</p>

<p><img src="/assets/img/blog/r4g/REPA_to_REG.png" alt="REG: Entangling class tokens with latents" />
<em>Fig 7. REG concatenates the [CLS] token from frozen DINOv2 with noisy latents, enabling joint reconstruction of both image content and semantic class information directly from pure noise.</em></p>

<p>HASTE <sup id="fnref:wang_repa_2025:1" role="doc-noteref"><a href="#fn:wang_repa_2025" class="footnote" rel="footnote">48</a></sup> addressed a key REPA limitation: alignment helps dramatically early but can plateau or degrade later. Once the generative model begins modeling the full data distribution, the lower-dimensional discriminative teacher becomes a constraint rather than guide. The discriminative encoder focuses on task-relevant semantics while discarding generative details; forcing continued alignment may prevent learning the distribution’s full complexity. HASTE introduces two-phase training: Phase I simultaneously distills attention maps (relational priors) and feature projections like REPA (semantic anchors) for rapid initial convergence. Phase II terminates alignment at a predetermined iteration, freeing the model to exploit its generative capacity. This simple modification achieves dramatic acceleration, with the key insight that alignment is most valuable for initial structure but counterproductive once basic semantic organization is learned.</p>

<p><img src="/assets/img/blog/r4g/REPA_to_HASTE.png" alt="HASTE: Early-stopped alignment with staged termination" />
<em>Fig 8. Phase I applies holistic alignment distilling both attention maps and features. Phase II terminates alignment one-shot at a fixed iteration, freeing the diffusion model to model the full distribution without the discriminative teacher constraint.</em></p>

<p>Several puzzling trends emerged in representation alignment that defied conventional understanding. Larger model variants within the same encoder family often led to similar or even worse generation performance despite higher ImageNet-1K accuracy—DINOv2’s larger variants showed diminishing returns, while PE and C-RADIO exhibited this counterintuitive pattern even more starkly<sup id="fnref:yu_representation_2025:2" role="doc-noteref"><a href="#fn:yu_representation_2025" class="footnote" rel="footnote">44</a></sup>. More strikingly, representations with dramatically higher global semantic understanding consistently underperformed: PE-Core-G (82.8% ImageNet accuracy) generated worse images than PE-Spatial-B (53.1% accuracy), and SAM2-S achieved strong generation performance despite only 24.1% ImageNet accuracy - approximately 60% lower than many competing encoders. Perhaps most revealing, controlled experiments showed that explicitly injecting global information through CLS token mixing improved linear probing accuracy from 70.7% to 78.5% while simultaneously degrading generation quality, with FID worsening from 19.2 to 25.4.</p>

<p>iREPA’s analysis in December 2025<sup id="fnref:singh_what_2025" role="doc-noteref"><a href="#fn:singh_what_2025" class="footnote" rel="footnote">52</a></sup> resolved these contradictions by demonstrating that spatial structure—the self-similarity patterns between patch tokens—not global semantics, drives representation alignment effectiveness. To quantify this insight, the authors measured spatial self-similarity structure <sup id="fnref:shechtman2007matching" role="doc-noteref"><a href="#fn:shechtman2007matching" class="footnote" rel="footnote">53</a></sup> across patch tokens and performed large-scale correlation analysis across 27 vision encoders and three model sizes. Spatial structure metrics exhibited remarkably strong correlation with generation FID (Pearson |r| &gt; 0.852 for metrics like Local Distance Similarity, Short-Range Spatial Similarity, Cosine Distance Similarity, and Relative Mean Spatial Contrast), far exceeding ImageNet-1K accuracy’s predictive power (|r| = 0.26).
​
This explained SAM2’s paradoxical success: despite poor classification accuracy, it maintained strong spatial structure that proved ideal for generation. The authors then took this further and made two small modifications to the REPA recipe: by replacing standard MLP projection with convolutional layers that preserve local spatial relationships and implementing spatial normalization to accentuate relational structure transfer, iREPA (implemented in fewer than 4 lines of code) consistently improves convergence speed across diverse encoders, model sizes, and training recipes including REPA, REPA-E (more on this later), and MeanFlow (a few-step training method). This aligns with HASTE’s emphasis on attention distillation: the success lies in teaching spatial organization coherence rather than transferring high-level semantic concepts.</p>

<p>These findings are quite intuitive when you think about what generative models need to do: they must model all spatial structure in detail, which is expensive to discover from scratch. In contrast, global semantic understanding is more relevant for classification than for pixel-level generation. A model that knows “this is a dog” but doesn’t understand the spatial relationships between patches will generate poorly, while a model with strong spatial coherence but weaker global semantics can still produce coherent images.</p>

<p>Another angle to think about why alignment helps is that the diffusion training objective is inherently high variance: at each iteration, we present a noisy input <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi mathvariant="bold">x</mi><mi>t</mi></msub></mrow><annotation encoding="application/x-tex">\mathbf{x}_t</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.5944em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathbf">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> and ask the model to predict <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi mathvariant="bold">x</mi><mn>0</mn></msub></mrow><annotation encoding="application/x-tex">\mathbf{x}_0</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.5944em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathbf">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">0</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span>, but the <em>optimal</em> prediction is not any single <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi mathvariant="bold">x</mi><mn>0</mn></msub></mrow><annotation encoding="application/x-tex">\mathbf{x}_0</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.5944em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathbf">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">0</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> seen during training—it’s the <em>expectation</em> <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi mathvariant="double-struck">E</mi><mo stretchy="false">[</mo><msub><mi mathvariant="bold">x</mi><mn>0</mn></msub><mo>∣</mo><msub><mi mathvariant="bold">x</mi><mi>t</mi></msub><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex">\mathbb{E}[\mathbf{x}_0 \mid \mathbf{x}_t]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathbb">E</span><span class="mopen">[</span><span class="mord"><span class="mord mathbf">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">0</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">∣</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathbf">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">]</span></span></span></span> over all possible clean images consistent with that noisy input <sup id="fnref:dieleman2023geometry" role="doc-noteref"><a href="#fn:dieleman2023geometry" class="footnote" rel="footnote">54</a></sup>. At high noise levels, this expectation corresponds roughly to the mean of the entire dataset. The model must learn to implicitly average over many possible reconstructions, but it only ever sees individual samples as supervision. This mismatch between what we supervise (samples) and what we want (expectations) makes the objective noisy and slows representation learning.</p>

<p>Representation alignment methods like REPA address this by providing a low-variance auxiliary signal. The pretrained vision encoder already captures useful spatial and semantic structure through objectives that don’t suffer from this sample-vs-expectation mismatch. By aligning to these representations, we essentially provide the denoiser with a shortcut to the internal structure it needs for effective denoising, bypassing the slow process of discovering this structure through the high-variance diffusion objective alone.</p>

<p>In summary, phase 1 establishes clear patterns: Using pretrained vision foundation models through representation alignment dramatically accelerates diffusion training. However, the methods operate at the level of intermediate diffusion features, leaving the VAE latent space unchanged. Phase 2 takes the logical next step: incorporating semantic structure into the latent space itself.</p>

<h2 id="phase-2-aligning-the-vae-latent-space-to-foundation-models">Phase 2: Aligning the VAE Latent Space to Foundation Models</h2>

<p>While Phase 1 aligned intermediate diffusion features, Phase 2 recognized that the latent space itself—the compressed VAE representation—could incorporate semantic structure from vision foundation models. This deeper integration addresses the fundamental trade-off between reconstruction quality and learnability of the latent distribution.</p>

<p><img src="/assets/img/blog/r4g/good_latent_space.png" alt="The optimization dilemma in latent diffusion" />
<em>Fig 9. LSGM (2021) aims for smooth trajectories in latent space by normalizing distributions. LDM (2022) emphasizes highly compressed latents for computational efficiency. EQ-VAE and VA-VAE tackle the trade-off: improving encoder equivariance and aligning latent encodings with pretrained models to create learnable high-dimensional spaces. Image kindly adapted from Arash Vahdat.</em></p>

<p>Standard LDM treats VAE and diffusion training as independent, with the VAE optimized solely for pixel reconstruction (and perceptual quality by auxiliary losses relying on discriminators or metrics like LPIPS <sup id="fnref:dieleman2025latents:1" role="doc-noteref"><a href="#fn:dieleman2025latents" class="footnote" rel="footnote">17</a></sup>). This pixel-focused objective produces latents encoding low-level details effectively but lacking semantic structure. Increasing latent dimensionality improves reconstruction but creates higher-dimensional, more complex spaces for diffusion to learn—an “optimization dilemma” where better reconstruction leads to harder generation.</p>

<p><img src="/assets/img/blog/r4g/LDM_to_VAVAE.png" alt="VA-VAE: Aligning VAE latents to VFM during tokenizer training" />
<em>Fig 10. VAE encoder trains with both reconstruction loss and alignment loss to frozen VFM features, creating latents that are both reconstructive and semantically meaningful for efficient diffusion model training.</em></p>

<p>VA-VAE <sup id="fnref:yao_reconstruction_2025:1" role="doc-noteref"><a href="#fn:yao_reconstruction_2025" class="footnote" rel="footnote">47</a></sup> directly tackles this: it aligns the VAE’s latent space with pretrained vision foundation models during tokenizer training rather than relying solely on pixels via their VF loss:</p>

<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><msub><mi mathvariant="script">L</mi><mtext>VA-VAE</mtext></msub><mo>=</mo><msub><mi mathvariant="script">L</mi><mtext>recon</mtext></msub><mo>+</mo><mi>β</mi><mo>⋅</mo><mtext>KL</mtext><mo>+</mo><msub><mi>λ</mi><mtext>align</mtext></msub><msub><mi mathvariant="script">L</mi><mtext>VF</mtext></msub><mi mathvariant="normal">.</mi></mrow><annotation encoding="application/x-tex">\mathcal{L}_{\text{VA-VAE}} = \mathcal{L}_{\text{recon}} + \beta \cdot \text{KL} + \lambda_{\text{align}} \mathcal{L}_{\text{VF}}.</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathcal">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3283em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">VA-VAE</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.8333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathcal">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">recon</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">⋅</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.7667em;vertical-align:-0.0833em;"></span><span class="mord text"><span class="mord">KL</span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.9805em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord mathnormal">λ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">align</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span><span class="mord"><span class="mord mathcal">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3283em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">VF</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord">.</span></span></span></span></span>

<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><msub><mi mathvariant="script">L</mi><mtext>VF</mtext></msub><mo>=</mo><msub><mi>w</mi><mtext>hyper</mtext></msub><mtext> </mtext><msub><mi>w</mi><mtext>adaptive</mtext></msub><mo fence="false" stretchy="true" minsize="1.8em" maxsize="1.8em">(</mo><msub><mi mathvariant="script">L</mi><mtext>mcos</mtext></msub><mo>+</mo><msub><mi mathvariant="script">L</mi><mtext>mdms</mtext></msub><mo fence="false" stretchy="true" minsize="1.8em" maxsize="1.8em">)</mo><mi mathvariant="normal">.</mi></mrow><annotation encoding="application/x-tex">\mathcal{L}_{\text{VF}} = w_{\text{hyper}} \, w_{\text{adaptive}} \Big( \mathcal{L}_{\text{mcos}} + \mathcal{L}_{\text{mdms}} \Big).</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathcal">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3283em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">VF</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1.8em;vertical-align:-0.65em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02691em;">w</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:-0.0269em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">hyper</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02691em;">w</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:-0.0269em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">adaptive</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span><span class="mord"><span class="delimsizing size2">(</span></span><span class="mord"><span class="mord mathcal">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">mcos</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1.8em;vertical-align:-0.65em;"></span><span class="mord"><span class="mord mathcal">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">mdms</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord"><span class="delimsizing size2">)</span></span><span class="mord">.</span></span></span></span></span>

<p>with</p>

<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><msub><mi mathvariant="script">L</mi><mtext>mcos</mtext></msub><mo>=</mo><mfrac><mn>1</mn><mrow><mi>h</mi><mi>w</mi></mrow></mfrac><munderover><mo>∑</mo><mrow><mi>i</mi><mo>=</mo><mn>1</mn></mrow><mi>h</mi></munderover><munderover><mo>∑</mo><mrow><mi>j</mi><mo>=</mo><mn>1</mn></mrow><mi>w</mi></munderover><mi mathvariant="normal">ReLU</mi><mo>⁡</mo><mtext> ⁣</mtext><mrow><mo fence="true">(</mo><mn>1</mn><mo>−</mo><msub><mi>m</mi><mn>1</mn></msub><mo>−</mo><mfrac><mrow><msubsup><mi>z</mi><mrow><mi>i</mi><mi>j</mi></mrow><mo mathvariant="normal" lspace="0em" rspace="0em">′</mo></msubsup><mo>⋅</mo><msub><mi>f</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub></mrow><mrow><mo stretchy="false">∥</mo><msubsup><mi>z</mi><mrow><mi>i</mi><mi>j</mi></mrow><mo mathvariant="normal" lspace="0em" rspace="0em">′</mo></msubsup><mo stretchy="false">∥</mo><mtext> </mtext><mo stretchy="false">∥</mo><msub><mi>f</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub><mo stretchy="false">∥</mo></mrow></mfrac><mo fence="true">)</mo></mrow><mo separator="true">,</mo></mrow><annotation encoding="application/x-tex">\mathcal{L}_{\text{mcos}} = \frac{1}{h w} \sum_{i=1}^{h} \sum_{j=1}^{w} \operatorname{ReLU} \!\left( 1 - m_{1} - \frac{z&#x27;_{ij} \cdot f_{ij}} {\lVert z&#x27;_{ij}\rVert \,\lVert f_{ij}\rVert} \right),</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathcal">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">mcos</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:3.2499em;vertical-align:-1.4138em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.3214em;"><span style="top:-2.314em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal">h</span><span class="mord mathnormal" style="margin-right:0.02691em;">w</span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.686em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.8361em;"><span style="top:-1.8723em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span><span class="mrel mtight">=</span><span class="mord mtight">1</span></span></span></span><span style="top:-3.05em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op">∑</span></span></span><span style="top:-4.3em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">h</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.2777em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.6514em;"><span style="top:-1.8723em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span><span class="mrel mtight">=</span><span class="mord mtight">1</span></span></span></span><span style="top:-3.05em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op">∑</span></span></span><span style="top:-4.3em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.02691em;">w</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.4138em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop"><span class="mord mathrm">ReLU</span></span><span class="mspace" style="margin-right:-0.1667em;"></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="minner"><span class="mopen delimcenter" style="top:0em;"><span class="delimsizing size4">(</span></span><span class="mord">1</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord"><span class="mord mathnormal">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.5367em;"><span style="top:-2.314em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mopen">∥</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.04398em;">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.7337em;"><span style="top:-2.4231em;margin-left:-0.044em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">ij</span></span></span></span><span style="top:-3.0448em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.413em;"><span></span></span></span></span></span></span><span class="mclose">∥</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mopen">∥</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em;">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">ij</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span><span class="mclose">∥</span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.7848em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal" style="margin-right:0.04398em;">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.7519em;"><span style="top:-2.4413em;margin-left:-0.044em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">ij</span></span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.3948em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">⋅</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em;">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">ij</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.099em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose delimcenter" style="top:0em;"><span class="delimsizing size4">)</span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mpunct">,</span></span></span></span></span>

<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><msub><mi mathvariant="script">L</mi><mtext>mdms</mtext></msub><mo>=</mo><mfrac><mn>1</mn><msup><mi>N</mi><mn>2</mn></msup></mfrac><munder><mo>∑</mo><mrow><mi>i</mi><mo separator="true">,</mo><mi>j</mi></mrow></munder><mi mathvariant="normal">ReLU</mi><mo>⁡</mo><mrow><mo fence="true">(</mo><mrow><mo fence="true">∣</mo><mfrac><mrow><msub><mi>z</mi><mi>i</mi></msub><mo>⋅</mo><msub><mi>z</mi><mi>j</mi></msub></mrow><mrow><mo stretchy="false">∥</mo><msub><mi>z</mi><mi>i</mi></msub><mo stretchy="false">∥</mo><mtext> </mtext><mo stretchy="false">∥</mo><msub><mi>z</mi><mi>j</mi></msub><mo stretchy="false">∥</mo></mrow></mfrac><mo>−</mo><mfrac><mrow><msub><mi>f</mi><mi>i</mi></msub><mo>⋅</mo><msub><mi>f</mi><mi>j</mi></msub></mrow><mrow><mo stretchy="false">∥</mo><msub><mi>f</mi><mi>i</mi></msub><mo stretchy="false">∥</mo><mtext> </mtext><mo stretchy="false">∥</mo><msub><mi>f</mi><mi>j</mi></msub><mo stretchy="false">∥</mo></mrow></mfrac><mo fence="true">∣</mo></mrow><mo>−</mo><msub><mi>m</mi><mn>2</mn></msub><mo fence="true">)</mo></mrow><mo separator="true">,</mo></mrow><annotation encoding="application/x-tex">\mathcal{L}_{\text{mdms}} = \frac{1}{N^{2}} \sum_{i,j} \operatorname{ReLU} \left( \left| \frac{z_i \cdot z_j}{\lVert z_i\rVert \,\lVert z_j\rVert} - \frac{f_i \cdot f_j}{\lVert f_i\rVert \,\lVert f_j\rVert} \right| - m_{2} \right),</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathcal">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">mdms</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:2.8758em;vertical-align:-1.4138em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.3214em;"><span style="top:-2.314em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.7401em;"><span style="top:-2.989em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.686em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.05em;"><span style="top:-1.8723em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span style="top:-3.05em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op">∑</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.4138em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop"><span class="mord mathrm">ReLU</span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="minner"><span class="mopen delimcenter" style="top:0em;"><span class="delimsizing size3">(</span></span><span class="minner"><span class="mopen"><span class="delimsizing mult"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.462em;"><span style="top:-2.266em;"><span class="pstrut" style="height:3.216em;"></span><span class="delimsizinginner delim-size1"><span>∣</span></span></span><span style="top:-2.864em;"><span class="pstrut" style="height:3.216em;"></span><span style="height:1.216em;width:0.3333em;"><svg xmlns="http://www.w3.org/2000/svg" width='0.3333em' height='1.216em' style='width:0.3333em' viewBox='0 0 333.33000000000004 1216' preserveAspectRatio='xMinYMin'><path d='M145 0 H188 V1216 H145z M145 0 H188 V1216 H145z'/></svg></span></span><span style="top:-4.072em;"><span class="pstrut" style="height:3.216em;"></span><span class="delimsizinginner delim-size1"><span>∣</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.95em;"><span></span></span></span></span></span></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.1215em;"><span style="top:-2.314em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mopen">∥</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.04398em;">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:-0.044em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">∥</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mopen">∥</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.04398em;">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:-0.044em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span><span class="mclose">∥</span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal" style="margin-right:0.04398em;">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:-0.044em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">⋅</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.04398em;">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:-0.044em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.9721em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.3714em;"><span style="top:-2.314em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mopen">∥</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em;">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">∥</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mopen">∥</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em;">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span><span class="mclose">∥</span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em;">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">⋅</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em;">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.9721em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose"><span class="delimsizing mult"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.462em;"><span style="top:-2.266em;"><span class="pstrut" style="height:3.216em;"></span><span class="delimsizinginner delim-size1"><span>∣</span></span></span><span style="top:-2.864em;"><span class="pstrut" style="height:3.216em;"></span><span style="height:1.216em;width:0.3333em;"><svg xmlns="http://www.w3.org/2000/svg" width='0.3333em' height='1.216em' style='width:0.3333em' viewBox='0 0 333.33000000000004 1216' preserveAspectRatio='xMinYMin'><path d='M145 0 H188 V1216 H145z M145 0 H188 V1216 H145z'/></svg></span></span><span style="top:-4.072em;"><span class="pstrut" style="height:3.216em;"></span><span class="delimsizinginner delim-size1"><span>∣</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.95em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord"><span class="mord mathnormal">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose delimcenter" style="top:0em;"><span class="delimsizing size3">)</span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mpunct">,</span></span></span></span></span>

<p>where <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>z</mi></mrow><annotation encoding="application/x-tex">z</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord mathnormal" style="margin-right:0.04398em;">z</span></span></span></span> denotes latent vectors from the VAE latent feature map and <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>f</mi></mrow><annotation encoding="application/x-tex">f</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal" style="margin-right:0.10764em;">f</span></span></span></span> denotes vectors from the frozen vision foundation feature map for the same image; <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msup><mi>Z</mi><mo mathvariant="normal" lspace="0em" rspace="0em">′</mo></msup><mo>=</mo><mi>W</mi><mi>Z</mi></mrow><annotation encoding="application/x-tex">Z&#x27; = WZ</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.7519em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07153em;">Z</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.7519em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.13889em;">W</span><span class="mord mathnormal" style="margin-right:0.07153em;">Z</span></span></span></span> linearly projects VAE latents to the VFM feature dimension; <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msubsup><mi>z</mi><mrow><mi>i</mi><mi>j</mi></mrow><mo mathvariant="normal" lspace="0em" rspace="0em">′</mo></msubsup></mrow><annotation encoding="application/x-tex">z&#x27;_{ij}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.1467em;vertical-align:-0.3948em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.04398em;">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.7519em;"><span style="top:-2.4413em;margin-left:-0.044em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">ij</span></span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.3948em;"><span></span></span></span></span></span></span></span></span></span> and <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>f</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub></mrow><annotation encoding="application/x-tex">f_{ij}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9805em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em;">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">ij</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span> are the projected latent and VFM feature at spatial position <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo stretchy="false">(</mo><mi>i</mi><mo separator="true">,</mo><mi>j</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">(i,j)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord mathnormal">i</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathnormal" style="margin-right:0.05724em;">j</span><span class="mclose">)</span></span></span></span>; <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>z</mi><mi>i</mi></msub><mo separator="true">,</mo><msub><mi>f</mi><mi>i</mi></msub></mrow><annotation encoding="application/x-tex">z_i, f_i</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.04398em;">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:-0.044em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em;">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> are vectors at position <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>i</mi></mrow><annotation encoding="application/x-tex">i</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6595em;"></span><span class="mord mathnormal">i</span></span></span></span> after flattening the <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>h</mi><mo>×</mo><mi>w</mi></mrow><annotation encoding="application/x-tex">h \times w</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.7778em;vertical-align:-0.0833em;"></span><span class="mord mathnormal">h</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord mathnormal" style="margin-right:0.02691em;">w</span></span></span></span> grid into <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mi>h</mi><mi>w</mi></mrow><annotation encoding="application/x-tex">N = h w</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord mathnormal">h</span><span class="mord mathnormal" style="margin-right:0.02691em;">w</span></span></span></span> tokens; <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>m</mi><mn>1</mn></msub></mrow><annotation encoding="application/x-tex">m_1</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.5806em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> and <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>m</mi><mn>2</mn></msub></mrow><annotation encoding="application/x-tex">m_2</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.5806em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> are cosine-similarity margins; <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi mathvariant="script">L</mi><mtext>mcos</mtext></msub></mrow><annotation encoding="application/x-tex">\mathcal{L}_{\text{mcos}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathcal">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">mcos</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> enforces pointwise alignment, <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi mathvariant="script">L</mi><mtext>mdms</mtext></msub></mrow><annotation encoding="application/x-tex">\mathcal{L}_{\text{mdms}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathcal">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">mdms</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> aligns pairwise relational structure; <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>w</mi><mtext>adaptive</mtext></msub><mo>=</mo><mo stretchy="false">∥</mo><mi mathvariant="normal">∇</mi><msub><mi mathvariant="script">L</mi><mtext>recon</mtext></msub><mo stretchy="false">∥</mo><mi mathvariant="normal">/</mi><mo stretchy="false">∥</mo><mi mathvariant="normal">∇</mi><msub><mi mathvariant="script">L</mi><mtext>VF,raw</mtext></msub><mo stretchy="false">∥</mo></mrow><annotation encoding="application/x-tex">w_{\text{adaptive}} = \lVert\nabla \mathcal{L}_{\text{recon}}\rVert / \lVert\nabla \mathcal{L}_{\text{VF,raw}}\rVert</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.7167em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02691em;">w</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:-0.0269em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">adaptive</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1.0361em;vertical-align:-0.2861em;"></span><span class="mopen">∥</span><span class="mord">∇</span><span class="mord"><span class="mord mathcal">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">recon</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">∥</span><span class="mord">/</span><span class="mopen">∥</span><span class="mord">∇</span><span class="mord"><span class="mord mathcal">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3283em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">VF,raw</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span><span class="mclose">∥</span></span></span></span> rescales VF gradients to match the reconstruction loss, and <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>w</mi><mtext>hyper</mtext></msub></mrow><annotation encoding="application/x-tex">w_{\text{hyper}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.7167em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02691em;">w</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:-0.0269em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">hyper</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span> is a user-set scalar (e.g., <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mn>0.1</mn></mrow><annotation encoding="application/x-tex">0.1</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">0.1</span></span></span></span>) controlling the overall VF strength.</p>

<p>The VF loss encourages both point-by-point alignment (individual latent vectors close to VFM features) and relative alignment (relationships between latents match relationships between features), using adaptive weighting similar in spirit to loss balancing in GANs <sup id="fnref:goodfellow2014generative" role="doc-noteref"><a href="#fn:goodfellow2014generative" class="footnote" rel="footnote">55</a></sup>. This yields semantically organized high-dimensional latent spaces that retain reconstruction quality while being more learnable for downstream generative models.</p>

<p>By semantically structuring the latent space of the VAE, it reduces the diffusion model’s burden, allowing it to focus on learning the distribution rather than also discovering semantic organization. The latent space provides appropriate inductive bias—semantic structure “baked in” through VFM alignment while pixel details are captured through reconstruction.</p>

<p><img src="/assets/img/blog/r4g/REPA_to_REPAE.png" alt="REPA-E: End-to-end joint VAE+diffusion training" /></p>

<p><em>Fig 11. Left shows stage-wise training with frozen VAE. Right shows REPA-E with careful gradient flow: alignment loss flows to both components, diffusion loss uses stop-gradient on VAE encoder, and VAE receives alignment gradients through BatchNorm for latent normalization.</em></p>

<p>After diffusion model alignment as well as VAE alignment had been demonstrated, REPA-E <sup id="fnref:leng_repa-e_2025:1" role="doc-noteref"><a href="#fn:leng_repa-e_2025" class="footnote" rel="footnote">46</a></sup> takes integration further through joint VAE+diffusion training, challenging the convention that these components should train separately. It demonstrates that while naive end-to-end training with diffusion loss alone is ineffective (causing latent space collapse), representation alignment provides necessary constraints for successful joint optimization. The key innovation proved to be careful gradient control. Alignment loss flows to both VAE and diffusion model, but diffusion loss uses stop-gradient on the VAE encoder to prevent collapse (the VAE shouldn’t change to make diffusion easier at reconstruction’s cost). In addition, to keep the latent space normalised, the VAE receives alignment gradients only through BatchNorm normalisation. This enables joint optimization: the VAE improves to produce latents both reconstructive <em>and</em> well-aligned, while the diffusion model learns in this evolving but stable space. Joint optimization improves the VAE itself, leading to better latent structure (higher VFM alignment, better class separation) and downstream performance. While in LSGM <sup id="fnref:vahdat_score-based_2021:1" role="doc-noteref"><a href="#fn:vahdat_score-based_2021" class="footnote" rel="footnote">15</a></sup> pretraining of the VAE was necessary, true end-to-end training is now possible.</p>

<p>3-Stage-Aligner <sup id="fnref:chen_aligning_2025" role="doc-noteref"><a href="#fn:chen_aligning_2025" class="footnote" rel="footnote">56</a></sup> proposes an alternative strategy: rather than training the VAE from scratch with alignment, freeze a pretrained encoder (e.g., DINOv2), map into into a low-dimension space via an adapter block and learn to align a decoder through three stages. Stage 1 (Latent Alignment) freezes the VFM encoder and trains the adapter plus decoder, establishing a semantic latent space with basic reconstruction capabilities. The resulting latents are semantically grounded but exhibit color shifts and missing fine-grained details since the frozen encoder was not trained for reconstruction. Stage 2 (Perceptual Alignment) jointly optimizes adapter and encoder (now unfrozen) with semantic preservation loss maintaining alignment with original frozen VFM features:</p>

<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><msub><mi mathvariant="script">L</mi><mtext>Stage 2</mtext></msub><mo>=</mo><msub><mi mathvariant="script">L</mi><mtext>recon</mtext></msub><mo>+</mo><msub><mi>λ</mi><mn>2</mn></msub><mo stretchy="false">∥</mo><mtext>Enc</mtext><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo><mo>−</mo><msub><mtext>Enc</mtext><mtext>frozen</mtext></msub><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo><msup><mo stretchy="false">∥</mo><mn>2</mn></msup></mrow><annotation encoding="application/x-tex">\mathcal{L}_{\text{Stage 2}} = \mathcal{L}_{\text{recon}} + \lambda_2 \lVert\text{Enc}(x) - \text{Enc}_{\text{frozen}}(x) \rVert^2</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9694em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord mathcal">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3283em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">Stage 2</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.8333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathcal">L</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">recon</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal">λ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">∥</span><span class="mord text"><span class="mord">Enc</span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1.1141em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord text"><span class="mord">Enc</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">frozen</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mclose"><span class="mclose">∥</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8641em;"><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span></span></span></span></span>

<p>The L2 loss prevents encoder drift from the pretrained semantic structure while allowing capture of fine-grained color and texture. Stage 3 (Decoder Refinement) freezes both encoder and adapter, allowing the decoder to better exploit the latent representation changed during Stage 2 without disturbing semantic structure.</p>

<p><img src="/assets/img/blog/r4g/VFMVAE_to_3stagealigner.png" alt="3stage-aligner: Three-stage frozen encoder alignment" />
<em>Fig 12. Stage 1 establishes semantic grounding with frozen encoder. Stage 2 allows encoder refinement with semantic preservation loss. Stage 3 optimizes decoder for reconstruction quality, carefully balancing semantic preservation with fine-grained detail capture.</em></p>

<p>This yields semantically rich tokenizers where latent space inherits discriminative structure from the pretrained encoder. The three-stage process carefully balances semantic preservation (maintaining VFM structure) with reconstruction quality (capturing fine-grained details), avoiding color shifts of purely frozen encoders and semantic drift of fully unconstrained fine-tuning. Note the contrast with REPA-E’s end-to-end approach: 3-Stage-Aligner returns to explicit staged training rather than joint optimization. This reflects a broader pattern in Phase 2—there’s a zoo of different methods (end-to-end vs. staged, joint vs. separate alignment) and it remains unclear which approach is definitively best, as many of these concurrent works don’t directly compare to each other under identical conditions.</p>

<p>Phase 2 establishes that incorporating semantic structure directly into VAE latent space—whether through alignment loss during training (VA-VAE), end-to-end joint optimization (REPA-E), or staged adaptation of frozen encoders (3stage-aligner)—produces superior results compared to standard pixel-focused VAE training. These semantically-structured spaces are easier to learn (faster convergence) and produce better final quality. However, they still rely on the two-stage VAE+diffusion pipeline, raising a natural question: do we need VAE compression at all?</p>

<h2 id="phase-3-operating-directly-in-vision-foundation-model-feature-spaces">Phase 3: Operating Directly in Vision Foundation Model Feature Spaces</h2>

<p>The third phase represents a more radical departure, questioning whether the VAE bottleneck is necessary at all. Instead of compressing images through a VAE and then aligning the latent space, these methods propose directly using pretrained vision foundation model features as the “latent space” for diffusion, or training autoencoders specifically to preserve discriminative information rather than minimize reconstruction error.</p>

<p>Based on the observation made in Perception Encoder <sup id="fnref:bolya_perception_2025" role="doc-noteref"><a href="#fn:bolya_perception_2025" class="footnote" rel="footnote">57</a></sup> that the best visual embeddings for downstream tasks are often not at the output of vision networks but rather in intermediate layers, VFM-VAE <sup id="fnref:bi_vision_2025:1" role="doc-noteref"><a href="#fn:bi_vision_2025" class="footnote" rel="footnote">43</a></sup> merges frozen VFM features from different parts of the network as latent representations. However, VFMs focus on semantic understanding, producing spatially coarse features (e.g., DINOv2 ViT-L outputs <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mn>16</mn><mo>×</mo><mn>16</mn></mrow><annotation encoding="application/x-tex">16 \times 16</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.7278em;vertical-align:-0.0833em;"></span><span class="mord">16</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">16</span></span></span></span> for <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mn>256</mn><mo>×</mo><mn>256</mn></mrow><annotation encoding="application/x-tex">256 \times 256</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.7278em;vertical-align:-0.0833em;"></span><span class="mord">256</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">256</span></span></span></span> images) sacrificing pixel fidelity. VFM-VAE redesigns the decoder with multi-scale latent fusion (combining features from multiple VFM layers, providing both semantic guidance from deep layers and spatial detail from shallow layers) and progressive resolution reconstruction (building up resolution gradually through decoder blocks, starting from coarse VFM features and progressively adding detail). In addition, the embedding dimensionality of VFMs is often too high for effective generative modelling; VFM-VAE circumvents this by mapping the different embeddings into a compressed latent space that is regularised via KL divergence, thereby still containing a VAE but with strong initialisation by a VFM.</p>

<p><img src="/assets/img/blog/r4g/VAVAE_to_VFMVAE.png" alt="VFM-VAE: Leveraging multiple VFM encodings as compressed latents" /></p>

<p><em>Fig 13. In VFM-VAE, multiple VFM encodings are compressed into a single latent representation that is then projected out to pixel space via multi-scale decoders. The right side shows that this latent space is more robust to geometric perturbations and achieves strong reconstruction as well as generation.</em></p>

<p>This enables high-quality reconstruction from semantically rich but spatially compact representations. The work also introduces SE-CKNNA metric for diagnosing representation dynamics during diffusion training. SE-CKNNA measures how well semantic structure in latent space is preserved during noising, revealing that semantic structure degrades nonlinearly with noise level, with critical thresholds where class separability breaks down. Using these insights, the authors develop joint tokenizer-diffusion alignment strategy dramatically accelerating convergence. The frozen pretrained encoder ensures the latent space maintains semantic alignment even under distribution shifts—Phase 2 methods that fine-tune encoders risk semantic drift; VFM-VAE’s frozen encoder ensures consistent structure. However, this requires architectural innovations (multi-scale fusion, progressive reconstruction) to overcome reconstruction challenges of coarse frozen features, which prevents easy adoption.</p>

<p>SVG <sup id="fnref:shi_latent_2025" role="doc-noteref"><a href="#fn:shi_latent_2025" class="footnote" rel="footnote">58</a></sup> tries to avoid these complex architectural modifications by taking a principled approach analyzing why VAE latent spaces are problematic: they lack clear semantic separation and strong discriminative structure. Standard VAE latents exhibit semantic entanglement (different classes overlap) and poor class compactness (same-class samples widely dispersed). This makes the distribution difficult for diffusion to learn, as it must simultaneously discover semantic structure and model fine-grained variation. To overcome this, SVG constructs latent representations from frozen DINO features providing semantically discriminative structure with clear class separation, augmented with lightweight residual branch capturing fine-grained details:</p>

<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><msub><mi>z</mi><mtext>final</mtext></msub><mo>=</mo><msub><mi>z</mi><mtext>DINO</mtext></msub><mo>+</mo><mi>α</mi><mo>⋅</mo><msub><mi>z</mi><mtext>residual</mtext></msub></mrow><annotation encoding="application/x-tex">z_{\text{final}} = z_{\text{DINO}} + \alpha \cdot z_{\text{residual}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.5806em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.04398em;">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:-0.044em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">final</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.7333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.04398em;">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3283em;"><span style="top:-2.55em;margin-left:-0.044em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">DINO</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.4445em;"></span><span class="mord mathnormal" style="margin-right:0.0037em;">α</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">⋅</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.5806em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.04398em;">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:-0.044em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">residual</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span>

<p>where frozen DINOv2 provides semantics and a learned residual encoder captures color, texture, and other details DINO discards. Normal VAE latents are semantically entangled, but alignment to VFM models enables clearer class separation and more compact classes. The SVG encoder proves important for fine-grained color details. No diffusion model tricks are needed since in the case of the chosen VFM DINOv3, the latent space is small enough (384-dimensional) to be modelled without compression. However, the alignment loss is crucial: without it, the decoder over-relies on the residual encoder, and numerical range differences between normalized frozen DINOv3 features and unnormalized learned residuals can distort semantic embeddings.</p>

<p><img src="/assets/img/blog/r4g/VAVAE_to_SVG.png" alt="SVG: Using frozen DINO with residual encoder" />
<em>Fig 14. Left shows VA-VAE with learned encoder aligned to VFM. Right shows SVG with frozen DINO encoder plus lightweight residual encoder capturing fine-grained details, enabling clearer semantic separation without VAE training.</em></p>

<p>While SVG emphasises the need for a modest embedding space dimensionality and the need for a residual encoder that makes up for missing pixel-level details in the VFM embeddings, RAE <sup id="fnref:zheng_diffusion_2025" role="doc-noteref"><a href="#fn:zheng_diffusion_2025" class="footnote" rel="footnote">59</a></sup> tries to replace the VAE solely with pretrained representation encoders paired with trained decoders, without additional compression or auxiliary encoders. The authors systematically explore encoders from diverse self-supervised methods (DINO, SigLIP, MAE) and analyze challenges of operating diffusion transformers in resulting high-dimensional spaces. While standard VAE latents are low-dimensional (<span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mn>32</mn><mo>×</mo><mn>32</mn><mo>×</mo><mn>4</mn></mrow><annotation encoding="application/x-tex">32 \times 32 \times 4</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.7278em;vertical-align:-0.0833em;"></span><span class="mord">32</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.7278em;vertical-align:-0.0833em;"></span><span class="mord">32</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">4</span></span></span></span>, or 4K dimensions), representation encoder outputs are much higher (<span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mn>16</mn><mo>×</mo><mn>16</mn><mo>×</mo><mn>1024</mn></mrow><annotation encoding="application/x-tex">16 \times 16 \times 1024</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.7278em;vertical-align:-0.0833em;"></span><span class="mord">16</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.7278em;vertical-align:-0.0833em;"></span><span class="mord">16</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">1024</span></span></span></span> for DINOv2 ViT-L, or 262K dimensions). This poses challenges for diffusion transformers that generally perform poorly in such high-dimensional spaces.</p>

<p>RAE identifies and addresses sources of difficulty through theoretically motivated solutions. First, standard DiT bottlenecks all tokens through the same hidden dimension, so when input tokens have higher dimensionality, this creates an information bottleneck. RAE introduces a wide DDT head that maintains high-dimensional representations through a final shallow-but-wide layer while keeping the majority of the DiT block lower-dimensional. Second, standard schedules are designed based on spatial dimensions assuming certain statistical properties. Representation encoder outputs have different characteristics (already normalized, different variance structure). Therefore, RAE makes the noise schedule depend on actual data statistics rather than assuming fixed properties. Third, since the decoder trains separately from the frozen encoder, mismatch can occur at inference—the diffusion model produces slightly imperfect samples, but the decoder was trained on clean representations. Following TarFlow <sup id="fnref:zhai2024normalizing" role="doc-noteref"><a href="#fn:zhai2024normalizing" class="footnote" rel="footnote">60</a></sup>, RAE adds noise augmentation during decoder training for robustness to imperfect samples.</p>

<p>RAE demonstrated that high-quality reconstruction from frozen DINO encoders with strong representations is possible. Computational overhead is minimal since DiT cost depends mostly on sequence length, not token dimension (which the wide head addresses). The DiT adjustments are necessary: scaling width to token dimension, making noise schedule data-dependent instead of spatial-dependent, and using noise-augmented decoding due to discrete decoder training. An additional benefit of RAE is that high-resolution synthesis is trivially enabled by swapping decoders with different patch sizes—the frozen encoder and trained diffusion model remain unchanged.</p>

<p><img src="/assets/img/blog/r4g/SVG_to_RAE.png" alt="RAE: Comprehensive framework for representation autoencoders" />
<em>Fig 15. Shows systematic exploration of different pretrained encoders (DINO, SigLIP, MAE) as frozen latent encoders, with DiT adjustments (wide DDT head, data-dependent noise schedule, noise-augmented decoding) enabling effective diffusion in high-dimensional representation spaces.</em></p>

<p>However, while RAE allowed the direct use of pretrained VFMs as encoders, it has two main limitations:</p>
<ol>
  <li>The modifications of the diffusion model required to make this work were substantial.</li>
  <li>There is no emphasis whatsoever on reconstruction, limiting editing capabilities of these models and making them potentially vulnerable to drifting off the data manifold.</li>
</ol>

<p>FAE <sup id="fnref:gao_one_2025" role="doc-noteref"><a href="#fn:gao_one_2025" class="footnote" rel="footnote">61</a></sup> focuses on tackling the first of these challenges by introducing a simple adoption via a single attention layer that allows the usage of standard LightningDiT recipes. By then training to both reconstruct images and preserve pretrained features, FAE creates truly unified representation serving as both generative latent space and discriminative feature space. The simple translation layer (a single attention layer between frozen encoder features and generative decoder) provides minimal but effective transformation. This allows use of standard diffusion models again without the RAE modifications, demonstrating that the right architectural intervention can eliminate the need for extensive model adjustments. It also shows that the simple translation layer preserves the spatial structure in latent space, which aligns with the iREPA insights that spatial structure is the main determinant for how effective alignment will be for generation quality <sup id="fnref:singh_what_2025:1" role="doc-noteref"><a href="#fn:singh_what_2025" class="footnote" rel="footnote">52</a></sup>. While these tricks seem useful to avoid the RAE architecture modifications like noise-augmented decoding and wide DDT head, recent work suggests that these modifications are not necessary once one scales RAE models up to larger sizes where the DiT width is anyway larger than the VFM embedding space <sup id="fnref:tong2026scalerae" role="doc-noteref"><a href="#fn:tong2026scalerae" class="footnote" rel="footnote">62</a></sup>.</p>

<p><img src="/assets/img/blog/r4g/RAE_to_FAE.png" alt="FAE: Streamlined unification of feature and generative spaces" />
<em>Fig 16. Unlike RAE, FAE introduces a lightweight “translation layer” (a single attention block) to align frozen pretrained encoder features with the generative decoder. This minimal intervention preserves spatial structure and discriminative power.</em></p>

<p>An alternative perspective on RAE’s convergence difficulties argues that the problem is fundamentally <em>geometric</em>, not architectural. Standard flow matching uses linear interpolation between noise and data, creating probability paths that cut through Euclidean space. When data lies on a curved manifold—such as the hypersphere that representation encoders like DINOv2 produce—these straight-line paths pass through low-density regions in the interior of the manifold rather than following its surface. This “Geometric Interference” causes standard diffusion transformers to fail on representation spaces.</p>

<p>RJF (Riemannian Flow Matching with Jacobi Regularization) <sup id="fnref:kumar2026rjf" role="doc-noteref"><a href="#fn:kumar2026rjf" class="footnote" rel="footnote">63</a></sup> addresses this by explicitly modifying the probability paths: when the manifold geometry is known (a hypersphere for DINO features), it replaces linear interpolation with geodesic interpolation (SLERP), constraining paths to follow the manifold surface using Riemannian flow matching <sup id="fnref:mathieu2020riemannian" role="doc-noteref"><a href="#fn:mathieu2020riemannian" class="footnote" rel="footnote">64</a></sup>. It additionally corrects for curvature-induced error propagation via Jacobi regularization <sup id="fnref:zaghen2025variational" role="doc-noteref"><a href="#fn:zaghen2025variational" class="footnote" rel="footnote">65</a></sup>. This enables standard DiT architectures to converge without width scaling: a DiT-B (131M parameters) achieves FID 3.37 where prior methods fail entirely. RJF is in spirit similar to CDC-FM (Carré du champ flow matching) <sup id="fnref:bamberger2025cdcfm" role="doc-noteref"><a href="#fn:bamberger2025cdcfm" class="footnote" rel="footnote">66</a></sup>, which also modifies probability paths to respect data geometry; the key difference is that RJF requires explicit knowledge of the manifold (enabling geodesic paths), while CDC-FM estimates local curvature from data via geometry-aware noise covariances, making it more general but less precise when the manifold structure is known.</p>

<p><img src="/assets/img/blog/r4g/rjf_cdc_fm.png" alt="RJF: Riemannian Flow Matching with Jacobi Regularization" />
<em>Fig 17. Both CDC-FM (left) and RJF (right) modify the probability path structure to follow the data manifold structure, with CDC-FM using a spatially varying, anisotropic Gaussian noise whose covariance captures local manifold geometry and RJF using Riemannian flow matching and Jacobi regularization.</em></p>

<p>While FAE and RJF tackled the architectural adoption problem of RAE, PS-VAE <sup id="fnref:zhang2025psvae" role="doc-noteref"><a href="#fn:zhang2025psvae" class="footnote" rel="footnote">67</a></sup> tackled the editing problem that comes with the fact that RAE does not encourage the latent space to encode reconstruction capability explicitly. By training sequential representation as well as pixel decoders as well as finetuning the pretrained representation encoder with reconstruction losses, they find a good balance between reconstruction and representation capabilities and show that this balance allows them to perform superior generation and editing.</p>

<p>Most recently, UAE <sup id="fnref:fan2025harmonizing" role="doc-noteref"><a href="#fn:fan2025harmonizing" class="footnote" rel="footnote">68</a></sup> offers a theoretical unification through its “Prism Hypothesis,” which posits that semantic and pixel representations correspond to different frequency bands of a shared spectrum. Unlike SVG which adds a separate residual encoder, or RAE which relies on a heavy decoder, UAE initializes its encoder from DINOv2 and utilizes a frequency-band modulator to disentangle the latent space. It explicitly aligns the low-frequency band to the semantic teacher while dedicating high-frequency bands to residual details, effectively harmonizing semantic abstraction with pixel fidelity in a single compact latent space. For related work on frequency-band analysis in autoencoders, see also work on spectral autoencoders <sup id="fnref:falck2025spectral" role="doc-noteref"><a href="#fn:falck2025spectral" class="footnote" rel="footnote">69</a></sup> and the associated <a href="https://www.fabianfalck.com/posts/spectralauto/">blog post</a>.</p>

<p>Phase 3 methods establish that VAE compression is not fundamental to high-quality latent diffusion. By directly using pretrained vision foundation model features as latent representations (with appropriate architectural modifications handling high-dimensionality, spatial coarseness, and reconstruction challenges), we achieve generation quality comparable to or exceeding VAE-based methods while maintaining discriminative power of the original pretrained encoder. However, all Phase 3 methods still rely on pretrained vision foundation models. Phase 4 takes the final step: questioning whether we need pretrained representations at all.</p>

<h2 id="phase-4-questioning-the-need-for-pretrained-representations">Phase 4: Questioning the Need for Pretrained Representations</h2>

<p>After three phases focused on progressively sophisticated ways to leverage pretrained models, Phase 4 represents a countertrend: can we achieve similar benefits by training from scratch with better objectives and architectures? This phase questions whether dependency on external pretrained models is fundamental or merely a workaround for suboptimal training procedures.</p>

<p>USP <sup id="fnref:chu_usp_2025" role="doc-noteref"><a href="#fn:chu_usp_2025" class="footnote" rel="footnote">70</a></sup> embodies this philosophy through fully end-to-end training jointly optimized for both generative and discriminative objectives. Rather than initializing from external representations, it employs a multi-task loss combining generation and discrimination such as contrastive learning, masked prediction, or classification. Generative and discriminative objectives complement one another: generative learning encourages modeling the full data distribution, while discriminative tasks promote the discovery of semantically meaningful structure. Joint optimization thus produces representations that are simultaneously generative (capable of synthesis) and discriminative (useful downstream), reducing the reliance on separate pretraining stages. This raises a critical question: does representation alignment solve deep architectural deficiencies, or does it merely accelerate learning? If the latter, the necessity of pretrained models could wane as compute, data, and training recipes continue to scale.</p>

<p>A similar spirit underlies large-scale systems such as FLUX2-VAE <sup id="fnref:noauthor_black_nodate" role="doc-noteref"><a href="#fn:noauthor_black_nodate" class="footnote" rel="footnote">71</a></sup>, which demonstrates that sophisticated tokenizers can be learned directly through end-to-end training rather than depending on pretrained vision foundation features. Although little is publicly known about its technical details, FLUX2-VAE’s production success suggests that with sufficient scale and engineering, high-quality tokenizers and representations can emerge organically from task training alone. Yet, “without pretrained representations” does not necessarily mean “cheap to train”: the total computational cost may rival or even exceed that of conventional pretraining pipelines. Whether the elegance of end-to-end architectures outweighs the modularity, interpretability, and reusability of pretrained components remains an open question.</p>

<p>The same shift is visible in the recent renaissance of pixel-space diffusion models, which challenge the long-held assumption that latent diffusion is a prerequisite for high-resolution, high-quality generation. Methods such as JiT (Just image Transformer) <sup id="fnref:li_back_2025" role="doc-noteref"><a href="#fn:li_back_2025" class="footnote" rel="footnote">72</a></sup>, PixelDiT <sup id="fnref:yu_pixeldit_2025" role="doc-noteref"><a href="#fn:yu_pixeldit_2025" class="footnote" rel="footnote">73</a></sup>, DeCo (frequency-DeCoupled diffusion) <sup id="fnref:ma_deco_2025" role="doc-noteref"><a href="#fn:ma_deco_2025" class="footnote" rel="footnote">74</a></sup>, DiP (Diffusion in Pixel space) <sup id="fnref:chen_dip_2025" role="doc-noteref"><a href="#fn:chen_dip_2025" class="footnote" rel="footnote">75</a></sup>, and SiD2 (Simpler Diffusion v2) <sup id="fnref:hoogeboomsimpler2025" role="doc-noteref"><a href="#fn:hoogeboomsimpler2025" class="footnote" rel="footnote">76</a></sup> illustrate a broader trend: architectural innovation can substitute for latent-space compression. By employing patch-based Transformers, efficient multi-scale attention, or frequency-aware loss designs, these models demonstrate that the efficiency, quality, and stability advantages traditionally attributed to latent spaces can also be achieved through direct pixel-space training.</p>

<p>EPG (End-to-End Pixel-Space Generative Pretraining) <sup id="fnref:lei_advancing_2025" role="doc-noteref"><a href="#fn:lei_advancing_2025" class="footnote" rel="footnote">77</a></sup> pushes this idea further by integrating representation learning into pixel-space diffusion itself. Rather than discarding the notion of learned structure, it reimagines representation pretraining as part of the diffusion process. EPG pretrains encoders through self-supervised objectives along deterministic diffusion trajectories, learning temporally consistent and semantically distinct features directly in pixel space. This pretraining endows the encoder with structured initialization analogous to pretrained vision models, but derived natively from the diffusion task. The result is a model that successfully trains consistency and diffusion systems from scratch, reportedly the first to achieve stable training of high-resolution consistency models without any pretrained VAEs or diffusion models. EPG leverages the dispersive loss <sup id="fnref:wang_diffuse_2025:1" role="doc-noteref"><a href="#fn:wang_diffuse_2025" class="footnote" rel="footnote">49</a></sup>, a simple plug-and-play regularizer that encourages diffusion model representations to disperse in the model’s intermediate feature space (analogous to contrastive learning) without requiring positive pairs, improving generation quality without interfering with the sampling process.</p>

<p>However, it’s important to note that these pixel-space methods mostly tackle ImageNet-scale generation. At true production level—think FLUX, Sora, Veo, and similar systems—to the best of my knowledge these are all latent models. The field is moving toward video generation, which is so computationally expensive that compression remains essential. Scaling pixel-space methods to high-resolution, text-driven image or video generation at production quality remains to be demonstrated. Additionally, at production level, efficiency is critically important: serving models to millions of users requires fast generation and manageable inference costs, especially as video and world models become the next frontier. For some discussion on the latent vs pixel-space trade-off you can look at <a href="https://x.com/sedielem/status/1993826764674490734">the replies to this tweet By Sander Dieleman</a>.</p>

<h2 id="the-other-direction-generative-models-as-representations">The Other Direction: Generative Models as Representations</h2>

<p>The four phases above focused on one direction of unification: using pretrained representations to improve generation. But the relationship is bidirectional—generative modeling itself can serve as a powerful pretraining objective for learning representations useful in discriminative tasks. This “other direction” has gained significant momentum, suggesting that generation and representation learning may be two views of the same underlying process.</p>

<h3 id="from-pixel-prediction-to-embedding-prediction">From Pixel Prediction to Embedding Prediction</h3>

<p>MAE <sup id="fnref:he2022masked:2" role="doc-noteref"><a href="#fn:he2022masked" class="footnote" rel="footnote">8</a></sup> pioneered masked pixel reconstruction for vision, demonstrating that predicting masked image patches creates strong representations. However, the pixel-level reconstruction objective tends to focus on low-level details rather than high-level semantics. Could predicting embeddings instead of pixels yield better representations?</p>

<p>AIM v1 (Autoregressive Image Models) <sup id="fnref:el2024scalable" role="doc-noteref"><a href="#fn:el2024scalable" class="footnote" rel="footnote">78</a></sup> revisits autoregressive modeling for vision with modern architectures and large-scale data. Unlike early work like iGPT <sup id="fnref:chen2020generative" role="doc-noteref"><a href="#fn:chen2020generative" class="footnote" rel="footnote">79</a></sup> or D-iGPT <sup id="fnref:ren2023rejuvenating" role="doc-noteref"><a href="#fn:ren2023rejuvenating" class="footnote" rel="footnote">80</a></sup>, AIM uses Vision Transformers and is trained on billions of images. The work demonstrates two key findings: (1) visual feature performance scales with both model capacity and data quantity, exhibiting similar scaling laws to large language models, and (2) the value of the autoregressive objective function correlates with downstream performance, providing a meaningful training signal. AIM-7B achieves 84.0% ImageNet fine-tuning accuracy and shows particularly strong performance when trained on diverse, uncurated web data.</p>

<p>AIM v2 <sup id="fnref:fini2025multimodal" role="doc-noteref"><a href="#fn:fini2025multimodal" class="footnote" rel="footnote">81</a></sup> extends this to multimodal autoregressive models, demonstrating that the same autoregressive paradigm can be applied across images and text, creating unified representations that span modalities. NEPA (Next-Embedding Prediction) <sup id="fnref:xu_next-embedding_2025" role="doc-noteref"><a href="#fn:xu_next-embedding_2025" class="footnote" rel="footnote">82</a></sup> takes this further by predicting embeddings from pretrained models rather than raw pixels—by operating in a semantic embedding space, NEPA focuses on high-level features rather than low-level details, bridging generative objectives with the representation-focused methods discussed in earlier sections.</p>

<h3 id="diffusion-models-learn-representations-too">Diffusion Models Learn Representations Too</h3>

<p>The broader pattern is that generative objectives—whether autoregressive, masked, or diffusion-based—can serve dual purposes: they enable sampling of new examples and, as a byproduct, learn representations useful for discriminative tasks. Recent work on improving diffusion autoencoders <sup id="fnref:skorokhodov_improving_2025:2" role="doc-noteref"><a href="#fn:skorokhodov_improving_2025" class="footnote" rel="footnote">38</a></sup> and using masked autoencoders as tokenizers <sup id="fnref:chen_masked_2025:2" role="doc-noteref"><a href="#fn:chen_masked_2025" class="footnote" rel="footnote">39</a></sup> further blurs this line.</p>

<p>Several methods explicitly bridge generative and discriminative training within diffusion models. Robust representation consistency models <sup id="fnref:lei_robust_2025" role="doc-noteref"><a href="#fn:lei_robust_2025" class="footnote" rel="footnote">83</a></sup> use contrastive denoising to learn consistent representations along diffusion trajectories, improving both robustness and downstream performance. EPG <sup id="fnref:lei_advancing_2025:1" role="doc-noteref"><a href="#fn:lei_advancing_2025" class="footnote" rel="footnote">77</a></sup>, discussed in Phase 4, exemplifies this approach by pretraining encoders along diffusion trajectories to learn structured representations natively from the generative task.</p>

<p>These developments suggest that the distinction between “representation learning” and “generative modeling” may be more historical than fundamental—both aim to learn useful structure from data, just with different downstream applications in mind.</p>

<h2 id="representation-learning-and-alignment-in-molecular-machine-learning">Representation Learning and Alignment in Molecular Machine Learning</h2>

<p>The ideas from visual representation learning and generative modeling are beginning to influence molecular and protein modeling, suggesting broader applicability of these concepts beyond computer vision. Neural network potentials (NNPs), particularly MACE (Message Passing Atomic Cluster Expansion) <sup id="fnref:batatia2025foundation" role="doc-noteref"><a href="#fn:batatia2025foundation" class="footnote" rel="footnote">84</a></sup><sup id="fnref:bernstein2024gap" role="doc-noteref"><a href="#fn:bernstein2024gap" class="footnote" rel="footnote">85</a></sup>, have emerged as foundation models for atomistic chemistry that exhibit striking parallels to vision foundation models like DINO:</p>

<ol>
  <li>
    <p><strong>Embeddings transfer across tasks</strong>: MACE’s internal representations, learned for predicting quantum-mechanical energies and forces, generalise remarkably well to diverse downstream tasks. These embeddings can predict molecular properties far beyond the original training objective, allowing accurate property predictions not only in materials (the original domain) but also in small molecules <sup id="fnref:wedig2025rem3di" role="doc-noteref"><a href="#fn:wedig2025rem3di" class="footnote" rel="footnote">86</a></sup> and proteins <sup id="fnref:bojan2025representing" role="doc-noteref"><a href="#fn:bojan2025representing" class="footnote" rel="footnote">87</a></sup>.</p>
  </li>
  <li>
    <p><strong>Platonic convergence with scale</strong>: Just as vision models trained with different objectives converge toward similar representations as they scale, independently trained molecular models exhibit the same phenomenon. Work from MIT demonstrates that ostensibly different molecular models can be mapped into a common latent space with minimal performance loss <sup id="fnref:edamadaka2025universally" role="doc-noteref"><a href="#fn:edamadaka2025universally" class="footnote" rel="footnote">88</a></sup>, while complementary work from London shows that NNPs trained on large, diverse datasets discover comparable latent organizations <sup id="fnref:li2025platonic" role="doc-noteref"><a href="#fn:li2025platonic" class="footnote" rel="footnote">89</a></sup>—a molecular analogue of the Platonic Representation Hypothesis.</p>
  </li>
  <li>
    <p><strong>Representation alignment benefits generative models</strong>: MACE-REPA directly applies the Phase 1 alignment paradigm to molecular force fields <sup id="fnref:pinede2025unifying" role="doc-noteref"><a href="#fn:pinede2025unifying" class="footnote" rel="footnote">90</a></sup>. Instead of aligning diffusion features to DINO, it aligns force-field encoder representations to frozen MACE features using auxiliary losses. This demonstrates that the core insight—leveraging structured pretrained representations to accelerate training—transfers robustly from image diffusion to atomistic simulations.</p>
  </li>
</ol>

<h3 id="molecular-embeddings-borrowing-from-nlp-and-computer-vision">Molecular Embeddings: Borrowing from NLP and Computer Vision</h3>

<p>The pattern of using foundation model embeddings for downstream prediction is not unique to vision or molecules—it directly parallels the NLP paradigm where LLM embeddings are aggregated (e.g., via mean pooling) and fed to prediction heads. In computer vision, DINO embeddings (particularly the [CLS] token) serve the same role. For molecules and proteins, MACE produces atomic descriptors that must be aggregated into molecular or residue-level representations before downstream prediction.</p>

<p><img src="/assets/img/blog/r4g/molecular_embeddings.png" alt="Foundation model embeddings across domains" />
<em>Fig 17. The embedding paradigm across domains: foundation models (LLMs, DINO, MACE) produce local embeddings that are aggregated and fed to task-specific prediction heads. In molecules, MACE atomic descriptors are pooled via learned aggregators like REM3DI; in proteins, they can be pooled to residue-level descriptors via GNN pooling.</em></p>

<p>This pattern has been successfully instantiated for both molecules and proteins. REM3DI <sup id="fnref:wedig2025rem3di:1" role="doc-noteref"><a href="#fn:wedig2025rem3di" class="footnote" rel="footnote">86</a></sup> learns to aggregate MACE atomic descriptors into smooth, rotation-invariant molecular representations that achieve state-of-the-art performance on property prediction benchmarks. For proteins, similar approaches pool MACE atomic descriptors to the residue level, enabling prediction of per-residue properties like NMR chemical shifts or pKa values<sup id="fnref:bojan2025representing:1" role="doc-noteref"><a href="#fn:bojan2025representing" class="footnote" rel="footnote">87</a></sup>.</p>

<p><img src="/assets/img/blog/r4g/rem3di_protein_embeddings.png" alt="Molecular and protein embeddings from MACE" /></p>

<p><em>Fig 18. Left: REM3DI aggregates MACE atomic descriptors into molecular descriptors for property prediction. Right: MACE atomic descriptors can be pooled to residue-level representations for protein property prediction, extracting canonical environment descriptors from local atomic neighborhoods.</em></p>

<h3 id="where-to-go-from-here">Where to Go from Here?</h3>

<p><img src="/assets/img/blog/r4g/how_to_use_afms.png" alt="Three approaches to using atomistic foundation models" />
<em>Fig 19. Three complementary approaches to leveraging atomistic foundation models: (1) using pretrained embeddings directly for downstream tasks, (2) aligning generative model representations to foundation model features, and (3) drawing architectural inspiration from what makes foundation models generalise well.</em></p>

<p>Using pretrained embeddings remains powerful for property prediction, though naively incorporating them does not seem to help structure prediction—as we explored in the RF3 paper <sup id="fnref:corley2025atomworks" role="doc-noteref"><a href="#fn:corley2025atomworks" class="footnote" rel="footnote">91</a></sup>, simply conditioning on MACE embeddings did not improve structure prediction accuracy. Representation alignment (as in MACE-REPA) has started but remains in its infancy compared to the sophisticated alignment methods developed for images.</p>

<p>A third, perhaps underexplored angle is <strong>architectural inspiration</strong>: why does MACE work so well and generalise so broadly? Three factors likely contribute: (1) training on large-scale DFT data, (2) physics-grounded objectives (energies and forces), and (3) strong locality bias—MACE operates on strictly local atomic environments rather than global molecular graphs. SLAE <sup id="fnref:chen2025slae" role="doc-noteref"><a href="#fn:chen2025slae" class="footnote" rel="footnote">92</a></sup> takes exactly this approach for proteins: it adopts the physics-grounded objective by predicting Rosetta energy terms (hydrogen bonding, solvation, electrostatics) and embraces strict locality by encoding all-atom environments rather than full protein graphs. This places SLAE conceptually close to aligned VAEs in vision—reconstruction ensures geometric fidelity while auxiliary physics heads encourage the latent space to align with physically meaningful axes.</p>

<h2 id="conclusion">Conclusion</h2>

<p>The field has undergone rapid evolution, progressing through four distinct phases: (1) aligning diffusion features with pretrained representations, (2) incorporating semantic structure into VAE latent spaces, (3) directly using pretrained representations as latent spaces, and (4) questioning whether pretrained representations are necessary at all. Parallel developments in pixel-space diffusion and generative representation learning have further enriched the landscape.</p>

<p>Several clear patterns emerge: representation alignment dramatically accelerates training, spatial structure may be more important than global semantics, VAE compression is not fundamental, and principles transfer beyond vision to molecular modeling. However, fundamental questions remain: What makes representations learnable? What is the optimal compression rate? How do we unify multiple modalities? Should we train jointly or in stages?</p>

<p>The answers likely depend on scale, application requirements, and computational constraints. At research scale, leveraging pretrained models provides clear advantages for rapid iteration and exploration. At production scale, state-of-the-art systems like FLUX, Veo, and Sora demonstrate that multi-stage latent approaches—not necessarily end-to-end training—can achieve maximum quality. This suggests that the modularity of staged training, where VAEs are pretrained and reused, offers both efficiency and quality benefits at scale. While pixel-space methods continue to advance and may avoid certain reconstruction artifacts, latent-space methods currently dominate production deployments due to their computational efficiency, which is critical when serving millions of users or generating expensive video content.</p>

<p>Looking forward, I am very excited to see how these advances in vision will translate to the molecular world. I myself have worked quite a bit on generative modelling for proteins <sup id="fnref:geffner2025proteina" role="doc-noteref"><a href="#fn:geffner2025proteina" class="footnote" rel="footnote">93</a></sup> and small molecules <sup id="fnref:schneuing2024structure" role="doc-noteref"><a href="#fn:schneuing2024structure" class="footnote" rel="footnote">94</a></sup>, recently also leveraging latent diffusion <sup id="fnref:geffner2025laproteina" role="doc-noteref"><a href="#fn:geffner2025laproteina" class="footnote" rel="footnote">95</a></sup>, so I will follow this space with great interest!</p>

<h2 id="credits">Credits</h2>

<p>Thanks to everyone who gave feedback to this blogpost, especially Arash Vahdat, Karsten Kreis and the rest of the GenAIR team as well as members of the Baker lab for interesting discussions about this topic. Also thanks to Philip Isola and the rest of the “Foundation of Computer Vision” team for making their great text <a href="https://visionbook.mit.edu/">openly accessible</a> from which I got the title image for this blogpost as well as the representation vs generation figure.</p>

<h2 id="references">References</h2>

<div class="footnotes" role="doc-endnotes">
  <ol>
    <li id="fn:ho2020ddpm" role="doc-endnote">
      <p>Ho, J., et al. (2020). Denoising Diffusion Probabilistic Models. <em>NeurIPS</em>. <a href="https://arxiv.org/abs/2006.11239">https://arxiv.org/abs/2006.11239</a> <a href="#fnref:ho2020ddpm" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:ho2020ddpm:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a></p>
    </li>
    <li id="fn:song2020score" role="doc-endnote">
      <p>Song, Y., et al. (2020). Score-Based Generative Modeling through Stochastic Differential Equations. <em>ICLR</em>. <a href="https://arxiv.org/abs/2011.13456">https://arxiv.org/abs/2011.13456</a> <a href="#fnref:song2020score" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:song2020score:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a></p>
    </li>
    <li id="fn:lai2025principles" role="doc-endnote">
      <p>Lai, C.-H., Song, Y., Kim, D., Mitsufuji, Y., &amp; Ermon, S. (2025). The principles of diffusion models. <em>arXiv preprint arXiv:2510.21890</em>. <a href="https://arxiv.org/abs/2510.21890">https://arxiv.org/abs/2510.21890</a> <a href="#fnref:lai2025principles" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:lipman2024flow" role="doc-endnote">
      <p>Lipman, Y., et al. (2024). Flow Matching for Generative Modeling. <em>ICLR</em>. <a href="https://arxiv.org/abs/2210.02747">https://arxiv.org/abs/2210.02747</a> <a href="#fnref:lipman2024flow" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:lipman2024flow:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a></p>
    </li>
    <li id="fn:albergo2023building" role="doc-endnote">
      <p>Albergo, M. S., &amp; Vanden-Eijnden, E. (2023). Building Normalizing Flows with Stochastic Interpolants. <em>ICLR</em>. <a href="https://arxiv.org/abs/2209.15571">https://arxiv.org/abs/2209.15571</a> <a href="#fnref:albergo2023building" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:radford2021learning" role="doc-endnote">
      <p>Radford, A., et al. (2021). Learning Transferable Visual Models From Natural Language Supervision. <em>ICML</em>. <a href="https://arxiv.org/abs/2103.00020">https://arxiv.org/abs/2103.00020</a> <a href="#fnref:radford2021learning" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:radford2021learning:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a> <a href="#fnref:radford2021learning:2" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>3</sup></a></p>
    </li>
    <li id="fn:caron_emerging_2021" role="doc-endnote">
      <p>Caron, M., et al. (2021). Emerging Properties in Self-Supervised Vision Transformers. <em>ICCV</em>. <a href="https://arxiv.org/abs/2104.14294">https://arxiv.org/abs/2104.14294</a> <a href="#fnref:caron_emerging_2021" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:caron_emerging_2021:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a> <a href="#fnref:caron_emerging_2021:2" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>3</sup></a></p>
    </li>
    <li id="fn:he2022masked" role="doc-endnote">
      <p>He, K., et al. (2022). Masked Autoencoders Are Scalable Vision Learners. <em>CVPR</em>. <a href="https://arxiv.org/abs/2111.06377">https://arxiv.org/abs/2111.06377</a> <a href="#fnref:he2022masked" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:he2022masked:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a> <a href="#fnref:he2022masked:2" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>3</sup></a></p>
    </li>
    <li id="fn:kadkhodaie_unconditional_2025" role="doc-endnote">
      <p>Kadkhodaie, Z., Mallat, S., &amp; Simoncelli, E. (2025). Unconditional CNN denoisers contain sparse semantic representation of images. <em>arXiv preprint arXiv:2506.01912</em>. <a href="https://arxiv.org/abs/2506.01912">https://arxiv.org/abs/2506.01912</a> <a href="#fnref:kadkhodaie_unconditional_2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:liang_how_2024" role="doc-endnote">
      <p>Liang, Q., Liu, Z., Ostrow, M., &amp; Fiete, I. (2024). How Diffusion Models Learn to Factorize and Compose. <em>arXiv preprint arXiv:2408.13256</em>. <a href="https://arxiv.org/abs/2408.13256">https://arxiv.org/abs/2408.13256</a> <a href="#fnref:liang_how_2024" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:jaini_intriguing_2024" role="doc-endnote">
      <p>Jaini, P., Clark, K., &amp; Geirhos, R. (2024). Intriguing properties of generative classifiers. <em>arXiv preprint arXiv:2309.16779</em>. <a href="https://arxiv.org/abs/2309.16779">https://arxiv.org/abs/2309.16779</a> <a href="#fnref:jaini_intriguing_2024" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:sohldickstein2015deep" role="doc-endnote">
      <p>Sohl-Dickstein, J., Weiss, E., Maheswaranathan, N., &amp; Ganguli, S. (2015). Deep unsupervised learning using nonequilibrium thermodynamics. <em>International Conference on Machine Learning</em>, 2256–2265. <a href="https://arxiv.org/abs/1503.03585">https://arxiv.org/abs/1503.03585</a> <a href="#fnref:sohldickstein2015deep" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:song2020improved" role="doc-endnote">
      <p>Song, Y., &amp; Ermon, S. (2020). Improved techniques for training score-based generative models. <em>Advances in neural information processing systems</em>, 33, 12438–12448. <a href="#fnref:song2020improved" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:vahdat2020nvae" role="doc-endnote">
      <p>Vahdat, A., &amp; Kautz, J. (2020). NVAE: A Deep Hierarchical Variational Autoencoder. <em>NeurIPS</em>. <a href="https://arxiv.org/abs/2007.03898">https://arxiv.org/abs/2007.03898</a> <a href="#fnref:vahdat2020nvae" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:vahdat_score-based_2021" role="doc-endnote">
      <p>Vahdat, A., et al. (2021). Score-based generative modeling in latent space. <em>NeurIPS</em>. <a href="https://arxiv.org/abs/2106.05931">https://arxiv.org/abs/2106.05931</a> <a href="#fnref:vahdat_score-based_2021" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:vahdat_score-based_2021:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a></p>
    </li>
    <li id="fn:rombach_high-resolution_2022" role="doc-endnote">
      <p>Rombach, R., et al. (2022). High-Resolution Image Synthesis with Latent Diffusion Models. <em>CVPR</em>. <a href="https://arxiv.org/abs/2112.10752">https://arxiv.org/abs/2112.10752</a> <a href="#fnref:rombach_high-resolution_2022" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:rombach_high-resolution_2022:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a></p>
    </li>
    <li id="fn:dieleman2025latents" role="doc-endnote">
      <p>Dieleman, S. (2025). Generative modelling in latent space. <a href="https://sander.ai/2025/04/15/latents.html">https://sander.ai/2025/04/15/latents.html</a> <a href="#fnref:dieleman2025latents" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:dieleman2025latents:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a></p>
    </li>
    <li id="fn:blattmann2023align" role="doc-endnote">
      <p>Blattmann, A., et al. (2023). Align your Latents: High-Resolution Video Synthesis with Latent Diffusion Models. <em>CVPR</em>. <a href="https://arxiv.org/abs/2304.08818">https://arxiv.org/abs/2304.08818</a> <a href="#fnref:blattmann2023align" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:brooks2024sora" role="doc-endnote">
      <p>Brooks, T., et al. (2024). Video generation models as world simulators. <em>OpenAI Blog</em>. <a href="https://openai.com/index/video-generation-models-as-world-simulators/">https://openai.com/index/video-generation-models-as-world-simulators/</a> <a href="#fnref:brooks2024sora" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:brooks2024sora:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a> <a href="#fnref:brooks2024sora:2" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>3</sup></a></p>
    </li>
    <li id="fn:podell2023sdxl" role="doc-endnote">
      <p>Podell, D., et al. (2023). SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis. <em>ICLR</em>. <a href="https://arxiv.org/abs/2307.01952">https://arxiv.org/abs/2307.01952</a> <a href="#fnref:podell2023sdxl" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:peebles2023scalable" role="doc-endnote">
      <p>Peebles, W., &amp; Xie, S. (2023). Scalable Diffusion Models with Transformers. <em>ICCV</em>. <a href="https://arxiv.org/abs/2212.09748">https://arxiv.org/abs/2212.09748</a> <a href="#fnref:peebles2023scalable" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:gao2024diffusion" role="doc-endnote">
      <p>Gao, R., Hoogeboom, E., Heek, J., Bortoli, V. D., Murphy, K. P., &amp; Salimans, T. (2024). Diffusion meets flow matching: Two sides of the same coin. <em>arXiv preprint arXiv:2401.08740</em>. <a href="https://arxiv.org/abs/2401.08740">https://arxiv.org/abs/2401.08740</a> <a href="#fnref:gao2024diffusion" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:esser2024sd3" role="doc-endnote">
      <p>Esser, P., et al. (2024). Scaling Rectified Flow Transformers for High-Resolution Image Synthesis. <em>ICML</em>. <a href="https://arxiv.org/abs/2403.03206">https://arxiv.org/abs/2403.03206</a> <a href="#fnref:esser2024sd3" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:esser2024sd3:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a></p>
    </li>
    <li id="fn:oord_representation_2019" role="doc-endnote">
      <p>Oord, A. v. d., Li, Y., &amp; Vinyals, O. (2019). Representation Learning with Contrastive Predictive Coding. <em>arXiv preprint arXiv:1807.03748</em>. <a href="https://arxiv.org/abs/1807.03748">https://arxiv.org/abs/1807.03748</a> <a href="#fnref:oord_representation_2019" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:oord_representation_2019:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a> <a href="#fnref:oord_representation_2019:2" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>3</sup></a></p>
    </li>
    <li id="fn:chen_exploring_2020" role="doc-endnote">
      <p>Chen, X., &amp; He, K. (2020). Exploring Simple Siamese Representation Learning. <em>CVPR</em>. <a href="https://arxiv.org/abs/2011.10566">https://arxiv.org/abs/2011.10566</a> <a href="#fnref:chen_exploring_2020" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:chen_exploring_2020:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a></p>
    </li>
    <li id="fn:caron_deep_2019" role="doc-endnote">
      <p>Caron, M., et al. (2019). Deep Clustering for Unsupervised Learning of Visual Features. <em>ECCV</em>. <a href="https://arxiv.org/abs/1807.05520">https://arxiv.org/abs/1807.05520</a> <a href="#fnref:caron_deep_2019" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:caron_deep_2019:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a></p>
    </li>
    <li id="fn:caron_unsupervised_2021" role="doc-endnote">
      <p>Caron, M., et al. (2021). Unsupervised Learning of Visual Features by Contrasting Cluster Assignments. <em>NeurIPS</em>. <a href="https://arxiv.org/abs/2006.09882">https://arxiv.org/abs/2006.09882</a> <a href="#fnref:caron_unsupervised_2021" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:caron_unsupervised_2021:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a></p>
    </li>
    <li id="fn:assran_self-supervised_2023" role="doc-endnote">
      <p>Assran, M., et al. (2023). Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture. <em>CVPR</em>. <a href="https://arxiv.org/abs/2301.08243">https://arxiv.org/abs/2301.08243</a> <a href="#fnref:assran_self-supervised_2023" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:assran_self-supervised_2023:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a> <a href="#fnref:assran_self-supervised_2023:2" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>3</sup></a></p>
    </li>
    <li id="fn:huh_platonic_2024" role="doc-endnote">
      <p>Huh, M., Cheung, B., Wang, T., &amp; Isola, P. (2024). The Platonic Representation Hypothesis. <em>arXiv preprint arXiv:2405.07987</em>. <a href="https://arxiv.org/abs/2405.07987">https://arxiv.org/abs/2405.07987</a> <a href="#fnref:huh_platonic_2024" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:huh_platonic_2024:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a> <a href="#fnref:huh_platonic_2024:2" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>3</sup></a> <a href="#fnref:huh_platonic_2024:3" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>4</sup></a> <a href="#fnref:huh_platonic_2024:4" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>5</sup></a></p>
    </li>
    <li id="fn:chen2020simclr" role="doc-endnote">
      <p>Chen, T., et al. (2020). A Simple Framework for Contrastive Learning of Visual Representations. <em>ICML</em>. <a href="https://arxiv.org/abs/2002.05709">https://arxiv.org/abs/2002.05709</a> <a href="#fnref:chen2020simclr" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:chen2020simclr:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a></p>
    </li>
    <li id="fn:oquab_dinov2_2024" role="doc-endnote">
      <p>Oquab, M., et al. (2024). DINOv2: Learning Robust Visual Features without Supervision. <em>arXiv preprint arXiv:2304.07193</em>. <a href="https://arxiv.org/abs/2304.07193">https://arxiv.org/abs/2304.07193</a> <a href="#fnref:oquab_dinov2_2024" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:oquab_dinov2_2024:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a> <a href="#fnref:oquab_dinov2_2024:2" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>3</sup></a></p>
    </li>
    <li id="fn:simeoni_dinov3_2025" role="doc-endnote">
      <p>Simeoni, O., et al. (2025). DINOv3. <em>arXiv preprint arXiv:2508.10104</em>. <a href="https://arxiv.org/abs/2508.10104">https://arxiv.org/abs/2508.10104</a> <a href="#fnref:simeoni_dinov3_2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:simeoni_dinov3_2025:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a></p>
    </li>
    <li id="fn:assran_v-jepa_2025" role="doc-endnote">
      <p>Assran, M., et al. (2025). V-JEPA 2: Self-Supervised Video Models Enable Understanding, Prediction and Planning. <em>arXiv preprint arXiv:2506.09985</em>. <a href="https://arxiv.org/abs/2506.09985">https://arxiv.org/abs/2506.09985</a> <a href="#fnref:assran_v-jepa_2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:zhai_sigmoid_2023" role="doc-endnote">
      <p>Zhai, X., et al. (2023). Sigmoid Loss for Language Image Pre-Training. <em>ICCV</em>. <a href="https://arxiv.org/abs/2303.15343">https://arxiv.org/abs/2303.15343</a> <a href="#fnref:zhai_sigmoid_2023" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:zhai_sigmoid_2023:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a></p>
    </li>
    <li id="fn:tschannen_siglip_2025" role="doc-endnote">
      <p>Tschannen, M., et al. (2025). SigLIP 2: Multilingual Vision-Language Encoders with Improved Semantic Understanding, Localization, and Dense Features. <em>arXiv preprint arXiv:2502.14786</em>. <a href="https://arxiv.org/abs/2502.14786">https://arxiv.org/abs/2502.14786</a> <a href="#fnref:tschannen_siglip_2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:grill_bootstrap_2020" role="doc-endnote">
      <p>Grill, J.-B., et al. (2020). Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning. <em>NeurIPS</em>. <a href="https://arxiv.org/abs/2006.07733">https://arxiv.org/abs/2006.07733</a> <a href="#fnref:grill_bootstrap_2020" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:richemond_byol_2020" role="doc-endnote">
      <p>Richemond, P. H., et al. (2020). BYOL works even without batch statistics. <em>arXiv preprint arXiv:2010.10241</em>. <a href="https://arxiv.org/abs/2010.10241">https://arxiv.org/abs/2010.10241</a> <a href="#fnref:richemond_byol_2020" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:skorokhodov_improving_2025" role="doc-endnote">
      <p>Skorokhodov, I., et al. (2025). Improving the Diffusability of Autoencoders. <em>arXiv preprint arXiv:2502.14831</em>. <a href="https://arxiv.org/abs/2502.14831">https://arxiv.org/abs/2502.14831</a> <a href="#fnref:skorokhodov_improving_2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:skorokhodov_improving_2025:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a> <a href="#fnref:skorokhodov_improving_2025:2" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>3</sup></a></p>
    </li>
    <li id="fn:chen_masked_2025" role="doc-endnote">
      <p>Chen, H., et al. (2025). Masked Autoencoders Are Effective Tokenizers for Diffusion Models. <em>arXiv preprint arXiv:2502.03444</em>. <a href="https://arxiv.org/abs/2502.03444">https://arxiv.org/abs/2502.03444</a> <a href="#fnref:chen_masked_2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:chen_masked_2025:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a> <a href="#fnref:chen_masked_2025:2" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>3</sup></a></p>
    </li>
    <li id="fn:zhou2021ibot" role="doc-endnote">
      <p>Zhou, J., et al. (2021). iBOT: Image BERT Pre-Training with Online Tokenizer. <em>ICLR</em>. <a href="https://arxiv.org/abs/2111.07832">https://arxiv.org/abs/2111.07832</a> <a href="#fnref:zhou2021ibot" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:gupta_better_2025" role="doc-endnote">
      <p>Gupta, S., Sundaram, S., Wang, C., Jegelka, S., &amp; Isola, P. (2025). Better Together: Leveraging Unpaired Multimodal Data for Stronger Unimodal Models. <em>arXiv preprint arXiv:2510.08492</em>. <a href="https://arxiv.org/abs/2510.08492">https://arxiv.org/abs/2510.08492</a> <a href="#fnref:gupta_better_2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:wang_words_2025" role="doc-endnote">
      <p>Wang, S. L., Isola, P., &amp; Cheung, B. (2025). Words That Make Language Models Perceive. <em>arXiv preprint arXiv:2510.02425</em>. <a href="https://arxiv.org/abs/2510.02425">https://arxiv.org/abs/2510.02425</a> <a href="#fnref:wang_words_2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:wang_words_2025:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a></p>
    </li>
    <li id="fn:bi_vision_2025" role="doc-endnote">
      <p>Bi, T., Zhang, X., Lu, Y., &amp; Zheng, N. (2025). Vision Foundation Models Can Be Good Tokenizers for Latent Diffusion Models. <em>arXiv preprint arXiv:2510.18457</em>. <a href="https://arxiv.org/abs/2510.18457">https://arxiv.org/abs/2510.18457</a> <a href="#fnref:bi_vision_2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:bi_vision_2025:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a></p>
    </li>
    <li id="fn:yu_representation_2025" role="doc-endnote">
      <p>Yu, S., et al. (2025). Representation Alignment for Generation: Training Diffusion Transformers Is Easier Than You Think. <em>arXiv preprint arXiv:2410.06940</em>. <a href="https://arxiv.org/abs/2410.06940">https://arxiv.org/abs/2410.06940</a> <a href="#fnref:yu_representation_2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:yu_representation_2025:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a> <a href="#fnref:yu_representation_2025:2" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>3</sup></a></p>
    </li>
    <li id="fn:wu_representation_2025" role="doc-endnote">
      <p>Wu, G., et al. (2025). Representation Entanglement for Generation: Training Diffusion Transformers Is Much Easier Than You Think. <em>arXiv preprint arXiv:2507.01467</em>. <a href="https://arxiv.org/abs/2507.01467">https://arxiv.org/abs/2507.01467</a> <a href="#fnref:wu_representation_2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:wu_representation_2025:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a></p>
    </li>
    <li id="fn:leng_repa-e_2025" role="doc-endnote">
      <p>Leng, X., Singh, J., Hou, Y., Xing, Z., Xie, S., &amp; Zheng, L. (2025). REPA-E: Unlocking VAE for End-to-End Tuning with Latent Diffusion Transformers. <em>arXiv preprint arXiv:2504.10483</em>. <a href="https://arxiv.org/abs/2504.10483">https://arxiv.org/abs/2504.10483</a> <a href="#fnref:leng_repa-e_2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:leng_repa-e_2025:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a></p>
    </li>
    <li id="fn:yao_reconstruction_2025" role="doc-endnote">
      <p>Yao, J., Yang, B., &amp; Wang, X. (2025). Reconstruction vs. Generation: Taming Optimization Dilemma in Latent Diffusion Models. <em>arXiv preprint arXiv:2501.01423</em>. <a href="https://arxiv.org/abs/2501.01423">https://arxiv.org/abs/2501.01423</a> <a href="#fnref:yao_reconstruction_2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:yao_reconstruction_2025:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a></p>
    </li>
    <li id="fn:wang_repa_2025" role="doc-endnote">
      <p>Wang, Z., et al. (2025). REPA Works Until It Doesn’t: Early-Stopped, Holistic Alignment Supercharges Diffusion Training. <em>arXiv preprint arXiv:2505.16792</em>. <a href="https://arxiv.org/abs/2505.16792">https://arxiv.org/abs/2505.16792</a> <a href="#fnref:wang_repa_2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:wang_repa_2025:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a></p>
    </li>
    <li id="fn:wang_diffuse_2025" role="doc-endnote">
      <p>Wang, R., &amp; He, K. (2025). Diffuse and Disperse: Image Generation with Representation Regularization. <em>arXiv preprint arXiv:2506.09027</em>. <a href="https://arxiv.org/abs/2506.09027">https://arxiv.org/abs/2506.09027</a> <a href="#fnref:wang_diffuse_2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:wang_diffuse_2025:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a></p>
    </li>
    <li id="fn:xiang_denoising_2023" role="doc-endnote">
      <p>Xiang, W., et al. (2023). Denoising Diffusion Autoencoders are Unified Self-Supervised Learners. <em>ICCV</em>. <a href="https://arxiv.org/abs/2303.09769">https://arxiv.org/abs/2303.09769</a> <a href="#fnref:xiang_denoising_2023" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:chen2024deconstructing" role="doc-endnote">
      <p>Chen, X., Liu, Z., Xie, S., &amp; He, K. (2024). Deconstructing denoising diffusion models for self-supervised learning. <em>arXiv preprint arXiv:2401.14404</em>. <a href="https://arxiv.org/abs/2401.14404">https://arxiv.org/abs/2401.14404</a> <a href="#fnref:chen2024deconstructing" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:singh_what_2025" role="doc-endnote">
      <p>Singh, J., Leng, X., Wu, Z., Zheng, L., Zhang, R., Shechtman, E., &amp; Xie, S. (2025). What matters for Representation Alignment: Global Information or Spatial Structure? <em>arXiv preprint arXiv:2512.10794</em>. <a href="https://arxiv.org/abs/2512.10794">https://arxiv.org/abs/2512.10794</a> <a href="#fnref:singh_what_2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:singh_what_2025:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a></p>
    </li>
    <li id="fn:shechtman2007matching" role="doc-endnote">
      <p>Shechtman, E., &amp; Irani, M. (2007). Matching local self-similarities across images and videos. <em>2007 IEEE Conference on Computer Vision and Pattern Recognition</em> (pp. 1–8). IEEE. <a href="https://ieeexplore.ieee.org/document/4270170">https://ieeexplore.ieee.org/document/4270170</a> <a href="#fnref:shechtman2007matching" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:dieleman2023geometry" role="doc-endnote">
      <p>Dieleman, S. (2023). The geometry of diffusion guidance. <a href="https://sander.ai/2023/08/28/geometry.html">https://sander.ai/2023/08/28/geometry.html</a> <a href="#fnref:dieleman2023geometry" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:goodfellow2014generative" role="doc-endnote">
      <p>Goodfellow, I., et al. (2014). Generative Adversarial Networks. <em>NeurIPS</em>. <a href="https://arxiv.org/abs/1406.2661">https://arxiv.org/abs/1406.2661</a> <a href="#fnref:goodfellow2014generative" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:chen_aligning_2025" role="doc-endnote">
      <p>Chen, B., et al. (2025). Aligning Visual Foundation Encoders to Tokenizers for Diffusion Models. <em>arXiv preprint arXiv:2509.25162</em>. <a href="https://arxiv.org/abs/2509.25162">https://arxiv.org/abs/2509.25162</a> <a href="#fnref:chen_aligning_2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:bolya_perception_2025" role="doc-endnote">
      <p>Bolya, D., et al. (2025). Perception Encoder: The best visual embeddings are not at the output of the network. <em>arXiv preprint arXiv:2504.13181</em>. <a href="https://arxiv.org/abs/2504.13181">https://arxiv.org/abs/2504.13181</a> <a href="#fnref:bolya_perception_2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:shi_latent_2025" role="doc-endnote">
      <p>Shi, M., et al. (2025). Latent Diffusion Model without Variational Autoencoder. <em>arXiv preprint arXiv:2510.15301</em>. <a href="https://arxiv.org/abs/2510.15301">https://arxiv.org/abs/2510.15301</a> <a href="#fnref:shi_latent_2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:zheng_diffusion_2025" role="doc-endnote">
      <p>Zheng, B., Ma, N., Tong, S., &amp; Xie, S. (2025). Diffusion Transformers with Representation Autoencoders. <em>arXiv preprint arXiv:2510.11690</em>. <a href="https://arxiv.org/abs/2510.11690">https://arxiv.org/abs/2510.11690</a> <a href="#fnref:zheng_diffusion_2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:zhai2024normalizing" role="doc-endnote">
      <p>Zhai, S., Zhang, R., Nakkiran, P., Berthelot, D., Gu, J., Zheng, H., Chen, T., Bautista, M. A., Jaitly, N., &amp; Susskind, J. (2024). Normalizing flows are capable generative models. <em>arXiv preprint arXiv:2412.06329</em>. <a href="https://arxiv.org/abs/2412.06329">https://arxiv.org/abs/2412.06329</a> <a href="#fnref:zhai2024normalizing" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:gao_one_2025" role="doc-endnote">
      <p>Gao, Y., Chen, C., Chen, T., &amp; Gu, J. (2025). One Layer Is Enough: Adapting Pretrained Visual Encoders for Image Generation. <em>arXiv preprint arXiv:2512.07829</em>. <a href="https://arxiv.org/abs/2512.07829">https://arxiv.org/abs/2512.07829</a> <a href="#fnref:gao_one_2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:tong2026scalerae" role="doc-endnote">
      <p>Tong, S., Zheng, B., Wang, Z., Tang, B., Ma, N., Brown, E., Yang, J., Fergus, R., LeCun, Y., &amp; Xie, S. (2026). Scaling Text-to-Image Diffusion Transformers with Representation Autoencoders. <em>arXiv preprint arXiv:2601.16208</em>. <a href="https://arxiv.org/abs/2601.16208">https://arxiv.org/abs/2601.16208</a> <a href="#fnref:tong2026scalerae" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:kumar2026rjf" role="doc-endnote">
      <p>Kumar, A., &amp; Patel, V. M. (2026). Learning on the Manifold: Unlocking Standard Diffusion Transformers with Representation Encoders. <em>arXiv preprint arXiv:2602.10099</em>. <a href="https://arxiv.org/abs/2602.10099">https://arxiv.org/abs/2602.10099</a> <a href="#fnref:kumar2026rjf" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:mathieu2020riemannian" role="doc-endnote">
      <p>Mathieu, E., &amp; Nickel, M. (2020). Riemannian Continuous Normalizing Flows. <em>Advances in Neural Information Processing Systems</em>, 33, 2503-2515. <a href="https://arxiv.org/abs/2006.10605">https://arxiv.org/abs/2006.10605</a> <a href="#fnref:mathieu2020riemannian" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:zaghen2025variational" role="doc-endnote">
      <p>Zaghen, O., Eijkelboom, F., Pouplin, A., &amp; Bekkers, E. J. (2025). Towards Variational Flow Matching on General Geometries. <em>arXiv preprint arXiv:2502.12981</em>. <a href="https://arxiv.org/abs/2502.12981">https://arxiv.org/abs/2502.12981</a> <a href="#fnref:zaghen2025variational" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:bamberger2025cdcfm" role="doc-endnote">
      <p>Bamberger, J., Jones, I., Duncan, D., Bronstein, M. M., Vandergheynst, P., &amp; Gosztolai, A. (2025). Carré du champ flow matching: better quality-generalisation tradeoff in generative models. <em>arXiv preprint arXiv:2510.05930</em>. <a href="https://arxiv.org/abs/2510.05930">https://arxiv.org/abs/2510.05930</a> <a href="#fnref:bamberger2025cdcfm" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:zhang2025psvae" role="doc-endnote">
      <p>Zhang, S. (2025). Both Semantics and Reconstruction Matter: Making Representation Encoders Ready for Text-to-Image Generation and Editing. <em>arXiv preprint arXiv:2509.25162</em>. <a href="https://arxiv.org/abs/2509.25162">https://arxiv.org/abs/2509.25162</a> <a href="#fnref:zhang2025psvae" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:fan2025harmonizing" role="doc-endnote">
      <p>Fan, W., Diao, H., Wang, Q., Lin, D., &amp; Liu, Z. (2025). The Prism Hypothesis: Harmonizing Semantic and Pixel Representations via Unified Autoencoding. <em>arXiv preprint arXiv:2512.19693</em>. <a href="https://arxiv.org/abs/2512.19693">https://arxiv.org/abs/2512.19693</a> <a href="#fnref:fan2025harmonizing" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:falck2025spectral" role="doc-endnote">
      <p>Falck, F., et al. (2025). Spectral Autoencoders. <em>arXiv preprint arXiv:2505.11278</em>. <a href="https://arxiv.org/abs/2505.11278">https://arxiv.org/abs/2505.11278</a> <a href="#fnref:falck2025spectral" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:chu_usp_2025" role="doc-endnote">
      <p>Chu, X., Li, R., &amp; Wang, Y. (2025). USP: Unified Self-Supervised Pretraining for Image Generation and Understanding. <em>arXiv preprint arXiv:2503.06132</em>. <a href="https://arxiv.org/abs/2503.06132">https://arxiv.org/abs/2503.06132</a> <a href="#fnref:chu_usp_2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:noauthor_black_nodate" role="doc-endnote">
      <p>Black Forest Labs. (n.d.). FLUX. <a href="https://bfl.ai/research/representation-comparison">https://bfl.ai/research/representation-comparison</a> <a href="#fnref:noauthor_black_nodate" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:li_back_2025" role="doc-endnote">
      <p>Li, T., &amp; He, K. (2025). Back to Basics: Let Denoising Generative Models Denoise. <em>arXiv preprint arXiv:2511.13720</em>. <a href="https://arxiv.org/abs/2511.13720">https://arxiv.org/abs/2511.13720</a> <a href="#fnref:li_back_2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:yu_pixeldit_2025" role="doc-endnote">
      <p>Yu, Y., Xiong, W., Nie, W., Sheng, Y., Liu, S., &amp; Luo, J. (2025). PixelDiT: Pixel Diffusion Transformers for Image Generation. <em>arXiv preprint arXiv:2511.20645</em>. <a href="https://arxiv.org/abs/2511.20645">https://arxiv.org/abs/2511.20645</a> <a href="#fnref:yu_pixeldit_2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:ma_deco_2025" role="doc-endnote">
      <p>Ma, Z., Wei, L., Wang, S., Zhang, S., &amp; Tian, Q. (2025). DeCo: Frequency-Decoupled Pixel Diffusion for End-to-End Image Generation. <em>arXiv preprint arXiv:2511.19365</em>. <a href="https://arxiv.org/abs/2511.19365">https://arxiv.org/abs/2511.19365</a> <a href="#fnref:ma_deco_2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:chen_dip_2025" role="doc-endnote">
      <p>Chen, Z., et al. (2025). DiP: Taming Diffusion Models in Pixel Space. <em>arXiv preprint arXiv:2511.18822</em>. <a href="https://arxiv.org/abs/2511.18822">https://arxiv.org/abs/2511.18822</a> <a href="#fnref:chen_dip_2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:hoogeboomsimpler2025" role="doc-endnote">
      <p>Hoogeboom, E., Mensink, T., Heek, J., Lamerigts, K., Gao, R., &amp; Salimans, T. (2025). Simpler Diffusion (SiD2): 1.5 FID on ImageNet512 with pixel-space diffusion. <em>arXiv preprint arXiv:2410.19324</em>. <a href="https://arxiv.org/abs/2410.19324">https://arxiv.org/abs/2410.19324</a> <a href="#fnref:hoogeboomsimpler2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:lei_advancing_2025" role="doc-endnote">
      <p>Lei, J., Liu, K., Berner, J., Yu, H., Zheng, H., Wu, J., &amp; Chu, X. (2025). Advancing End-to-End Pixel Space Generative Modeling via Self-supervised Pre-training. <em>arXiv preprint arXiv:2510.12586</em>. <a href="https://arxiv.org/abs/2510.12586">https://arxiv.org/abs/2510.12586</a> <a href="#fnref:lei_advancing_2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:lei_advancing_2025:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a></p>
    </li>
    <li id="fn:el2024scalable" role="doc-endnote">
      <p>El-Nouby, A., Klein, M., Zhai, S., Bautista, M. A., Toshev, A., Shankar, V., Susskind, J. M., &amp; Joulin, A. (2024). Scalable pre-training of large autoregressive image models. <em>arXiv preprint arXiv:2401.08541</em>. <a href="https://arxiv.org/abs/2401.08541">https://arxiv.org/abs/2401.08541</a> <a href="#fnref:el2024scalable" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:chen2020generative" role="doc-endnote">
      <p>Chen, M., et al. (2020). Generative Pretraining from Pixels. <em>ICML</em>. <a href="https://arxiv.org/abs/2009.14794">https://arxiv.org/abs/2009.14794</a> <a href="#fnref:chen2020generative" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:ren2023rejuvenating" role="doc-endnote">
      <p>Ren, S., Wang, Z., Zhu, H., Xiao, J., Yuille, A., &amp; Xie, C. (2023). Rejuvenating image-gpt as strong visual representation learners. <em>arXiv preprint arXiv:2312.02147</em>. <a href="https://arxiv.org/abs/2312.02147">https://arxiv.org/abs/2312.02147</a> <a href="#fnref:ren2023rejuvenating" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:fini2025multimodal" role="doc-endnote">
      <p>Fini, E., Shukor, M., Li, X., Dufter, P., Klein, M., Haldimann, D., Aitharaju, S., da Costa, V. G. T., Béthune, L., Gan, Z., et al. (2025). Multimodal autoregressive pre-training of large vision encoders. <em>Proceedings of the Computer Vision and Pattern Recognition Conference</em>, 9641-9654. <a href="https://arxiv.org/abs/2502.14786">https://arxiv.org/abs/2502.14786</a> <a href="#fnref:fini2025multimodal" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:xu_next-embedding_2025" role="doc-endnote">
      <p>Xu, S., Ma, Z., Chai, W., Chen, X., Jin, W., Chai, J., Xie, S., &amp; Yu, S. X. (2025). Next-Embedding Prediction Makes Strong Vision Learners. <em>arXiv preprint arXiv:2512.16922</em>. <a href="https://arxiv.org/abs/2512.16922">https://arxiv.org/abs/2512.16922</a> <a href="#fnref:xu_next-embedding_2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:lei_robust_2025" role="doc-endnote">
      <p>Lei, J., Berner, J., Wang, J., Chen, Z., Ba, Z., Ren, K., Zhu, J., &amp; Anandkumar, A. (2025). Robust Representation Consistency Model via Contrastive Denoising. <em>arXiv preprint arXiv:2501.13094</em>. <a href="https://arxiv.org/abs/2501.13094">https://arxiv.org/abs/2501.13094</a> <a href="#fnref:lei_robust_2025" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:batatia2025foundation" role="doc-endnote">
      <p>Batatia, I., et al. (2025). A foundation model for atomistic materials chemistry. <em>The Journal of Chemical Physics</em>, 163(18). <a href="https://pubs.aip.org/aip/jcp/article/163/18/184110/3372267/A-foundation-model-for-atomistic-materials">https://pubs.aip.org/aip/jcp/article/163/18/184110/3372267/A-foundation-model-for-atomistic-materials</a> <a href="#fnref:batatia2025foundation" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:bernstein2024gap" role="doc-endnote">
      <p>Bernstein, N. (2024). From GAP to ACE to MACE. <em>arXiv preprint arXiv:2410.06354</em>. <a href="https://arxiv.org/abs/2410.06354">https://arxiv.org/abs/2410.06354</a> <a href="#fnref:bernstein2024gap" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:wedig2025rem3di" role="doc-endnote">
      <p>Wedig, S., Elijošius, R., Schran, C., &amp; Schaaf, L. L. (2025). REM3DI: Learning smooth, chiral 3D molecular representations from equivariant atomistic foundation models. <em>NeurIPS 2025 Workshop on Symmetry and Geometry in Neural Representations</em>. <a href="https://openreview.net/forum?id=jOmZsvXoK5">https://openreview.net/forum?id=jOmZsvXoK5</a> <a href="#fnref:wedig2025rem3di" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:wedig2025rem3di:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a></p>
    </li>
    <li id="fn:bojan2025representing" role="doc-endnote">
      <p>Bojan, M., Vedula, S., Maddipatla, A., Sellam, N. B., Napoli, F., Schanda, P., &amp; Bronstein, A. M. (2025). Representing local protein environments with atomistic foundation models. <em>arXiv preprint arXiv:2505.23354</em>. <a href="https://arxiv.org/abs/2505.23354">https://arxiv.org/abs/2505.23354</a> <a href="#fnref:bojan2025representing" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a> <a href="#fnref:bojan2025representing:1" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;<sup>2</sup></a></p>
    </li>
    <li id="fn:edamadaka2025universally" role="doc-endnote">
      <p>Edamadaka, S., Yang, S., Li, J., &amp; Gómez-Bombarelli, R. (2025). Universally Converging Representations of Matter Across Scientific Foundation Models. <em>arXiv preprint arXiv:2512.03750</em>. <a href="https://arxiv.org/abs/2512.03750">https://arxiv.org/abs/2512.03750</a> <a href="#fnref:edamadaka2025universally" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:li2025platonic" role="doc-endnote">
      <p>Li, Z., &amp; Walsh, A. (2025). Platonic representation of foundation machine learning interatomic potentials. <em>arXiv preprint arXiv:2512.05349</em>. <a href="https://arxiv.org/abs/2512.05349">https://arxiv.org/abs/2512.05349</a> <a href="#fnref:li2025platonic" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:pinede2025unifying" role="doc-endnote">
      <p>Pinede, L., Yang, S., Nam, J., &amp; Gomez-Bombarelli, R. (2025). Unifying Force Prediction and Molecular Conformation Generation Through Representation Alignment. <em>ICML 2025 Generative AI and Biology (GenBio) Workshop</em>. <a href="https://openreview.net/pdf?id=yzkHGHvC74">https://openreview.net/pdf?id=yzkHGHvC74</a> <a href="#fnref:pinede2025unifying" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:corley2025atomworks" role="doc-endnote">
      <p>Corley, N., Mathis, S., Krishna, R., Bauer, M. S., Thompson, T. R., Ahern, W., …, Didi, K., …, Baker, D., &amp; DiMaio, F. (2025). Accelerating Biomolecular Modeling with AtomWorks and RF3. <em>bioRxiv</em>. <a href="https://www.biorxiv.org/content/10.1101/2025.08.14.670328v2">https://www.biorxiv.org/content/10.1101/2025.08.14.670328v2</a> <a href="#fnref:corley2025atomworks" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:chen2025slae" role="doc-endnote">
      <p>Chen, Y., Lu, T., Zhao, C., Wayment-Steele, H., &amp; Huang, P. (2025). SLAE: Strictly Local All-atom Environment for Protein Representation. <em>bioRxiv</em>. <a href="https://www.biorxiv.org/content/10.1101/2025.10.03.680398v1">https://www.biorxiv.org/content/10.1101/2025.10.03.680398v1</a> <a href="#fnref:chen2025slae" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:geffner2025proteina" role="doc-endnote">
      <p>Geffner, T., Didi, K., Zhang, Z., Reidenbach, D., Cao, Z., Yim, J., Geiger, M., Dallago, C., Kucukbenli, E., Vahdat, A., &amp; others. (2025). Proteina: Scaling flow-based protein structure generative models. <em>arXiv preprint arXiv:2503.00710</em>. <a href="https://arxiv.org/abs/2503.00710">https://arxiv.org/abs/2503.00710</a> <a href="#fnref:geffner2025proteina" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:schneuing2024structure" role="doc-endnote">
      <p>Schneuing, A., Harris, C., Du, Y., Didi, K., Jamasb, A., Igashov, I., Du, W., Gomes, C., Blundell, T. L., Lio, P., &amp; others. (2024). Structure-based drug design with equivariant diffusion models. <em>Nature Computational Science</em>, 4(12), 899–909. <a href="https://www.nature.com/articles/s43588-024-00737-x">https://www.nature.com/articles/s43588-024-00737-x</a> <a href="#fnref:schneuing2024structure" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
    <li id="fn:geffner2025laproteina" role="doc-endnote">
      <p>Geffner, T., Didi, K., Cao, Z., Reidenbach, D., Zhang, Z., Dallago, C., Kucukbenli, E., Kreis, K., &amp; Vahdat, A. (2025). La-proteina: Atomistic protein generation via partially latent flow matching. <em>arXiv preprint arXiv:2507.09466</em>. <a href="https://arxiv.org/abs/2507.09466">https://arxiv.org/abs/2507.09466</a> <a href="#fnref:geffner2025laproteina" class="reversefootnote" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
    </li>
  </ol>
</div>]]></content><author><name>Kieran Didi</name><email>kieran.didi@gmail.com</email></author><category term="ml" /><summary type="html"><![CDATA[A deep dive into the convergence of discriminative and generative AI, covering 4 phases of evolution from REPA to RAE and beyond.]]></summary><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" url="/assets/img/blog/r4g/titler4g.png" /><media:content medium="image" url="/assets/img/blog/r4g/titler4g.png" xmlns:media="http://search.yahoo.com/mrss/" /></entry><entry><title type="html">Dealing with the flood of protein structures</title><link href="/blog/proteins/2024-04-07-protein-structure-universe/" rel="alternate" type="text/html" title="Dealing with the flood of protein structures" /><published>2024-04-07T00:00:00+00:00</published><updated>2026-02-13T05:13:14+00:00</updated><id>/blog/proteins/protein-structure-universe</id><content type="html" xml:base="/blog/proteins/2024-04-07-protein-structure-universe/"><![CDATA[<p>With the explosion of protein structure prediction and the sheer number of predicted protein structures available in databases nowadays, we can ask exciting new questions that would have been unanswerable only a few years ago. However, we need new tools in order to answer these questions and deal with the flood of structural data. In this post, I describe a few of these new tools and the reasoning behind them.</p>

<ol id="markdown-toc">
  <li><a href="#how-protein-structure-prediction-changed-the-game" id="markdown-toc-how-protein-structure-prediction-changed-the-game">How protein structure prediction changed the game</a></li>
  <li><a href="#foldcomp-compressing-protein-structures-to-managable-sizes" id="markdown-toc-foldcomp-compressing-protein-structures-to-managable-sizes">FoldComp: compressing protein structures to managable sizes</a>    <ol>
      <li><a href="#the-trouble-with-compression" id="markdown-toc-the-trouble-with-compression">The trouble with compression</a></li>
      <li><a href="#the-foldcomp-compression-scheme" id="markdown-toc-the-foldcomp-compression-scheme">The FoldComp compression scheme</a></li>
      <li><a href="#nerf-and-the-lever-arm-effect" id="markdown-toc-nerf-and-the-lever-arm-effect">NeRF and the lever-arm effect</a></li>
      <li><a href="#the-lever-arm-solution-bidirectional-nerf-and-anchoring" id="markdown-toc-the-lever-arm-solution-bidirectional-nerf-and-anchoring">The lever-arm solution: bidirectional NeRF and anchoring</a></li>
    </ol>
  </li>
  <li><a href="#mmseqs2-sequence-alignment-in-speed-mode" id="markdown-toc-mmseqs2-sequence-alignment-in-speed-mode">MMseqs2: sequence alignment in speed-mode</a>    <ol>
      <li><a href="#why-do-we-need-fast-sequence-alignment" id="markdown-toc-why-do-we-need-fast-sequence-alignment">Why do we need fast sequence alignment?</a></li>
      <li><a href="#prefiltering-is-key" id="markdown-toc-prefiltering-is-key">Prefiltering is key</a></li>
      <li><a href="#use-the-prefilter-for-clustering" id="markdown-toc-use-the-prefilter-for-clustering">Use the prefilter for clustering</a></li>
    </ol>
  </li>
  <li><a href="#foldseek-structural-clustering-of-the-protein-universe" id="markdown-toc-foldseek-structural-clustering-of-the-protein-universe">FoldSeek: structural clustering of the protein universe</a>    <ol>
      <li><a href="#structure-to-sequence-the-3di-alphabet" id="markdown-toc-structure-to-sequence-the-3di-alphabet">Structure to Sequence: the 3Di alphabet</a></li>
      <li><a href="#virtual-centers-optimise-conservation-of-interactions-and-tertiary-vs-local-interactions" id="markdown-toc-virtual-centers-optimise-conservation-of-interactions-and-tertiary-vs-local-interactions">Virtual centers optimise conservation of interactions and tertiary vs. local interactions</a></li>
      <li><a href="#learning-the-3di-alphabet-via-a-vq-vae" id="markdown-toc-learning-the-3di-alphabet-via-a-vq-vae">Learning the 3Di alphabet via a VQ-VAE</a></li>
      <li><a href="#speeding-things-up-by-building-on-mmseqs2" id="markdown-toc-speeding-things-up-by-building-on-mmseqs2">Speeding things up by building on mmseqs2</a></li>
    </ol>
  </li>
  <li><a href="#applications-clustering-the-protein-universe" id="markdown-toc-applications-clustering-the-protein-universe">Applications: clustering the protein universe</a></li>
</ol>

<p>If you would like to cite this post in an academic context, you can use this BibTeX snippet:</p>

<div class="language-bibtex highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nc">@misc</span><span class="p">{</span><span class="nl">didi2024proteinstructureuniverse</span><span class="p">,</span>
  <span class="na">author</span> <span class="p">=</span> <span class="s">{Didi, Kieran}</span><span class="p">,</span>
  <span class="na">title</span> <span class="p">=</span> <span class="s">{Dealing with the flood of protein structures}</span><span class="p">,</span>
  <span class="na">url</span> <span class="p">=</span> <span class="s">{https://kdidi.netlify.app/blog/proteins/2024-04-07-protein-structure-universe/}</span><span class="p">,</span>
  <span class="na">year</span> <span class="p">=</span> <span class="s">{2024}</span>
<span class="p">}</span>
</code></pre></div></div>

<h2 id="how-protein-structure-prediction-changed-the-game">How protein structure prediction changed the game</h2>

<p>The PDB as a database of experimental protein structures keeps growing, currently standing at <a href="https://www.rcsb.org/">nearly 218k</a> entries. However, it seems small compared to the <a href="https://academic.oup.com/nar/article/50/D1/D439/6430488">AlphaFoldDB (&gt;200m)</a> and <a href="https://esmatlas.com/">ESMAtlas (772m structures)</a>, powered by the recent advances in protein structure prediction via methods like <a href="https://www.nature.com/articles/s41586-021-03819-2">AlphaFold2</a> and <a href="https://www.science.org/doi/10.1126/science.ade2574">ESMFold</a>.</p>

<p>This development changed the game in protein biology. While until recently the <a href="https://moalquraishi.wordpress.com/2019/04/01/the-future-of-protein-science-will-not-be-supervised/">gap between available protein sequences and structures widened further and further</a>, we suddenly have a wealth of structural information that was unimaginable a decade ago. This quote from Mohammed AlQuraishi (Columbia University) sums up this paradigm shift well:</p>

<blockquote class="lead">
  <p>Everything we did with protein sequences we can now do with protein structures</p>
</blockquote>

<p>While that is a theoretically true and very exciting prospect, there is one big problem: we do not have tools to deal with such amounts of structural data. Here a visual comparison between the size of the PDB and the AFDB:</p>

<p align="center">
  <img src="/assets/img/blog/prot_representation/afdb_size.png" width="50%" height="50%" />
</p>

<p class="figcaption">Visual comparison of the size of the PDB vs the AFDB. Source: <a href="https://www.youtube.com/watch?v=IJtWTxhuunk">YouTube</a></p>

<p>You can see that we deal with a different order of magnitude in data here. This brings up a plethora of issues, starting from pure memory usage (the storage for AFDB is 23 TB) to questions of how we move these enormous amounts of data and also process them.</p>

<p>Many groups have developed tools in the last years to tackle this issue. Especially the <a href="https://steineggerlab.com/en/">Steinegger lab</a> has produced some fantastic tools in that space from which I want to present three here in this blogpost: Foldcomp for structure compression, Foldseek for structure clustering and mmseqs for sequence clustering (also very important in that context for generating both input MSAs and training splits).</p>

<p><img src="/assets/img/blog/prot_representation/steinegger_tools.png" alt="steinegger_tools" /></p>

<p class="figcaption">Tools from the Steinegger Lab. Source: <a href="https://www.youtube.com/watch?v=IJtWTxhuunk">YouTube</a></p>

<h2 id="foldcomp-compressing-protein-structures-to-managable-sizes">FoldComp: compressing protein structures to managable sizes</h2>
<ul>
  <li><a href="https://academic.oup.com/bioinformatics/article/39/4/btad153/7085592">Paper</a></li>
  <li><a href="https://www.youtube.com/watch?v=aFtqH0VqE7w">Talk</a>
    <h3 id="the-trouble-with-compression">The trouble with compression</h3>
  </li>
</ul>

<p>A perfect compression format satisfied all these three conditions:</p>
<ol>
  <li>The compressed files are small.</li>
  <li>The compression and decompression algorithms are fast</li>
  <li>The reconstruction is either lossless or (if lossy) has minimal reconstruction error.</li>
</ol>

<p>Fulfilling all of these at the same time is hard, so one always has to think about how to balance between them.</p>

<p>As described in the first section of this post, there have been efforts for compressed protein structure formats such as MMTF or binaryCIF. However, given the sheer amount of predicted protein structures, the authors decided that more efficient algorithms are needed.</p>

<p>People have tried this in the past by talking inspiration from image compression algorithms such as <a href="https://www.youtube.com/watch?v=EFUYNoFRHQI">PNG</a> and <a href="https://www.youtube.com/watch?v=Kv1Hiv3ox8I">JPEG</a> as in the example of the <a href="https://link.springer.com/article/10.1186/s12859-023-05570-z#Sec7">PIC algorithm</a>. These lossless formats are great since they reconstruct your data perfectly, but often leave some performance in terms of both speed and size on the table by focusing on reconstruction quality.</p>

<p>Therefore, looking into lossy compression formats often pays off if you are fine with paying a small penalty in terms of reconstruction error. Since our measurements of protein structures contain measurement errors anyway, we can often pay this penalty and still get great results for our biological problems such as for example <a href="https://onlinelibrary.wiley.com/doi/full/10.1002/pro.4511">energy calculations from MD trajectories</a>.</p>

<h3 id="the-foldcomp-compression-scheme">The FoldComp compression scheme</h3>

<p>In this spirit, Kim et al. from the Steinegger lab decided to build a lossy compression format that converts the nearly 100 bytes of 3D coordinates per residue into only 13 bytes of compressed internal coordinates (in this case torsion angles).</p>

<p><img src="/assets/img/blog/prot_representation/foldcomp.png" alt="FoldComp" /></p>

<p class="figcaption">FoldComp compression scheme. Source: <a href="https://www.youtube.com/watch?v=IJtWTxhuunk">YouTube</a></p>

<p>As you can see in that graphic, they do not only save the backbone and side-chain torsion angles, but also bond angles. This should in theory not be necessary since one should be able to reconstruct the full-atom structure by just using torsion angles. However, this theory assumes an idealised protein backbone geometry with constant bond angles, which is a bit too simplistic in practice to get very low reconstruction error. Encoding these bond angles improves the reconstruction a lot.</p>

<p>In order to not make the space occupied by both torsion and bond angles to demanding, they employ a quantisation step where they save both of these entities as discretised pre-defined values. This procedure is also commonly known as binning and has been used to great extent machine learning for <a href="https://huggingface.co/docs/optimum/concept_guides/quantization">weights and activations</a> as well as for <a href="https://arxiv.org/abs/2110.02861">optimiser states</a>, up to the extreme of recent <a href="https://arxiv.org/abs/2402.17764">1-bit LLMs</a>.</p>

<h3 id="nerf-and-the-lever-arm-effect">NeRF and the lever-arm effect</h3>

<p>Saving the actual bond angles helped to lower the reconstruction error for the first few residues that were reconstructed. However, the longer the polymer chain get, the bigger the reconstruction error became down the line. This problem is related to a phenomenon known as <a href="https://mphy0026.readthedocs.io/en/latest/tracking/errors.html">lever-arm effect</a> in engineering. It describes the propagation of an error on the rotation measurement in a series of successive measurements, with the error magnitude increasing the longer the distance between the original measurement and the reconstruction.</p>

<p>To understand this in the context of proteins, let’s look at the method FoldComp and others in the field use to convert the stored torsion angles back into 3D coordinates: the <a href="https://people.tamu.edu/~rojas/chemtorsion.pdf">NeRF (Natural Extension Reference Frame) method</a> (unrelated to the NeRF in machine learning which stands for <a href="https://datagen.tech/guides/synthetic-data/neural-radiance-field-nerf/#">Neural Radiance Fields</a>).</p>

<p>There have been multiple versions of NeRF such as <a href="https://onlinelibrary.wiley.com/doi/abs/10.1002/jcc.25772">pNeRF</a> and <a href="https://pubmed.ncbi.nlm.nih.gov/34709663/">MP-NeRF</a> that make it more efficient via parallelisation, but the <a href="https://sbl.inria.fr/doc/Molecular_coordinates-user-manual.html#fig-nerf-embedding">basic algorithmic ideas</a> stay the same:</p>

<ol>
  <li>We can place our first backbone atom wherever we want and define this as our origin: <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>A</mi><mn>1</mn></msub><mo stretchy="false">(</mo><mn>0</mn><mo separator="true">,</mo><mn>0</mn><mo separator="true">,</mo><mn>0</mn><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">A_1(0,0,0)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord">0</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">0</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">0</span><span class="mclose">)</span></span></span></span></li>
  <li>Given a first backbone atom (<span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>A</mi><mn>1</mn></msub></mrow><annotation encoding="application/x-tex">A_1</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span>), we can place the second one arbitrarily in space and just constrain its position by the known bond distance <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>d</mi><mn>1</mn></msub></mrow><annotation encoding="application/x-tex">d_1</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8444em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span>: <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>A</mi><mn>2</mn></msub><mo stretchy="false">(</mo><mn>0</mn><mo separator="true">,</mo><mn>0</mn><mo separator="true">,</mo><msub><mi>d</mi><mn>1</mn></msub><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">A_2(0,0,d_1)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord">0</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">0</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></li>
  <li>Given the first two backbone atoms (<span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>A</mi><mn>1</mn></msub><mo separator="true">,</mo><msub><mi>A</mi><mn>2</mn></msub></mrow><annotation encoding="application/x-tex">A_1, A_2</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8778em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span>), we can place the third one in space by using the literature bond distance <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>d</mi><mn>2</mn></msub></mrow><annotation encoding="application/x-tex">d_2</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8444em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> and angle <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>θ</mi><mn>1</mn></msub></mrow><annotation encoding="application/x-tex">\theta_1</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8444em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:-0.0278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span>: <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>A</mi><mn>3</mn></msub><mo stretchy="false">(</mo><mn>0</mn><mo separator="true">,</mo><mi>sin</mi><mo>⁡</mo><mo stretchy="false">(</mo><msub><mi>θ</mi><mn>1</mn></msub><mo stretchy="false">)</mo><mo>∗</mo><msub><mi>d</mi><mn>2</mn></msub><mo separator="true">,</mo><mi>d</mi><mn>1</mn><mo>−</mo><mi>cos</mi><mo>⁡</mo><mo stretchy="false">(</mo><msub><mi>θ</mi><mn>1</mn></msub><mo stretchy="false">)</mo><mo>∗</mo><msub><mi>d</mi><mn>2</mn></msub><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">A_3(0, \sin(\theta_1) * d_2, d1 - \cos(\theta_1) * d_2)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">3</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord">0</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop">sin</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:-0.0278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">∗</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord mathnormal">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathnormal">d</span><span class="mord">1</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mop">cos</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:-0.0278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">∗</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></li>
  <li>Given the first three backbone atoms (<span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>A</mi><mn>1</mn></msub><mo separator="true">,</mo><msub><mi>A</mi><mn>2</mn></msub><mo separator="true">,</mo><msub><mi>A</mi><mn>3</mn></msub></mrow><annotation encoding="application/x-tex">A_1, A_2, A_3</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8778em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">3</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span>), we can place the fourth one in space by using the literature bond distance (<span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>d</mi><mn>3</mn></msub></mrow><annotation encoding="application/x-tex">d_3</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8444em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">3</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span>), the literature (or saved in the case of FoldComp) angle <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>θ</mi><mn>2</mn></msub></mrow><annotation encoding="application/x-tex">\theta_2</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8444em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:-0.0278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> and the stored torsion angle <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>τ</mi><mn>1</mn></msub></mrow><annotation encoding="application/x-tex">\tau_1</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.5806em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.1132em;">τ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:-0.1132em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span>. To do this, we first define a new coordinate system called <em>specialised reference frame</em> centered at <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>A</mi><mn>3</mn></msub></mrow><annotation encoding="application/x-tex">A_3</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">3</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> using spherical coordinates and places <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msubsup><mi>A</mi><mn>4</mn><mo>∗</mo></msubsup></mrow><annotation encoding="application/x-tex">A_4^*</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9368em;vertical-align:-0.2481em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6887em;"><span style="top:-2.4519em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">4</span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2481em;"><span></span></span></span></span></span></span></span></span></span> there:</li>
</ol>

<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mtable rowspacing="0.25em" columnalign="right left" columnspacing="0em"><mtr><mtd><mstyle scriptlevel="0" displaystyle="true"><msubsup><mi>A</mi><mn>4</mn><mo>∗</mo></msubsup></mstyle></mtd><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow><mrow></mrow><mo>=</mo><mo stretchy="false">(</mo><msub><mi>d</mi><mn>3</mn></msub><mi>cos</mi><mo>⁡</mo><mo stretchy="false">(</mo><msub><mi>θ</mi><mn>2</mn></msub><mo stretchy="false">)</mo><mo separator="true">,</mo><msub><mi>d</mi><mn>3</mn></msub><mi>cos</mi><mo>⁡</mo><mo stretchy="false">(</mo><msub><mi>τ</mi><mn>1</mn></msub><mo stretchy="false">)</mo><mi>sin</mi><mo>⁡</mo><mo stretchy="false">(</mo><msub><mi>θ</mi><mn>2</mn></msub><mo stretchy="false">)</mo><mo separator="true">,</mo><msub><mi>d</mi><mn>3</mn></msub><mi>sin</mi><mo>⁡</mo><mo stretchy="false">(</mo><msub><mi>τ</mi><mn>1</mn></msub><mo stretchy="false">)</mo><mi>sin</mi><mo>⁡</mo><mo stretchy="false">(</mo><msub><mi>θ</mi><mn>2</mn></msub><mo stretchy="false">)</mo><mo stretchy="false">)</mo></mrow></mstyle></mtd></mtr></mtable><annotation encoding="application/x-tex">\begin{aligned} %!!15
A_4^* &amp;= (d_3 \cos(\theta_2), d_3 \cos(\tau_1) \sin(\theta_2), d_3 \sin(\tau_1) \sin(\theta_2))
\end{aligned}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.5em;vertical-align:-0.5em;"></span><span class="mord"><span class="mtable"><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1em;"><span style="top:-3.16em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.7387em;"><span style="top:-2.453em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">4</span></span></span><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.247em;"><span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.5em;"><span></span></span></span></span></span><span class="col-align-l"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1em;"><span style="top:-3.16em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">3</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop">cos</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:-0.0278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">3</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop">cos</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.1132em;">τ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:-0.1132em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop">sin</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:-0.0278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">3</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop">sin</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.1132em;">τ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:-0.1132em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop">sin</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:-0.0278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">))</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.5em;"><span></span></span></span></span></span></span></span></span></span></span></span>

<p class="figcaption">Calculation of <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msubsup><mi>A</mi><mn>4</mn><mo>∗</mo></msubsup></mrow><annotation encoding="application/x-tex">A_4^*</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9368em;vertical-align:-0.2481em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6887em;"><span style="top:-2.4519em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">4</span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2481em;"><span></span></span></span></span></span></span></span></span></span> in the specialised reference frame.</p>

<p>We then rototranslate <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msubsup><mi>A</mi><mn>4</mn><mo>∗</mo></msubsup></mrow><annotation encoding="application/x-tex">A_4^*</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9368em;vertical-align:-0.2481em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6887em;"><span style="top:-2.4519em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">4</span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2481em;"><span></span></span></span></span></span></span></span></span></span> back from that specialised reference frame back to our original coordinate system via <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>A</mi><mn>4</mn></msub><mo>=</mo><mi>R</mi><msubsup><mi>A</mi><mn>4</mn><mo>∗</mo></msubsup><mo>+</mo><msub><mi>A</mi><mn>3</mn></msub></mrow><annotation encoding="application/x-tex">A_4 = RA_4^* + A_3</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">4</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.9368em;vertical-align:-0.2481em;"></span><span class="mord mathnormal" style="margin-right:0.00773em;">R</span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6887em;"><span style="top:-2.4519em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">4</span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2481em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.8333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">3</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> and with</p>

<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mtable rowspacing="0.25em" columnalign="right left" columnspacing="0em"><mtr><mtd><mstyle scriptlevel="0" displaystyle="true"><mi>R</mi></mstyle></mtd><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow><mrow></mrow><mo>=</mo><mo stretchy="false">[</mo><msub><mover accent="true"><mi>A</mi><mo>^</mo></mover><mrow><mn>2</mn><mo>−</mo><mn>3</mn></mrow></msub><mo separator="true">,</mo><mover accent="true"><mi>n</mi><mo>^</mo></mover><mo>×</mo><msub><mover accent="true"><mi>A</mi><mo>^</mo></mover><mrow><mn>2</mn><mo>−</mo><mn>3</mn></mrow></msub><mo separator="true">,</mo><mover accent="true"><mi>n</mi><mo>^</mo></mover><mo stretchy="false">]</mo></mrow></mstyle></mtd></mtr><mtr><mtd><mstyle scriptlevel="0" displaystyle="true"><msub><mover accent="true"><mi>A</mi><mo>^</mo></mover><mrow><mn>2</mn><mo>−</mo><mn>3</mn></mrow></msub></mstyle></mtd><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow><mrow></mrow><mo>=</mo><mfrac><mrow><msub><mi>A</mi><mn>2</mn></msub><msub><mi>A</mi><mn>3</mn></msub></mrow><mrow><mo>∣</mo><msub><mi>A</mi><mn>2</mn></msub><msub><mi>A</mi><mn>3</mn></msub><mo>∣</mo></mrow></mfrac></mrow></mstyle></mtd></mtr><mtr><mtd><mstyle scriptlevel="0" displaystyle="true"><mover accent="true"><mi>n</mi><mo>^</mo></mover></mstyle></mtd><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow><mrow></mrow><mo>=</mo><mfrac><mrow><msub><mi>A</mi><mn>1</mn></msub><msub><mi>A</mi><mn>2</mn></msub><mo>×</mo><msub><mover accent="true"><mi>A</mi><mo>^</mo></mover><mrow><mn>2</mn><mo>−</mo><mn>3</mn></mrow></msub></mrow><mrow><mo>∣</mo><msub><mi>A</mi><mn>1</mn></msub><msub><mi>A</mi><mn>2</mn></msub><mo>×</mo><msub><mover accent="true"><mi>A</mi><mo>^</mo></mover><mrow><mn>2</mn><mo>−</mo><mn>3</mn></mrow></msub><mo>∣</mo></mrow></mfrac></mrow></mstyle></mtd></mtr></mtable><annotation encoding="application/x-tex">\begin{aligned} %!!15 
R &amp;= [\hat{A}_{2-3}, \hat{n} \times \hat{A}_{2-3}, \hat{n}] \\
\hat{A}_{2-3} &amp;= \frac{A_2 A_3}{\mid A_2 A_3 \mid}\\
\hat{n} &amp;= \frac{A_1 A_2 \times \hat{A}_{2-3}}{\mid A_1 A_2 \times \hat{A}_{2-3} \mid }
\end{aligned}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:7.2136em;vertical-align:-3.3568em;"></span><span class="mord"><span class="mtable"><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:3.8568em;"><span style="top:-6.5338em;"><span class="pstrut" style="height:3.6238em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.00773em;">R</span></span></span><span style="top:-4.5135em;"><span class="pstrut" style="height:3.6238em;"></span><span class="mord"><span class="mord"><span class="mord accent"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9468em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal">A</span></span><span style="top:-3.2523em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.1111em;"><span class="mord">^</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span><span class="mbin mtight">−</span><span class="mord mtight">3</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span></span></span><span style="top:-1.6537em;"><span class="pstrut" style="height:3.6238em;"></span><span class="mord"><span class="mord accent"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.6944em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal">n</span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:3.3568em;"><span></span></span></span></span></span><span class="col-align-l"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:3.8568em;"><span style="top:-6.5338em;"><span class="pstrut" style="height:3.6238em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mopen">[</span><span class="mord"><span class="mord accent"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9468em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal">A</span></span><span style="top:-3.2523em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.1111em;"><span class="mord">^</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span><span class="mbin mtight">−</span><span class="mord mtight">3</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord accent"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.6944em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal">n</span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord"><span class="mord accent"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9468em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal">A</span></span><span style="top:-3.2523em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.1111em;"><span class="mord">^</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span><span class="mbin mtight">−</span><span class="mord mtight">3</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord accent"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.6944em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal">n</span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span></span></span></span><span class="mclose">]</span></span></span><span style="top:-4.5135em;"><span class="pstrut" style="height:3.6238em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.3603em;"><span style="top:-2.314em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mrel">∣</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">3</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">∣</span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">3</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.936em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span><span style="top:-1.6537em;"><span class="pstrut" style="height:3.6238em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.6238em;"><span style="top:-2.1632em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mrel">∣</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord"><span class="mord accent"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9468em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal">A</span></span><span style="top:-3.2523em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.1111em;"><span class="mord">^</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span><span class="mbin mtight">−</span><span class="mord mtight">3</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">∣</span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord"><span class="mord accent"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9468em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal">A</span></span><span style="top:-3.2523em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.1111em;"><span class="mord">^</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span><span class="mbin mtight">−</span><span class="mord mtight">3</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.0868em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:3.3568em;"><span></span></span></span></span></span></span></span></span></span></span></span>

<p class="figcaption">Rototranslation of <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msubsup><mi>A</mi><mn>4</mn><mo>∗</mo></msubsup></mrow><annotation encoding="application/x-tex">A_4^*</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9368em;vertical-align:-0.2481em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6887em;"><span style="top:-2.4519em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">4</span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2481em;"><span></span></span></span></span></span></span></span></span></span> back to the original coordinate system to form <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>A</mi><mn>4</mn></msub></mrow><annotation encoding="application/x-tex">A_4</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">4</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span>.</p>

<ol>
  <li>We can repeat step 4 for all forthcoming atoms until we are at the end of the polymer chains.</li>
</ol>

<p>Reconstruction of the backbone works in a similar way, just using different values for bond distances, bond angles and torsion angles.</p>

<p align="center">
  <img src="/assets/img/blog/prot_representation/nerf.png" width="50%" height="50%" />
</p>

<p class="figcaption">NeRF algorithm. Source: <a href="https://sbl.inria.fr/doc/Molecular_coordinates-user-manual.html#fig-nerf-embedding">Structural Bioinformatics Library</a></p>

<p class="note" title="Fun Fact">As a fun fact, and another anecdote to how small the world of science is: Charlie Strauss, the lead author of the original NeRF paper from 2005, is from Seattle and <a href="https://www-k12.atmos.washington.edu/k12/mars/tillmans_reports/strauss.html">did a summer job as a highschooler with Prof Tillman at the University of Washington</a> working on mars metereology. That gave him both inspiration and grit to go into science, ending up in the Los Alamos National Laboratory where he supervised the NeRF paper that was published in 2005. As another unexpected twist of events, in the 90s he took a year long sabbatical at UW working in the lab of David Baker and <a href="https://www-k12.atmos.washington.edu/k12/mars/tillmans_reports/Baker_Laboratory_files/newindex.html">improving their Rosetta algortihm for protein structure prediction</a>. Yes, you heard right, the David Baker whose lab became synonymous with protein design and is still a pioneer in that field. Wha a funny world we live in.</p>

<p>Now with this NeRF algorithm at hand, we can go about reconstruction our 3D cartesian coordinates from our internal coordinates represented as torsion and bond angles. There is only one problem: the previously mentioned lever-arm effect.</p>

<p>We will get pretty accurate reconstruction for the first few residues, but small errors will accumulate since every reconstruction step is only <em>relative</em> to the previous ones. You can imagine this with an analogy: let’s say you want to follow a route on Google Maps. The routing instructions are successive relative statements (“turn left in 100 meters”, “go straight for 1km”, …), similar to how the torsion and bond angles during NeRF reconstruction are relative reconstruction steps. What you want in the end however is the full path to your correct destination; you are therefore “reconstructing” the correct path from these relative instructions.</p>

<p>Now if you follow the instructions very carefully at the start and only turn left instead of right close to the destination you will still be very close to your actual destination; your reconstruction error is low. However, if you start off your journey by taking the wrong exit in a roundabout and just keep following the instructions (ignoring that Google Maps will try to course-correct you), you will end up god-knows where! The error you made at the beginning propagates down to all of your successive steps and will accumulate, leading to a massive reconstruction error at the end.</p>

<p>The same is happening for the NeRF algorithm: a small reconstruction error at the start will lead to a large reconstruction error later along the peptide backbone, leaving you with a poorly reconstructed protein.</p>

<p>This phenomenon is not new by any means, not even in the protein community: while the first protein structure predicition methods like <a href="https://www.cell.com/cell-systems/pdf/S2405-4712(19)30076-6.pdf">RGN</a> used recurrent networks based on torsion angle prediction for reconstructing protein backbones, later methods like AlphaFold2 instead leveraged transformer-based architectures that utilise parallel reconstruction directly in Cartesion space instead of sequential reconstruction via internal coordinates. Similar observations where made in protein structure generation: <a href="https://www.nature.com/articles/s41467-024-45051-2">FoldingDiff</a>, a diffusion model by Microsoft Research, leveraged a torsion-angle based representation to generate protein backbones, and while that worked well for relatively short proteins, they note on <a href="https://static-content.springer.com/esm/art%3A10.1038%2Fs41467-024-45051-2/MediaObjects/41467_2024_45051_MOESM1_ESM.pdf">page 4 of the SI</a> that for larger proteins lever-arm effects play a role (although the model seems to be relatively robust in some cases).</p>

<h3 id="the-lever-arm-solution-bidirectional-nerf-and-anchoring">The lever-arm solution: bidirectional NeRF and anchoring</h3>

<p>While some machine learning algorithms like <a href="https://pubmed.ncbi.nlm.nih.gov/36749957/">Int2Cart</a> were developed to ameliorate the lever-arm problem, the FoldComp authors decided to stick with good-old NeRF and instead give it a boost via two approaches:</p>
<ol>
  <li><strong>Bidirectionality</strong>: They start NeRF from the N- and the C-terminus of the polypeptide chain and using a weighted average of both reconstructions at each position to get a better consensus position. This requires us to save the position of the first and last residue in Cartesian coordinates, since now we cannot place them arbitrarily in space, but need them to be at the correct distance and orientation to each other. This helps a lot with lowering the reconstruction error at the start and the end of the protein backbone, but leaves the center still relatively vulnerable to lever-arm effects.</li>
  <li><strong>Anchoring</strong>: if we now saved the first and the last amino acid, why stop there? Of course we do not want to save the 3D coordinates of <em>every</em> residue; if we do that we do not need a NeRF reconstruction to begin with. But the authors found that even doing that for every 25th amino acid in the backbone improved results dramatically, landing in a sweet spot where both memory requirements are still reduced a lot but reconstruction error is also way below experimental resolution accuracy (around 0.1 Angstrom for the backbone and around 0.15 for all-atom RMSD).</li>
</ol>

<p>With these two improvements, they managed to strike a good balance: they are as fast as gzip when decompressing and are a lot faster than other tools when compressing (10% of gzip) and reduce the storage requirements by a lot (2.9 GB vs the original 23 TB for the AFDB), all of this while mainting very low reconstruction errors, making it a very useful tool for large-scale structural bioinformatics.</p>

<h2 id="mmseqs2-sequence-alignment-in-speed-mode">MMseqs2: sequence alignment in speed-mode</h2>

<ul>
  <li><a href="https://www.nature.com/articles/nbt.3988">Paper</a></li>
  <li><a href="https://www.youtube.com/watch?v=lMq89wEPuaU">Talk</a></li>
</ul>

<p><em>Wait</em>, you might say, <em>you promised tools for large-scale protein structure analysis; why are we discussing a sequence alignment method</em>?</p>

<p>Bear with me, for I have my reasons:</p>
<ol>
  <li>Sequence alignment and clustering is one of the most-studied topics in bioinformatics and underpins many of the technologies and scientific discoveries made in the last decades, so it is generally something to be aware of</li>
  <li>Even as part of machine learning approaches for protein structure, sequence alignment and clustering is often used to create meaningful splits for training and test datasets (for more info I gave <a href="https://structural-bioinformatics.netlify.app/blog/proteins/2023-08-02-lesson4/">this lecture</a> about that topic)</li>
  <li>We will see later that structure alignment tools like <a href="https://www.nature.com/articles/s41587-023-01773-0">FoldSeek</a> reuse many of the components and ideas from MMseqs2, so it is useful to have it in the back of your mind.</li>
</ol>

<h3 id="why-do-we-need-fast-sequence-alignment">Why do we need fast sequence alignment?</h3>

<p>With that out of the way, what is MMseqs2 and which problem does it solve?</p>

<p>MMseqs2 (Many-against-Many sequence searching) is a tool that allows you to align and search protein sequence in a high-throughput manner while still retaining sensitivity. One application of this is metagenomics, where we get billions of possible ORFs (Open Reading Frames) from cheap DNA sequencing, but then need to search for potential hits in massive online databases like UniProt or KEGG to confirm that these potential ORFs are actually real genes. The exponential growth of sequencing data leads to a rare situation here where the cost for the computational analysis by far exceeds the actual sequencing cost, making the sequence search part fo the pipeline the real bottleneck.</p>

<p>Another application that might be a bit closer to home is MSA generation. Algorithms like AlphaFold heavily rely on MSAs as input to <a href="https://structural-bioinformatics.netlify.app/blog/proteins/2023-08-03-lesson6/">extract coevolutionary information and predict the structure of the input sequence</a>. While the original AlphaFold2 used tools like <a href="https://bmcbioinformatics.biomedcentral.com/articles/10.1186/1471-2105-11-431">JackHMMER</a> and <a href="https://www.nature.com/articles/nmeth.1818">HHBlits</a> for MSA generation, these <a href="https://www.youtube.com/watch?v=vO_6xfLwGao">profile-HMMs</a> based tools are still relatively slow (although a lot faster than the original <a href="https://en.wikipedia.org/wiki/Viterbi_algorithm">Viterbi</a> or <a href="https://en.wikipedia.org/wiki/Forward_algorithm">Forward</a> algorithms that are classically used for scoring in hidden markov models). By using MMseqs2 instead for this particular application, <a href="https://www.nature.com/articles/s41592-022-01488-1">ColabFold</a> achieved 40-60 faster search and enabled everyone to predict protein structures via <a href="https://github.com/sokrypton/ColabFold">Google Colaboratory</a>.</p>

<h3 id="prefiltering-is-key">Prefiltering is key</h3>

<p>How does it get this massive speed-up? The gold-standard for sequence alignment is still dynamic programming in the form of the [Needleman-Wunsch algorithm] for global sequence alignment or the <a href="https://www.youtube.com/watch?v=lu9ScxSejSE&amp;list=TLPQMjEwNzIwMjPG6vLp-w7KnQ&amp;index=3">Smith-Waterman algorithm</a> for local sequence alignment. These algorithms give the optimal alignment, but take <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>O</mi><mo stretchy="false">(</mo><mi>n</mi><mi>m</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">O(nm)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.02778em;">O</span><span class="mopen">(</span><span class="mord mathnormal">nm</span><span class="mclose">)</span></span></span></span> time for aligning two sequences of length <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>n</mi></mrow><annotation encoding="application/x-tex">n</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord mathnormal">n</span></span></span></span> and <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>m</mi></mrow><annotation encoding="application/x-tex">m</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord mathnormal">m</span></span></span></span> and are therefore impractical for many applications.</p>

<p>Many new tools still use these algorithms in the backend, but put a harsh prefilter before them so that the search space is reduced by multiple orders of magnitude while discarding as few true positives as possbile, passing only the most promising candidates for alignment to the expensive dynamic programming algorithms. MMSeqs2 is no different: it’s biggest selling point is the strong prefilter that is based on kmers; to be more precise, it looks for 2 consecutive 7-mers on a diagonal, and we will now spend some time to try and understand that statement.</p>

<p>The MMSeqs2 prefilter is divided into 4 different stages that correspond to nested for-loops:</p>

<ol>
  <li>As a preprocessing step, we take all our target sequences we might align query sequences to and create a precomputed index table of 7-mers that will allow fast 7-mer lookup. Each kmer acts in this index table as a key, and the corresponding value contains an index for the target sequence and an index for the position in that target sequence, uniquely identifying the position of that k-mer in the target sequence database.</li>
  <li>We now enter our first for-loop by processing each query sequence one by one and produce all possible 7-mers in a sliding window fashion.</li>
  <li>For each of these k-mers, we now produce a list of similar k-mers, where similarity is judged by some score threshold, either via a <a href="https://www.nature.com/articles/nbt0804-1035">BLOSUM score</a> or some <a href="https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2279992/">profile score</a> that judges how similar the generated k-mer is to the query k-mer.</li>
  <li>For each k-mer in that list, we now query the precomputed index table and see if we find a hit via our k-mer lookup. If we find a hit, we process to our fourth and last nested for-loop.</li>
  <li>We now check if we offset between the position of that k-mer in the query and in the target sequence has been observed the last time we checked. If that is the case, it means we already found two k-mers that match between the two sequences in the same reference offset, which is a sign that these two sequences have a high chance of having a good alignment. This process is often visualised by plotting the query position on the x and the target position on the y axis and looking of both kmers occur on the same diagonal. If we find two of these as just described, MMSeqs2 calls this a <em>double diagonal hit</em> and causes that sequence to be saved for more detailed analysis later.</li>
</ol>

<p><img src="/assets/img/blog/prot_representation/mmseqs2_prefilter.png" alt="mmseqs2_prefilter" /></p>

<p class="figcaption">MMSeqs2 Prefilter algorithm (a lot going on in that figure, but hopefully the description helps). Source: <a href="https://www.nature.com/articles/nbt.3988">MMSeqs2 Paper</a></p>

<p>This prefilter already cuts down the number of hits by a lot. However, the result is still to expensive for a full Smith-Waterman alignment. Part of what makes this dynamic programming algorithm very expensive is the possibility to include gaps in the alignment. Therefore, as an additional filter, the sequences that gave double diagonal hits undergo an ungapped alignment that is relatively fast (although slower than the prefilter). If the best diagonal of that alignment has a score above a predefined threshold, we finally do a proper gapped alignment and get our final result out.</p>

<p><img src="/assets/img/blog/prot_representation/mmseqs2_pipeline.png" alt="mmseqs2_pipeline" /></p>

<p class="figcaption">MMSeqs2 progressively filters out hits and passes them to more and more expensive alignment stages. Source: <a href="https://www.youtube.com/watch?v=lMq89wEPuaU">YouTube</a></p>

<p>In addition to that, the authors play all tricks in the hardware book to be fast, from <a href="https://en.wikipedia.org/wiki/Advanced_Vector_Extensions">AVX2</a> that allow 32 1-byte operations like add/mult/max to be computed in parallel per CPU clock cycle to optimising CPU cache allocation in the double diagonal hit matching stage and vectorizing both the ungapped and gapped alignment stages.</p>

<h3 id="use-the-prefilter-for-clustering">Use the prefilter for clustering</h3>

<p>The prefiltering algorithm is not only useful for alignments, but also for sequence clustering, a task that is useful in for example creating biologically relevant train-test splits in machine learning. To cluster a sequence set with MMSeqs2, we run it either just through the prefiltering or optionally also through the alignment module and then use the output similarity graph as an input to a clustering algorithm of our choice.</p>

<p>If we choose the <code class="language-plaintext highlighter-rouge">easy-cluster</code> mode of MMSeqs2, it will just pass that similarity graph to a classic cascaded clustering algorithm. If we want to cluster large datasets, we can instead use the <code class="language-plaintext highlighter-rouge">easy-linclust</code> command that leverages the <a href="https://www.nature.com/articles/s41467-018-04964-5">Linclust</a> algorithm to cluster sequence sets in linear time, again using k-mer based analysis workflows.</p>

<p>Another cool property of MMSeqs2 clustering is the possibility to <a href="https://mmseqs.com/latest/userguide.pdf">add new sequences to an existing clustering while maintaining stable cluster identifiers</a>. eliminating the need to recluster the entire sequence set.</p>

<h2 id="foldseek-structural-clustering-of-the-protein-universe">FoldSeek: structural clustering of the protein universe</h2>

<ul>
  <li><a href="https://www.nature.com/articles/s41587-023-01773-0">Paper</a></li>
  <li><a href="https://www.youtube.com/watch?v=IJtWTxhuunk">Talk</a></li>
</ul>

<p>Sequence alignment as described before is one of the main pillars in bioinformatics and useful for a variety of applications, from detecting homology to creating training splits for machine learning models.</p>

<p>However, when talking about protein structure, sequence alignments do not always tell the full story: in many cases, proteins may have very different sequences but very similar structures. This could be due to remote homology such as in the case of <a href="https://www.nature.com/articles/35056591">ubiquitin and it’s mysterious cousin Sumo</a> which have been separated by more than 1 billion years of evolutionarity history but still are structurally strikingly similar despite a sequence identity of only 16%.</p>

<p>This makes the idea of <em>structural alignment</em> and <em>structural clustering</em> very appealing: with this, you could detect these remote homologies, enabling you to detect very remote homologies while also preventing your machine learning models that deal with protein structures from cheating via such examples.</p>

<p>However, structure alignment is quite complex: as described before, we can find an optimal solution for aligning a sequence of length <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>n</mi></mrow><annotation encoding="application/x-tex">n</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord mathnormal">n</span></span></span></span> to a sequence of length <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>m</mi></mrow><annotation encoding="application/x-tex">m</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord mathnormal">m</span></span></span></span> via dynamic programming in <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>O</mi><mo stretchy="false">(</mo><mi>n</mi><mi>m</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">O(nm)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.02778em;">O</span><span class="mopen">(</span><span class="mord mathnormal">nm</span><span class="mclose">)</span></span></span></span> time since we need <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>n</mi><mo>∗</mo><mi>m</mi></mrow><annotation encoding="application/x-tex">n*m</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4653em;"></span><span class="mord mathnormal">n</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">∗</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord mathnormal">m</span></span></span></span> operations to populate the whole dynamic programming matrix.</p>

<p>For structure alignment, the problem is a lot more complicated due to the <a href="https://arxiv.org/pdf/2307.02170.pdf"><em>absence of natural local bounds</em></a>: if we change a sequence alignment at some position, a previously aligned segment somewhere else stays unchanged. Since structural alignment operates via concerted 3D rototranslations, the introduction of gaps outside an aligned region might still affect the already aligned region due to residues that are close in 3D but far in sequence space.</p>

<p>Therefore, structural alignment algorithms like <a href="https://pubmed.ncbi.nlm.nih.gov/32006276/">Dali</a> based on distance matrices and <a href="https://www.ncbi.nlm.nih.gov/pmc/articles/PMC1084323/">TM-align</a> based on the TM-score are relatively slow, preventing their application on the new scale of data we face (TM-Align would need around a year to search through the AFDB on a single CPU core). Foldseek, on the other hand, is 4-5 orders of magnitude faster and therefore suitable for such large-scale searches.</p>

<h3 id="structure-to-sequence-the-3di-alphabet">Structure to Sequence: the 3Di alphabet</h3>

<p>How is that done? The main idea is to translate structural information into some kind of sequence-based representation that allows the use of fast sequence alignment tools. This has been tried before with tools like <a href="https://www.worldscientific.com/doi/abs/10.1142/S0219720008003461">CLePAPS</a> and <a href="https://www.tandfonline.com/doi/abs/10.1080/07391102.2013.787026?casa_token=w4NYfkI5VQcAAAAA:koU6p0Ju-1Ymkip1kZzljuiYCYiT3fpfTZPYD91mTWKP6RuVZLqQ5Khhc_Xp1IxXnI76-XtCYA">mulPBA</a>, but has not found widespread use due to them ony describing <em>secondary backbone structure</em>.These tools build on the three-letter code of helix, sheet and coil and refine it further by describing the backbone around a single residue by one of 10-20 letters. This increases the speed by reducing the problem to sequence alignment, but only captures helical and sheet-like regions well, while the large amount of information in loop regions is not captured well due to the structure there being mostly determined by <em>interactions</em> between different residues. In addition, neighboring residues are highly correlated (helices or sheet stretch for quite a bit in a protein), making that encoding even less informative.</p>

<p>FoldSeek does away with this and instead describes the <em>tertiary</em> instead of the backbone secondary structure via a 20-letter alphabet called 3Di. More specifically you do the following:</p>

<ol>
  <li>Select a residue to encode and its nearest 3D neighbor. They started defining “nearest” as “smallest CB-CB distance”, but then replaced that with the concept of a <em>virtual center</em> for reasons explained later.</li>
  <li>Get the CA atoms of these two residues as well as the CA atoms of the residues before and after them in the sequence (in total 6), extract distance- and angle-based features from this 6-atom constellation and collect them in a 10D-descriptor.</li>
  <li>Discretise this information into one of the 20 letters from the 3Di alphabet.</li>
</ol>

<p align="center">
  <img src="/assets/img/blog/prot_representation/foldseek_algo.png" width="90%" height="90%" />
</p>

<p class="figcaption">FoldSeek stages in part b of the figure. We will come back to part a. Source: <a href="https://www.nature.com/articles/s41587-023-01773-0">FoldSeek Paper</a></p>

<p>We will talk in more detail about step 1 and 3 of this process, but you can see how the resulting 3Di sequence can be fed into any sequence-based program to get a structural alignment or clustering. In the paper, the authors show that they can do that with similar sensitivity as actual structural alignment programs, but at a fraction of the computational cost.</p>

<h3 id="virtual-centers-optimise-conservation-of-interactions-and-tertiary-vs-local-interactions">Virtual centers optimise conservation of interactions and tertiary vs. local interactions</h3>

<p>The virtual center described above is determined by a <a href="https://static-content.springer.com/esm/art%3A10.1038%2Fs41587-023-01773-0/MediaObjects/41587_2023_1773_MOESM1_ESM.pdf">pre-specified procedure described in the SI (Suppl. Fig. 1)</a>:</p>
<ol>
  <li>It lies on the plane defined by N, CA and CB</li>
  <li>CB, CA and the virtual center form a 90 degree angle</li>
  <li>The CA-virtual center distance is twice the CA-CB distance</li>
</ol>

<p align="center">
  <img src="/assets/img/blog/prot_representation/virtual_center.png" width="50%" height="50%" />
</p>

<p class="figcaption">Construction of the virtual center in FoldSeek. Source: <a href="https://static-content.springer.com/esm/art%3A10.1038%2Fs41587-023-01773-0/MediaObjects/41587_2023_1773_MOESM1_ESM.pdf">(Suppl. Fig. 1)</a></p>

<p>In the case of glycine, a virtual CB is approximated by idealising the backbone geometry as a tetrahedron.</p>

<p>Why is this better than just taking the CB-CB distance? Two reasons:</p>

<ol>
  <li><strong>Conservation of Interactions</strong>: we want to make sure that in the case of structurally aligning two homologs, the nearest neighbor of residue <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>i</mi></mrow><annotation encoding="application/x-tex">i</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6595em;"></span><span class="mord mathnormal">i</span></span></span></span> in structure one should be the same as for residue <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>i</mi></mrow><annotation encoding="application/x-tex">i</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6595em;"></span><span class="mord mathnormal">i</span></span></span></span> in structure 2. If this would not be the case and we would choose a different nearest neighbor, the extracted 10D descriptor would look different, we would assign the residue different 3Di letters in the two structures and the structural alignment would fail. Empirically, they found that the CB-CB distance is not a great criterion for that and therefore came up with the virtual center definition that fulfills this desideratum more often.</li>
  <li><strong>Tertiary vs. local interactions</strong>: One of the downsides of the previous alphabets such as CLePAPS and mulPBA was that they have a lot of repeated information encoded by only describing local interactions as part of the secondary structure description (e.g. “these 10 residues all are in a helix”). If our 3Di alphabet ends up encoding mainly local interactions between neighbors in sequence (as would often be the case if we choose the CB-CB distance as criterion for nearest neighbor) then we end up in the same spot of mainly describing redundant local interactions. One can think about it from an information theoretic perspective in terms of mutual information: in the case of only encoding the amino acid identity, the mutual information between structurally aligned residues is the same no matter if we correct by correlation between neighbouring letters to account for local interactions or not. Other structural alphabets show a higher mutual information than pure amino acid encoding (i.e. performing only classic sequence alignment), but that difference shrinks a lot when we correct for the neighbor letter correlation. FoldSeek therefore aims to minimise the amount of local interactions it encodes and maximise the amount of tertiary interaction that is encoded. By moving the virtual center further away from the backbone and orienting it into a different direction than the CB, we achieve this goal of often encoding interactions between residues that are not neighbors in sequence.</li>
</ol>

<h3 id="learning-the-3di-alphabet-via-a-vq-vae">Learning the 3Di alphabet via a VQ-VAE</h3>

<p>Given the 10-dimensional descriptor that encodes distance- and angle-based features from the residue and its nearest neighbour as judged by the virtual center, how do we actually decide which of the 20 letters of the alphabet we assign this residue to? Well, one could do something simple like k-means clustering (which the authors started out with), but you can be smarter than that by considering the fact that our 3Di alphabet should learn <em>maximally conserved structural states</em> between homologs.</p>

<p>Therefore, the authors leverage a <a href="https://www.youtube.com/watch?v=1ZHzAOutcnw">VQ-VAE</a> (vector-quantized variational autoencoder) to learn the 3Di alphabet first encoding the 10D descriptor via 3-layer neural network encoder into a bottlenecked representation, than mapping it to one of the 20 discrete 3Di states (that is where the VQ part comes in) and then reconstructing the 10D descriptor again via a 3-layer neural network decoder. The crucial part here is that <em>the reconstruction target is not the input</em> as it is for classic VAEs. Just reconstructing the exact 10D descriptor could lead to overfitting on the exact values instead of encoding features that allow us to identify conserved states between structures. Therefore, our reconstruction target is the 10D descriptor of a <em>structurally aligned homolog</em>. The structural alignment in this case was part of training dataset preparation via one of the more expensive classical tools.</p>

<p>By targeting not the same 10D descriptor but the descriptor of a homolog, the VQ-VAE is forced to encode a discretised representation that is useful for identifying homologs, exactly the use case we are building this algorithm for. This procedure is quite clever and can be seen as similar to <a href="https://lilianweng.github.io/posts/2018-08-12-vae/">denoising autoencoders</a>, where instead of swapping out the output, the input is corrupted with some noise in order for the network to learn a useful representation and avoid overfitting.</p>

<h3 id="speeding-things-up-by-building-on-mmseqs2">Speeding things up by building on mmseqs2</h3>

<p>We have now trained our VQ-VAE and can use it to encode a protein structure into a 3Di sequence. We could just leave it there and leverage good-old dynamic programming via <a href="https://en.wikipedia.org/wiki/Smith%E2%80%93Waterman_algorithm">Smith-Waterman</a> to get local alignments. But the authors were aiming for speed, so they did not stop there and took inspiration from their MMseqs2 sequence aligner described above. In fact, they use exactly the same pipeline!</p>

<p align="center">
  <img src="/assets/img/blog/prot_representation/foldseek_algo.png" width="90%" height="90%" />
</p>

<p class="figcaption">In part a, we can see that Foldseek uses the same prefilter and alignment modules as MMseqs2. Source: <a href="https://www.nature.com/articles/s41587-023-01773-0">FoldSeek Paper</a></p>

<p>Since the 3Di representation is just a sequence, we can plug that sequence into the MMseqs2 prefilter and alignment modules and get ultra-fast structural alignment. We can benefit from the clever prefilter design as well as the hardware optimisations like <a href="https://en.wikipedia.org/wiki/Advanced_Vector_Extensions">AVX2</a> instructions, optimised CPU cache, vectorisation and so on.</p>

<h2 id="applications-clustering-the-protein-universe">Applications: clustering the protein universe</h2>

<p>Using the turbo tandem of MMSeqs2 and FoldSeek as well as integrating these advancements into structure prediction methods via ColabFold has led to a flurry of new research directions.</p>

<p>For one, both sequence and structure clustering is now possible on scales that were not imaginable before. The <a href="https://academic.oup.com/nar/article/45/D1/D170/2605730?login=false">Uniclust</a> databases was created by sequence-similarity based clustering via MMSeqs2 at 90%, 50% and 30% pairwise sequence similarity. The resulting databases showed better consistency of functional annotations than the corresponding UniRef databases, arguable due to the better clustering algorithms.</p>

<p>Using a combination of MMSeqs2 and Foldseek, it was possible to perform <a href="https://www.nature.com/articles/s41586-023-06510-w">clustering on the whole AlphaFold database</a>, identifying new putative homologs that demonstrate the value of such a resource for studying protein evolution and function on such a large scale.</p>

<p>Other applications opened up in phylogenetics, the <a href="https://www.ebi.ac.uk/training/online/courses/introduction-to-phylogenetics/what-is-phylogenetics/#:~:text=Phylogenetics%20is%20the%20study%20of,be%20referred%20to%20as%20taxa">study of evolutionary relationships among biological entities such as species or individuals</a>: the use of Foldseek enabled fast homology detection via <a href="https://www.biorxiv.org/content/10.1101/2023.12.12.571181v2.full.pdf">structural phylogenetics</a> for proteins in the <em>twilight zone</em>, meaning that their sequence similarity is already very low but remote homology via structural similarity is still possible. In another study, a combination of MMSeqs2, ColabFold and FoldSeek enabled <a href="https://link.springer.com/article/10.1186/s13059-023-02942-9">cross-phyla protein annotation</a>, a task considered very challenging. Even more, protein structure prediction methods themselves were improved by applying MMSeqs2 to the <a href="https://www.ncbi.nlm.nih.gov/sra/docs/">Sequence Reads Archive (SRA)</a>, resulting in <a href="https://cshperspectives.cshlp.org/content/early/2024/02/05/cshperspect.a041465.abstract">petabase-scale homology search</a> and the construction of better MSAs (seems like in protein structure prediction we are now back to the old game of “who has the bigger MSA”).</p>

<p>While the tools as they stand right now are amazing, the algorithms behind them can still be improved. This includes making the last Smith-Waterman alignment more efficient via algorithms such as <a href="https://academic.oup.com/bioinformatics/article/39/8/btad487/7236499">BlockAligner that uses adaptive dynamic programming with SIMD-acceleration</a>, or even making it <a href="https://academic.oup.com/bioinformatics/article/39/1/btac724/6820925?login=false">differentiable</a> in order to backpropagate through the MSA construction step and enable full end-to-end-learning.</p>

<p>At the same time, it is still worthwile looking for other approaches to these challenges. Some of them include <a href="https://www.mlsb.io/papers_2022/SWAMPNN_End_to_end_protein_structures_alignment.pdf">SWAMPNN structure alignment via ProteinMPNN</a> that is more sensitive than FoldSeek while still being faster than many of the classical algorithms, as well as <a href="https://www.biorxiv.org/content/10.1101/2023.11.26.568742v1">language models used to perform protein search and annotation</a>. All in all, one can say that we can now indeed do many of the things we can do with sequences also with structures, and it will be exciting to see the scientific discoveries that result from that endeavour!</p>]]></content><author><name>Kieran Didi</name><email>kieran.didi@gmail.com</email></author><category term="proteins" /><summary type="html"><![CDATA[How structure prediction changed the questions we ask and the tools we use]]></summary><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" url="/assets/img/blog/prot_representation/protein_pile.png" /><media:content medium="image" url="/assets/img/blog/prot_representation/protein_pile.png" xmlns:media="http://search.yahoo.com/mrss/" /></entry><entry><title type="html">How to accelerate PyTorch on your GPU</title><link href="/blog/ml/2024-03-04-cuda-kernels/" rel="alternate" type="text/html" title="How to accelerate PyTorch on your GPU" /><published>2024-03-04T00:00:00+00:00</published><updated>2024-04-07T17:56:12+00:00</updated><id>/blog/ml/cuda-kernels</id><content type="html" xml:base="/blog/ml/2024-03-04-cuda-kernels/"><![CDATA[<p>Recently the <a href="https://www.youtube.com/@CUDAMODE">CUDA MODE</a> lecture series started with some amazing talks about how you can use tools like CUDA or Triton to speed up your PyTorch programs (join the <a href="https://discord.com/invite/XsdDHGtk9N">Discord</a> in case you are interested to learn more). Here I want to summarise and review some of the concepts and tools from the lecture and write them together in a coherent blog post.</p>

<ul id="markdown-toc">
  <li><a href="#1-profiling" id="markdown-toc-1-profiling">1. Profiling</a>    <ul>
      <li><a href="#11-torchcudaevent" id="markdown-toc-11-torchcudaevent">1.1 <code class="language-plaintext highlighter-rouge">torch.cuda.Event</code></a></li>
      <li><a href="#12-torchautogradprofiler" id="markdown-toc-12-torchautogradprofiler">1.2 <code class="language-plaintext highlighter-rouge">torch.autograd.profiler</code></a></li>
      <li><a href="#13-torchprofiler" id="markdown-toc-13-torchprofiler">1.3 <code class="language-plaintext highlighter-rouge">torch.profiler</code></a></li>
      <li><a href="#14-ncu-profiler" id="markdown-toc-14-ncu-profiler">1.4 <code class="language-plaintext highlighter-rouge">ncu</code> profiler</a></li>
    </ul>
  </li>
  <li><a href="#2-integrating-cuda-kernels-into-pytorch" id="markdown-toc-2-integrating-cuda-kernels-into-pytorch">2. Integrating CUDA kernels into PyTorch</a>    <ul>
      <li><a href="#21-load_inline-function" id="markdown-toc-21-load_inline-function">2.1 <code class="language-plaintext highlighter-rouge">load_inline</code> function</a></li>
      <li><a href="#22-numba" id="markdown-toc-22-numba">2.2 Numba</a></li>
    </ul>
  </li>
  <li><a href="#3-integrate-triton-kernels-into-pytorch" id="markdown-toc-3-integrate-triton-kernels-into-pytorch">3. Integrate Triton kernels into PyTorch</a>    <ul>
      <li><a href="#31-using-triton" id="markdown-toc-31-using-triton">3.1 Using Triton</a></li>
      <li><a href="#32-debugging-triton" id="markdown-toc-32-debugging-triton">3.2 Debugging Triton</a></li>
      <li><a href="#33-triton-deep-dive" id="markdown-toc-33-triton-deep-dive">3.3 Triton Deep-Dive</a></li>
      <li><a href="#34-benchmarking-triton" id="markdown-toc-34-benchmarking-triton">3.4 Benchmarking Triton</a></li>
    </ul>
  </li>
  <li><a href="#4-torchcompile" id="markdown-toc-4-torchcompile">4. <code class="language-plaintext highlighter-rouge">torch.compile</code></a></li>
  <li><a href="#credits" id="markdown-toc-credits">Credits</a></li>
</ul>

<h2 id="1-profiling">1. Profiling</h2>

<p>Profiling is the process of measuring the time and resources that a program uses. It is a crucial step in the development of any software, as it allows you to identify bottlenecks and areas for improvement. In the context of GPU programming, profiling is especially important, as the performance of a GPU program can be highly dependent on factors such as memory access patterns, kernel launch configurations, and the specific hardware being used. It is also not trivial to profile GPU code, as the operations are executed asynchronously on the GPU and we cannot simply measure execution time like we would with CPU code. In the following sections are a few tools to get you started on that for PyTorch code (for this you need to have access to a GPU, e.g. via Google Colab or a local machine with a CUDA-enabled GPU).</p>

<h3 id="11-torchcudaevent">1.1 <code class="language-plaintext highlighter-rouge">torch.cuda.Event</code></h3>

<p>To profile the time a torch opertion takes, you can use <code class="language-plaintext highlighter-rouge">torch.cuda.Event</code>. We cannot use the <code class="language-plaintext highlighter-rouge">time</code> module for this, because the operations are executed asynchronously on the GPU. Let us write a short function to profile the time a function call takes:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>
<span class="k">def</span> <span class="nf">time_pytorch_function</span><span class="p">(</span><span class="n">func</span><span class="p">,</span> <span class="nb">input</span><span class="p">):</span>
    <span class="n">start</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="n">Event</span><span class="p">(</span><span class="n">enable_timing</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
    <span class="n">end</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="n">Event</span><span class="p">(</span><span class="n">enable_timing</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
    <span class="c1"># Warmup
</span>    <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">5</span><span class="p">):</span>
        <span class="n">func</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
    <span class="n">start</span><span class="p">.</span><span class="n">record</span><span class="p">()</span>
    <span class="n">func</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
    <span class="n">end</span><span class="p">.</span><span class="n">record</span><span class="p">()</span>
    <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="n">synchronize</span><span class="p">()</span>
    <span class="k">return</span> <span class="n">start</span><span class="p">.</span><span class="n">elapsed_time</span><span class="p">(</span><span class="n">end</span><span class="p">)</span>
</code></pre></div></div>
<p>We do a few warmup steps at the start to make sure that things like memory allocation calls, PyTorch’s JIT fuser and other things are not included in the timing. Then we record the start and end of the function call and synchronize the GPU to make sure that the timing is correct. For more details on these things see <a href="https://www.speechmatics.com/company/articles-and-news/timing-operations-in-pytorch">this blog post</a>.</p>

<p>Let’s try this with a simple toy example:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">b</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">100000</span><span class="p">,</span> <span class="mi">100000</span><span class="p">).</span><span class="n">cuda</span><span class="p">()</span>

<span class="k">def</span> <span class="nf">square_2</span><span class="p">(</span><span class="n">a</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">a</span> <span class="o">*</span> <span class="n">a</span>

<span class="k">def</span> <span class="nf">square_3</span><span class="p">(</span><span class="n">a</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">a</span> <span class="o">**</span> <span class="mi">2</span>

<span class="k">print</span><span class="p">(</span><span class="n">time_pytorch_function</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">square</span><span class="p">,</span> <span class="n">b</span><span class="p">))</span>
<span class="k">print</span><span class="p">(</span><span class="n">time_pytorch_function</span><span class="p">(</span><span class="n">square_2</span><span class="p">,</span> <span class="n">b</span><span class="p">))</span>
<span class="k">print</span><span class="p">(</span><span class="n">time_pytorch_function</span><span class="p">(</span><span class="n">square_3</span><span class="p">,</span> <span class="n">b</span><span class="p">))</span>
<span class="c1">#output:
# 3.2753279209136963
# 3.272671937942505
# 3.2755520343780518
</span></code></pre></div></div>
<p>We can see that the multiplication <code class="language-plaintext highlighter-rouge">a * a</code> is slightly faster than the power operation <code class="language-plaintext highlighter-rouge">a ** 2</code>. However, we have no idea why this is happening; it is the same operation, so are they using different CUDA kernels? We can use the <code class="language-plaintext highlighter-rouge">torch.autograd.profiler</code> to find out.</p>

<h3 id="12-torchautogradprofiler">1.2 <code class="language-plaintext highlighter-rouge">torch.autograd.profiler</code></h3>

<p>Fortunately, we do not have to write all profiling tools ourselves PyTorch has a built-in profiler. Let us look again at the same operations:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">print</span><span class="p">(</span><span class="s">"============="</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"Profiling torch.square"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"============="</span><span class="p">)</span>

<span class="c1"># Now profile each function using pytorch profiler
</span><span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">autograd</span><span class="p">.</span><span class="n">profiler</span><span class="p">.</span><span class="n">profile</span><span class="p">(</span><span class="n">use_cuda</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> <span class="k">as</span> <span class="n">prof</span><span class="p">:</span>
    <span class="n">torch</span><span class="p">.</span><span class="n">square</span><span class="p">(</span><span class="n">b</span><span class="p">)</span>

<span class="k">print</span><span class="p">(</span><span class="n">prof</span><span class="p">.</span><span class="n">key_averages</span><span class="p">().</span><span class="n">table</span><span class="p">(</span><span class="n">sort_by</span><span class="o">=</span><span class="s">"cuda_time_total"</span><span class="p">,</span> <span class="n">row_limit</span><span class="o">=</span><span class="mi">10</span><span class="p">))</span>

<span class="k">print</span><span class="p">(</span><span class="s">"============="</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"Profiling a * a"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"============="</span><span class="p">)</span>

<span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">autograd</span><span class="p">.</span><span class="n">profiler</span><span class="p">.</span><span class="n">profile</span><span class="p">(</span><span class="n">use_cuda</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> <span class="k">as</span> <span class="n">prof</span><span class="p">:</span>
    <span class="n">square_2</span><span class="p">(</span><span class="n">b</span><span class="p">)</span>

<span class="k">print</span><span class="p">(</span><span class="n">prof</span><span class="p">.</span><span class="n">key_averages</span><span class="p">().</span><span class="n">table</span><span class="p">(</span><span class="n">sort_by</span><span class="o">=</span><span class="s">"cuda_time_total"</span><span class="p">,</span> <span class="n">row_limit</span><span class="o">=</span><span class="mi">10</span><span class="p">))</span>

<span class="k">print</span><span class="p">(</span><span class="s">"============="</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"Profiling a ** 2"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"============="</span><span class="p">)</span>

<span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">autograd</span><span class="p">.</span><span class="n">profiler</span><span class="p">.</span><span class="n">profile</span><span class="p">(</span><span class="n">use_cuda</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> <span class="k">as</span> <span class="n">prof</span><span class="p">:</span>
    <span class="n">square_3</span><span class="p">(</span><span class="n">b</span><span class="p">)</span>

<span class="k">print</span><span class="p">(</span><span class="n">prof</span><span class="p">.</span><span class="n">key_averages</span><span class="p">().</span><span class="n">table</span><span class="p">(</span><span class="n">sort_by</span><span class="o">=</span><span class="s">"cuda_time_total"</span><span class="p">,</span> <span class="n">row_limit</span><span class="o">=</span><span class="mi">10</span><span class="p">))</span>
</code></pre></div></div>

<p>This gives us the following output:</p>

<p><img src="/assets/img/blog/gpu_profiling/simple_profiling.png" alt="Profiling output" /></p>

<p>We can see that <code class="language-plaintext highlighter-rouge">a * a</code> calls the faster <code class="language-plaintext highlighter-rouge">aten::mul</code> operation, while <code class="language-plaintext highlighter-rouge">a ** 2</code> calls the slower <code class="language-plaintext highlighter-rouge">aten::pow</code> operation, explaining our previous results.</p>

<p class="note" title="Aside">ATen is a C++ library that is part of the <a href="https://pytorch.org/cppdocs/">PyTorch C++ API</a>. It is the foundational tensor and math library on which PyTorch is built and exposes the Tensor operations in PyTorch <a href="https://pytorch.org/cppdocs/notes/tensor_basics.html">directly in C++</a>. ATen is a very creative name, as it stands for “A tensor library”. You can here more about the differences between the torch API and the ATen API in <a href="https://pytorch-dev-podcast.simplecast.com/episodes/torch-vs-aten-apis">this podcast episode</a>.</p>

<p>Let us now profile a simple neural network forward pass:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="n">nn</span>
<span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="n">F</span>

<span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">).</span><span class="n">cuda</span><span class="p">()</span>
<span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">autograd</span><span class="p">.</span><span class="n">profiler</span><span class="p">.</span><span class="n">profile</span><span class="p">(</span><span class="n">use_cuda</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> <span class="k">as</span> <span class="n">prof</span><span class="p">:</span>
  <span class="n">output</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">).</span><span class="n">cuda</span><span class="p">()(</span><span class="n">data</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">prof</span><span class="p">.</span><span class="n">key_averages</span><span class="p">().</span><span class="n">table</span><span class="p">(</span><span class="n">sort_by</span><span class="o">=</span><span class="s">"cuda_time_total"</span><span class="p">,</span> <span class="n">row_limit</span><span class="o">=</span><span class="mi">10</span><span class="p">))</span>
</code></pre></div></div>

<p>Which gives us the following output:</p>

<p><img src="/assets/img/blog/gpu_profiling/linear_profiling.png" alt="Profiling output" /></p>

<p>We can see that the <code class="language-plaintext highlighter-rouge">aten::linear</code> and the <code class="language-plaintext highlighter-rouge">aten::addmm</code> operation are the most time-consuming operations in this forward pass. In <a href="">another post</a> I dig into how one can find the actual implementation of these functions in the PyTorch codebase to understand what they actually do, but for it is enough to know that <code class="language-plaintext highlighter-rouge">aten::linear</code> is the operation that applies a linear transformation to the input data and <code class="language-plaintext highlighter-rouge">aten::addmm</code> is the operation that performs a matrix multiplication of the input data with the weight matrix and adds a bias term.</p>

<h3 id="13-torchprofiler">1.3 <code class="language-plaintext highlighter-rouge">torch.profiler</code></h3>

<p>Another, more visual way to profile your code is to use <code class="language-plaintext highlighter-rouge">torch.profiler</code>. This is a more high-level interface to the profiler and allows you to export the profiling data to a Chrome trace file. Here is an example of how to use it:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">from</span> <span class="nn">torch.profiler</span> <span class="kn">import</span> <span class="n">profile</span><span class="p">,</span> <span class="n">record_function</span><span class="p">,</span> <span class="n">ProfilerActivity</span>

<span class="k">def</span> <span class="nf">trace_handler</span><span class="p">(</span><span class="n">prof</span><span class="p">):</span>
    <span class="k">print</span><span class="p">(</span><span class="n">prof</span><span class="p">.</span><span class="n">key_averages</span><span class="p">().</span><span class="n">table</span><span class="p">(</span>
        <span class="n">sort_by</span><span class="o">=</span><span class="s">"self_cuda_time_total"</span><span class="p">,</span> <span class="n">row_limit</span><span class="o">=-</span><span class="mi">1</span><span class="p">))</span>
    <span class="n">prof</span><span class="p">.</span><span class="n">export_chrome_trace</span><span class="p">(</span><span class="s">"/tmp/test_trace_"</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">prof</span><span class="p">.</span><span class="n">step_num</span><span class="p">)</span> <span class="o">+</span> <span class="s">".json"</span><span class="p">)</span>

<span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">profiler</span><span class="p">.</span><span class="n">profile</span><span class="p">(</span>
    <span class="n">activities</span><span class="o">=</span><span class="p">[</span>
        <span class="n">torch</span><span class="p">.</span><span class="n">profiler</span><span class="p">.</span><span class="n">ProfilerActivity</span><span class="p">.</span><span class="n">CPU</span><span class="p">,</span>
        <span class="n">torch</span><span class="p">.</span><span class="n">profiler</span><span class="p">.</span><span class="n">ProfilerActivity</span><span class="p">.</span><span class="n">CUDA</span><span class="p">,],</span>    
    <span class="n">schedule</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">profiler</span><span class="p">.</span><span class="n">schedule</span><span class="p">(</span>
        <span class="n">wait</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
        <span class="n">warmup</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
        <span class="n">active</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span>
        <span class="n">repeat</span><span class="o">=</span><span class="mi">1</span><span class="p">),</span>
    <span class="n">on_trace_ready</span><span class="o">=</span><span class="n">trace_handler</span>
    <span class="p">)</span> <span class="k">as</span> <span class="n">p</span><span class="p">:</span>
        <span class="k">for</span> <span class="nb">iter</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">10</span><span class="p">):</span>
            <span class="n">output</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">).</span><span class="n">cuda</span><span class="p">()(</span><span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">).</span><span class="n">cuda</span><span class="p">())</span>
            <span class="c1"># send a signal to the profiler that the next iteration has started
</span>            <span class="n">p</span><span class="p">.</span><span class="n">step</span><span class="p">()</span>
</code></pre></div></div>

<p>We still get a terminal output:</p>

<p><img src="/assets/img/blog/gpu_profiling/pt_profiler.png" alt="Profiling output" /></p>

<p>However, we also get a Chrome trace file that we can open in Chrome to visualize the profiling data:</p>

<p><img src="/assets/img/blog/gpu_profiling/chrome_trace.png" alt="Profiling output" /></p>

<p>We can see that the majority of the time is actually spent on the cpu, moving data to the GPU, whereas the actual matrix multiplication is quite fast (and uses a special CUDA kernel called <code class="language-plaintext highlighter-rouge">volta_sgemm_32x32_sliced1x4_tn</code>).</p>

<h3 id="14-ncu-profiler">1.4 <code class="language-plaintext highlighter-rouge">ncu</code> profiler</h3>

<p>The <code class="language-plaintext highlighter-rouge">ncu</code> profiler is a command-line tool that comes with the CUDA toolkit. It is a very powerful tool that allows you to profile your CUDA kernels in great detail. You invoke it by running <code class="language-plaintext highlighter-rouge">ncu python script.py</code>. It will then run your script and profile all the CUDA kernels that are called. It will then generate a report in the form of a ncu_logs file that contains helpful numbers and recommendations on how to optimize your code.</p>

<p>A similar tool from the CUDA toolkit is <code class="language-plaintext highlighter-rouge">nsys</code>, which also allows you to profile your code. It is however less focused on detailed CUDA kernel performance analysis, but more the overall system-wide performance, as well as understanding how the communication between CPU and GPU impacts performance.</p>

<p>We can mark code we want to profile via the <code class="language-plaintext highlighter-rouge">torch.cuda.nvtx</code> API that allows us to start capturing via <code class="language-plaintext highlighter-rouge">range_push()</code> and stop capturing via <code class="language-plaintext highlighter-rouge">range_pop()</code>. In the code example below, we profile a single linear layer; we also delay tehs tart of profiling until iteration 10 to allow for warm-up time.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="n">nn</span>
<span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="n">F</span>

<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">20</span><span class="p">):</span>
    <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="mi">10</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="n">cudart</span><span class="p">().</span><span class="n">cudaProfilerStart</span><span class="p">()</span>
    <span class="k">if</span> <span class="n">i</span> <span class="o">&gt;=</span> <span class="mi">10</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="n">nvtx</span><span class="p">.</span><span class="n">range_push</span><span class="p">(</span><span class="sa">f</span><span class="s">"Iteration </span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
    <span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">).</span><span class="n">cuda</span><span class="p">()</span>
    <span class="n">output</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">).</span><span class="n">cuda</span><span class="p">()(</span><span class="n">data</span><span class="p">)</span>
    <span class="k">if</span> <span class="n">i</span> <span class="o">&gt;=</span> <span class="mi">10</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="n">nvtx</span><span class="p">.</span><span class="n">range_pop</span><span class="p">()</span>
</code></pre></div></div>

<p>The call to <code class="language-plaintext highlighter-rouge">torch.cuda.cudart().cudaProfilerStart()</code> indicates to NSys to only care about profiling from this iteration on.</p>

<p>To get the profiling output now, we need to install and use the <code class="language-plaintext highlighter-rouge">nsys</code> toolkit. There are <a href="https://gist.github.com/mcarilli/376821aa1a7182dfcf59928a7cde3223">many CLI options you can choose for it</a>, but one of the simplest calls might be <code class="language-plaintext highlighter-rouge">nsys profile -o output_profile python script.py</code>. This will produce a file called <code class="language-plaintext highlighter-rouge">output_profile.nsys-rep</code> which you can then open in the NSight Systems UI (if you run your profiling on a remote machine, transfer the report file your local machine so that you can run the GUI application). For the a simple linear layer it will look something like this:</p>

<p><img src="/assets/img/blog/gpu_profiling/nsys_linear.jpeg" alt="NSys Linear Layer Report" /></p>

<p class="figcaption">NSys Profiling report for a single linear layer in PyTorch.</p>

<p>We can see that the actual computation only takes a bit of time, while there is a long time before that gets spent on data transfer via calls to the CUDA API like <code class="language-plaintext highlighter-rouge">MemCopy</code>. Only at the end is the <code class="language-plaintext highlighter-rouge">ampere_sgemm_32x32_sliced</code> kernel called that performs the actual matrix multiplication in tiles of 32 by 32.</p>

<p>To profile more complex code like a whole ResNet for example, we can either set the profiling points still manually as described in <a href="https://dev-discuss.pytorch.org/t/using-nsight-systems-to-profile-gpu-workload/59">this community post</a> or we can use tools such as <a href="https://github.com/zasdfgbnm/autonvtx">autonvtx</a> that just wrap our model and deal with the profiling setup for us. Doing this for a simple ResNet results in the following profiler output:</p>

<p><img src="/assets/img/blog/gpu_profiling/nsys_resnet.jpeg" alt="NSys ResNet Report" /></p>

<p class="figcaption">NSys Profiling report for a ResnNet in PyTorch.</p>

<p>In this case we can see that a way bigger chunck of time is spent on CUDA calls and actual computation. We also see that there are calls to the <code class="language-plaintext highlighter-rouge">cuDNN</code> backend for operations such as batch normalization.</p>

<p>NSys can seem overwhelming and is a bit more overhead to get set up compared to the options presented before, but it can give you some detailed insights as well as suggestions what kind of things to improve in your code.</p>

<h2 id="2-integrating-cuda-kernels-into-pytorch">2. Integrating CUDA kernels into PyTorch</h2>

<p>CUDA is written in C++ and is a parallel computing platform and application programming interface (API) model created by Nvidia. It allows software developers to use a CUDA-enabled graphics processing unit (GPU) for general purpose processing. Since it is written in C++, it is not immediatly obvious how to integrate it into our ML code that is normally written in Python libraries like PyTorch. However, there are several options.</p>

<h3 id="21-load_inline-function">2.1 <code class="language-plaintext highlighter-rouge">load_inline</code> function</h3>

<p>The easiest way to integrate CUDA kernels into PyTorch is to use the <code class="language-plaintext highlighter-rouge">torch.utils.cpp_extension</code> module. This module allows you to compile C++ code into a shared library and then load it into Python. Here is an example of how to do this via the <code class="language-plaintext highlighter-rouge">load_inline</code> function for a simple matrix squaring operation:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">from</span> <span class="nn">torch.utils.cpp_extension</span> <span class="kn">import</span> <span class="n">load_inline</span>

<span class="c1"># Define the CUDA kernel and C++ wrapper
</span><span class="n">cuda_source</span> <span class="o">=</span> <span class="s">'''
__global__ void square_matrix_kernel(const float* matrix, float* result, int width, int height) {
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;

    if (row &lt; height &amp;&amp; col &lt; width) {
        int idx = row * width + col;
        result[idx] = matrix[idx] * matrix[idx];
    }
}

torch::Tensor square_matrix(torch::Tensor matrix) {
    const auto height = matrix.size(0);
    const auto width = matrix.size(1);

    auto result = torch::empty_like(matrix);

    dim3 threads_per_block(16, 16);
    dim3 number_of_blocks((width + threads_per_block.x - 1) / threads_per_block.x,
                          (height + threads_per_block.y - 1) / threads_per_block.y);

    square_matrix_kernel&lt;&lt;&lt;number_of_blocks, threads_per_block&gt;&gt;&gt;(
        matrix.data_ptr&lt;float&gt;(), result.data_ptr&lt;float&gt;(), width, height);

    return result;
    }
'''</span>

<span class="n">cpp_source</span> <span class="o">=</span> <span class="s">"torch::Tensor square_matrix(torch::Tensor matrix);"</span>

<span class="c1"># Load the CUDA kernel as a PyTorch extension
</span><span class="n">square_matrix_extension</span> <span class="o">=</span> <span class="n">load_inline</span><span class="p">(</span>
    <span class="n">name</span><span class="o">=</span><span class="s">'square_matrix_extension'</span><span class="p">,</span>
    <span class="n">cpp_sources</span><span class="o">=</span><span class="n">cpp_source</span><span class="p">,</span>
    <span class="n">cuda_sources</span><span class="o">=</span><span class="n">cuda_source</span><span class="p">,</span>
    <span class="n">functions</span><span class="o">=</span><span class="p">[</span><span class="s">'square_matrix'</span><span class="p">],</span>
    <span class="n">with_cuda</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
    <span class="n">extra_cuda_cflags</span><span class="o">=</span><span class="p">[</span><span class="s">"-O2"</span><span class="p">],</span>
    <span class="c1"># build_directory='./load_inline_cuda',
</span>    <span class="c1"># extra_cuda_cflags=['--expt-relaxed-constexpr']
</span><span class="p">)</span>

<span class="n">a</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mf">1.</span><span class="p">,</span> <span class="mf">2.</span><span class="p">,</span> <span class="mf">3.</span><span class="p">],</span> <span class="p">[</span><span class="mf">4.</span><span class="p">,</span> <span class="mf">5.</span><span class="p">,</span> <span class="mf">6.</span><span class="p">]],</span> <span class="n">device</span><span class="o">=</span><span class="s">'cuda'</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">square_matrix_extension</span><span class="p">.</span><span class="n">square_matrix</span><span class="p">(</span><span class="n">a</span><span class="p">))</span>

<span class="c1"># Output:
# tensor([[ 1.,  4.,  9.],
#         [16., 25., 36.]], device='cuda:0')
</span></code></pre></div></div>

<p>We see that the output is the same as if we used a PyTorch function. If we want to inspect the generated code, we can set the <code class="language-plaintext highlighter-rouge">build_directory</code> argument of the <code class="language-plaintext highlighter-rouge">load_inline</code> function to see the generated code in the specified directory.</p>

<h3 id="22-numba">2.2 Numba</h3>

<p>Another way to integrate CUDA kernels into PyTorch is to use the <code class="language-plaintext highlighter-rouge">numba</code> library. This is a just-in-time (JIT) compiler that translates Python functions to optimized machine code at runtime using the industry-standard LLVM compiler library. It can also be used to generate CUDA kernels.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">numba</span> <span class="kn">import</span> <span class="n">cuda</span>

<span class="c1"># CUDA kernel
</span><span class="o">@</span><span class="n">cuda</span><span class="p">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">square_matrix_kernel</span><span class="p">(</span><span class="n">matrix</span><span class="p">,</span> <span class="n">result</span><span class="p">):</span>
    <span class="c1"># Calculate the row and column index for each thread
</span>    <span class="n">row</span><span class="p">,</span> <span class="n">col</span> <span class="o">=</span> <span class="n">cuda</span><span class="p">.</span><span class="n">grid</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span>

    <span class="c1"># Check if the thread's indices are within the bounds of the matrix
</span>    <span class="k">if</span> <span class="n">row</span> <span class="o">&lt;</span> <span class="n">matrix</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="ow">and</span> <span class="n">col</span> <span class="o">&lt;</span> <span class="n">matrix</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]:</span>
        <span class="c1"># Perform the square operation
</span>        <span class="n">result</span><span class="p">[</span><span class="n">row</span><span class="p">,</span> <span class="n">col</span><span class="p">]</span> <span class="o">=</span> <span class="n">matrix</span><span class="p">[</span><span class="n">row</span><span class="p">,</span> <span class="n">col</span><span class="p">]</span> <span class="o">**</span> <span class="mi">2</span>

<span class="c1"># Example usage
</span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>

<span class="c1"># Create a sample matrix
</span><span class="n">matrix</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">],</span> <span class="p">[</span><span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">]],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span>

<span class="c1"># Allocate memory on the device
</span><span class="n">d_matrix</span> <span class="o">=</span> <span class="n">cuda</span><span class="p">.</span><span class="n">to_device</span><span class="p">(</span><span class="n">matrix</span><span class="p">)</span>
<span class="n">d_result</span> <span class="o">=</span> <span class="n">cuda</span><span class="p">.</span><span class="n">device_array</span><span class="p">(</span><span class="n">matrix</span><span class="p">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span>

<span class="c1"># Configure the blocks
</span><span class="n">threads_per_block</span> <span class="o">=</span> <span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">16</span><span class="p">)</span>
<span class="n">blocks_per_grid_x</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">matrix</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="n">threads_per_block</span><span class="p">[</span><span class="mi">0</span><span class="p">]))</span>
<span class="n">blocks_per_grid_y</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">matrix</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">/</span> <span class="n">threads_per_block</span><span class="p">[</span><span class="mi">1</span><span class="p">]))</span>
<span class="n">blocks_per_grid</span> <span class="o">=</span> <span class="p">(</span><span class="n">blocks_per_grid_x</span><span class="p">,</span> <span class="n">blocks_per_grid_y</span><span class="p">)</span>

<span class="c1"># Launch the kernel
</span><span class="n">square_matrix_kernel</span><span class="p">[</span><span class="n">blocks_per_grid</span><span class="p">,</span> <span class="n">threads_per_block</span><span class="p">](</span><span class="n">d_matrix</span><span class="p">,</span> <span class="n">d_result</span><span class="p">)</span>

<span class="c1"># Copy the result back to the host
</span><span class="n">result</span> <span class="o">=</span> <span class="n">d_result</span><span class="p">.</span><span class="n">copy_to_host</span><span class="p">()</span>

<span class="c1"># Result is now in 'result' array
</span><span class="k">print</span><span class="p">(</span><span class="n">matrix</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">result</span><span class="p">)</span>
</code></pre></div></div>

<h2 id="3-integrate-triton-kernels-into-pytorch">3. Integrate Triton kernels into PyTorch</h2>

<h3 id="31-using-triton">3.1 Using Triton</h3>

<p><a href="https://openai.com/research/triton">Triton</a> is both a domain-specific language (DSL) and a compiler for writing highly efficient GPU code. It actually does not generate CUDA code, but PTX code, which is a lower-level intermediate representation of the CUDA code (basically the assembly language of CUDA). Newer features in PyTorch like <code class="language-plaintext highlighter-rouge">torch.compile</code> actually <a href="https://pytorch.org/assets/pytorch2-2.pdf">leverage Triton kernels under the hood</a>, so it is worth understanding how it works. Since Triton is written in Python, it is easy to integrate with PyTorch. Here is an example of how to use Triton to write a simple matrix squaring operation:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># Adapted straight from https://triton-lang.org/main/getting-started/tutorials/02-fused-softmax.html
</span><span class="kn">import</span> <span class="nn">triton</span>
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="n">tl</span>
<span class="kn">import</span> <span class="nn">torch</span>

<span class="o">@</span><span class="n">triton</span><span class="p">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">square_kernel</span><span class="p">(</span><span class="n">output_ptr</span><span class="p">,</span> <span class="n">input_ptr</span><span class="p">,</span> <span class="n">input_row_stride</span><span class="p">,</span> <span class="n">output_row_stride</span><span class="p">,</span> <span class="n">n_cols</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="p">:</span> <span class="n">tl</span><span class="p">.</span><span class="n">constexpr</span><span class="p">):</span>
    <span class="c1"># The rows of the softmax are independent, so we parallelize across those
</span>    <span class="n">row_idx</span> <span class="o">=</span> <span class="n">tl</span><span class="p">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
    <span class="c1"># The stride represents how much we need to increase the pointer to advance 1 row
</span>    <span class="n">row_start_ptr</span> <span class="o">=</span> <span class="n">input_ptr</span> <span class="o">+</span> <span class="n">row_idx</span> <span class="o">*</span> <span class="n">input_row_stride</span>
    <span class="c1"># The block size is the next power of two greater than n_cols, so we can fit each
</span>    <span class="c1"># row in a single block
</span>    <span class="n">col_offsets</span> <span class="o">=</span> <span class="n">tl</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="p">)</span>
    <span class="n">input_ptrs</span> <span class="o">=</span> <span class="n">row_start_ptr</span> <span class="o">+</span> <span class="n">col_offsets</span>
    <span class="c1"># Load the row into SRAM, using a mask since BLOCK_SIZE may be &gt; than n_cols
</span>    <span class="n">row</span> <span class="o">=</span> <span class="n">tl</span><span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="n">input_ptrs</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">col_offsets</span> <span class="o">&lt;</span> <span class="n">n_cols</span><span class="p">,</span> <span class="n">other</span><span class="o">=-</span><span class="nb">float</span><span class="p">(</span><span class="s">'inf'</span><span class="p">))</span>

    <span class="n">square_output</span> <span class="o">=</span> <span class="n">row</span> <span class="o">*</span> <span class="n">row</span>
    
    <span class="c1"># Write back output to DRAM
</span>    <span class="n">output_row_start_ptr</span> <span class="o">=</span> <span class="n">output_ptr</span> <span class="o">+</span> <span class="n">row_idx</span> <span class="o">*</span> <span class="n">output_row_stride</span>
    <span class="n">output_ptrs</span> <span class="o">=</span> <span class="n">output_row_start_ptr</span> <span class="o">+</span> <span class="n">col_offsets</span>
    <span class="n">tl</span><span class="p">.</span><span class="n">store</span><span class="p">(</span><span class="n">output_ptrs</span><span class="p">,</span> <span class="n">square_output</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">col_offsets</span> <span class="o">&lt;</span> <span class="n">n_cols</span><span class="p">)</span>


<span class="k">def</span> <span class="nf">square</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
    <span class="n">n_rows</span><span class="p">,</span> <span class="n">n_cols</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="n">shape</span>
    <span class="c1"># The block size is the smallest power of two greater than the number of columns in x
</span>    <span class="n">BLOCK_SIZE</span> <span class="o">=</span> <span class="n">triton</span><span class="p">.</span><span class="n">next_power_of_2</span><span class="p">(</span><span class="n">n_cols</span><span class="p">)</span>
    <span class="c1"># Another trick we can use is to ask the compiler to use more threads per row by
</span>    <span class="c1"># increasing the number of warps (num_warps) over which each row is distributed.
</span>    <span class="c1"># You will see in the next tutorial how to auto-tune this value in a more natural
</span>    <span class="c1"># way so you don't have to come up with manual heuristics yourself.
</span>    <span class="n">num_warps</span> <span class="o">=</span> <span class="mi">4</span>
    <span class="k">if</span> <span class="n">BLOCK_SIZE</span> <span class="o">&gt;=</span> <span class="mi">2048</span><span class="p">:</span>
        <span class="n">num_warps</span> <span class="o">=</span> <span class="mi">8</span>
    <span class="k">if</span> <span class="n">BLOCK_SIZE</span> <span class="o">&gt;=</span> <span class="mi">4096</span><span class="p">:</span>
        <span class="n">num_warps</span> <span class="o">=</span> <span class="mi">16</span>
    <span class="c1"># Allocate output
</span>    <span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
    <span class="c1"># Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o
</span>    <span class="c1"># f the input matrix
</span>    <span class="n">square_kernel</span><span class="p">[(</span><span class="n">n_rows</span><span class="p">,</span> <span class="p">)](</span>
        <span class="n">y</span><span class="p">,</span>
        <span class="n">x</span><span class="p">,</span>
        <span class="n">x</span><span class="p">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
        <span class="n">y</span><span class="p">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
        <span class="n">n_cols</span><span class="p">,</span>
        <span class="n">num_warps</span><span class="o">=</span><span class="n">num_warps</span><span class="p">,</span>
        <span class="n">BLOCK_SIZE</span><span class="o">=</span><span class="n">BLOCK_SIZE</span><span class="p">,</span>
    <span class="p">)</span>
    <span class="k">return</span> <span class="n">y</span>


<span class="n">torch</span><span class="p">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">1823</span><span class="p">,</span> <span class="mi">781</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s">'cuda'</span><span class="p">)</span>
<span class="n">y_triton</span> <span class="o">=</span> <span class="n">square</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">y_torch</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">square</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">torch</span><span class="p">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">y_triton</span><span class="p">,</span> <span class="n">y_torch</span><span class="p">),</span> <span class="p">(</span><span class="n">y_triton</span><span class="p">,</span> <span class="n">y_torch</span><span class="p">)</span>
</code></pre></div></div>

<p>We see that the output of the Triton kernel is the same as the output of the PyTorch function.</p>

<h3 id="32-debugging-triton">3.2 Debugging Triton</h3>

<p>Once we go to compiled code, we hopefully gain speed, but loose some of the flexibility that comes with eager execution, e.g. easy debugging via <code class="language-plaintext highlighter-rouge">pdb</code> and other Python debuggers or simple <code class="language-plaintext highlighter-rouge">print</code> statements.</p>

<p>Fortunately, Triton has a debugger now: we can invoke it by changing the <code class="language-plaintext highlighter-rouge">triton.jit</code> decorator to <code class="language-plaintext highlighter-rouge">triton.jit(interpret=True)</code>. This will allow you to set normal <code class="language-plaintext highlighter-rouge">Python</code> breakpoints and step through the kernel line by line.</p>

<p class="note" title="Attention">The <code class="language-plaintext highlighter-rouge">interpret=True</code> option was recently deprecated, so you can instead use <code class="language-plaintext highlighter-rouge">os.environ["TRITON_INTERPRET"] = "1"</code>.</p>

<p>When doing that, you will see that most objects in the kernel are of the type <code class="language-plaintext highlighter-rouge">WrappedTensor</code>. So if you want to inspect a variable, you have to access its <code class="language-plaintext highlighter-rouge">.tensor</code> attribute.</p>

<p>Let’s look at this in action with a simple vector addition kernel from the <a href="https://triton-lang.org/main/index.html">Triton Docs</a>.</p>

<p>If you do not have a GPU available, you can run this code in a Google Colab by first choosing a GPU runtime and then executing the following lines to get the latest Triton version and set up your CUDA libraries correctly:</p>
<ul class="note" title="Attention">
  <li><code class="language-plaintext highlighter-rouge">!ldconfig /usr/lib64-nvidia</code></li>
  <li><code class="language-plaintext highlighter-rouge">!ldconfig -p | grep libcud</code></li>
  <li><code class="language-plaintext highlighter-rouge">!pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly</code></li>
</ul>

<p>Let us implement a simple vector addition kernel together with a helper function to call the kernel as well as some code to generate data and call that function:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">triton</span>
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="n">tl</span>
<span class="kn">import</span> <span class="nn">torch</span>

<span class="o">@</span><span class="n">triton</span><span class="p">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">add_kernel</span><span class="p">(</span><span class="n">x_ptr</span><span class="p">,</span> <span class="n">y_ptr</span><span class="p">,</span> <span class="n">output_ptr</span><span class="p">,</span> <span class="n">n_elements</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="p">:</span> <span class="n">tl</span><span class="p">.</span><span class="n">constexpr</span><span class="p">):</span>
    <span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="p">.</span><span class="n">program_id</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
    <span class="n">breakpoint</span><span class="p">()</span>
    <span class="n">block_start</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK_SIZE</span>
    <span class="n">offsets</span> <span class="o">=</span> <span class="n">block_start</span> <span class="o">+</span> <span class="n">tl</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="p">)</span>
    <span class="n">mask</span> <span class="o">=</span> <span class="n">offsets</span> <span class="o">&lt;</span> <span class="n">n_elements</span>
    <span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="n">x_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
    <span class="n">y</span> <span class="o">=</span> <span class="n">tl</span><span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="n">y_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
    <span class="n">output</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
    <span class="n">tl</span><span class="p">.</span><span class="n">store</span><span class="p">(</span><span class="n">output_ptr</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>

<span class="k">def</span> <span class="nf">add</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">):</span>
    <span class="n">output</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
    <span class="k">assert</span> <span class="n">x</span><span class="p">.</span><span class="n">is_cuda</span> <span class="ow">and</span> <span class="n">y</span><span class="p">.</span><span class="n">is_cuda</span> <span class="ow">and</span> <span class="n">output</span><span class="p">.</span><span class="n">is_cuda</span>
    <span class="n">n_elements</span> <span class="o">=</span> <span class="n">output</span><span class="p">.</span><span class="n">numel</span><span class="p">()</span>
    <span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">meta</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="p">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">n_elements</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s">'BLOCK_SIZE'</span><span class="p">]),</span> <span class="p">)</span>
    <span class="n">add_kernel</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">n_elements</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="o">=</span><span class="mi">1024</span><span class="p">)</span>
    <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="n">synchronize</span><span class="p">()</span>
    <span class="k">return</span> <span class="n">output</span>

<span class="n">torch</span><span class="p">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">size</span> <span class="o">=</span> <span class="mi">98432</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">rand</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s">'cuda'</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">rand</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s">'cuda'</span><span class="p">)</span>
<span class="n">output_torch</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
<span class="n">output_triton</span> <span class="o">=</span> <span class="n">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">'The maximum difference between torch and triton is '</span>
      <span class="sa">f</span><span class="s">'</span><span class="si">{</span><span class="n">torch</span><span class="p">.</span><span class="nb">max</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="nb">abs</span><span class="p">(</span><span class="n">output_torch</span> <span class="o">-</span> <span class="n">output_triton</span><span class="p">))</span><span class="si">}</span><span class="s">'</span><span class="p">)</span>
</code></pre></div></div>

<p>Does not seem to complicated, but what are all these in-built Triton variables like <code class="language-plaintext highlighter-rouge">tl.program_id</code>? What do the offsets look like? And what values do my data pointers have? If we try to set  <code class="language-plaintext highlighter-rouge">breakpoint()</code> to enter the <a href="https://docs.python.org/3/library/pdb.html">PDB debugger</a>, we get a <code class="language-plaintext highlighter-rouge">NameError</code>.</p>

<p>To answer these questions, <code class="language-plaintext highlighter-rouge">import os</code> and set the interpret flag for Triton to true: <code class="language-plaintext highlighter-rouge">os.environ["TRITON_INTERPRET"] = "1"</code>. Now our <code class="language-plaintext highlighter-rouge">breakpoint()</code> works like a charm and we can interactively debug our Triton kernel (even inside a notebook!).</p>

<p>Via this, we learn for example that our <code class="language-plaintext highlighter-rouge">pid</code> is 0 in the first iteration, 1 in the second and so on! These iterations correspond to the workgroup/tile id, similar to the <code class="language-plaintext highlighter-rouge">blockIdx</code> in CUDA.</p>

<p>We also see that the offsets are a contiguous array of indices that are used to later access the vectors. We can also see that <code class="language-plaintext highlighter-rouge">x_ptr</code> and <code class="language-plaintext highlighter-rouge">y_ptr</code> contain memory addresses. So what happens is that in <code class="language-plaintext highlighter-rouge">x = tl.load(x_ptr + offsets, mask=mask)</code>, Triton loads the whole block of memory from <code class="language-plaintext highlighter-rouge">x_ptr</code> and including all the offset locations. The compiler here makes sure that these memory accesses are efficient via e.g. memory coalesence.</p>

<h3 id="33-triton-deep-dive">3.3 Triton Deep-Dive</h3>

<p>What does Triton do under the hood? It converts the Python code first into a custom Triton IR and then via the Triton compiler into the well-known LLVM-IR. From there PTX code is generated. Basically, Triton leverages LLVM heavily and (quote from the paper) “just a few data- and control-flow extensions to LLVM-IR could enable various tile-level optimization passes which jointly lead to performance on-par with vendor libraries.” These extensions allow Triton to do things like shared memory allocation or memory coalescence, things that in CUDA the GPU programmer has to handle manually.</p>

<p><img src="/assets/img/blog/gpu_profiling/triton.png" alt="Triton under the hood" /></p>

<p class="figcaption">From <a href="https://international.binus.ac.id/computer-science/2022/09/02/openai-proposes-open-source-triton-language-as-an-alternative-to-nvidias-cuda/">this news article</a></p>

<p>We can look at all these different intermediate representations by saving the compiled kernell to a variable and then accessing the <code class="language-plaintext highlighter-rouge">asm</code> field that contains the IRs for various levels.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">compiled</span> <span class="o">=</span> <span class="n">add_kernel</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">n_elements</span><span class="p">,</span> <span class="n">BLOCK_SIZE</span><span class="o">=</span><span class="mi">1024</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"IR"</span><span class="p">,</span> <span class="n">compiled</span><span class="p">.</span><span class="n">asm</span><span class="p">[</span><span class="s">'ttir'</span><span class="p">])</span>
<span class="k">print</span><span class="p">(</span><span class="s">"TTGIR"</span><span class="p">,</span> <span class="n">compiled</span><span class="p">.</span><span class="n">asm</span><span class="p">[</span><span class="s">'ttgir'</span><span class="p">])</span>
<span class="k">print</span><span class="p">(</span><span class="s">"LLIR"</span><span class="p">,</span> <span class="n">compiled</span><span class="p">.</span><span class="n">asm</span><span class="p">[</span><span class="s">'llir'</span><span class="p">])</span>
<span class="k">print</span><span class="p">(</span><span class="s">"PTX"</span><span class="p">,</span> <span class="n">compiled</span><span class="p">.</span><span class="n">asm</span><span class="p">[</span><span class="s">'ptx'</span><span class="p">])</span>
</code></pre></div></div>

<ol>
  <li><a href="/assets/img/blog/gpu_profiling/ttir.txt"><strong>TTIR (Intermediate Representation)</strong></a>: This IR is what people generally refer to when they say Triton IR. Inspecting it you see that it looks relatively similar to the original Triton code, just with many of the operations split up into more fundamental steps like initialising constants, loading and broadcasting data and finally (after computation) storing it again. We see that our original kernel is now wrapped in an <code class="language-plaintext highlighter-rouge">IR module</code> as a <code class="language-plaintext highlighter-rouge">tt.func public @kernel_name</code>.</li>
  <li><a href="/assets/img/blog/gpu_profiling/ttgir.txt"><strong>TTGIR (Triton Thread-Group Intermediate Representation)</strong></a>: Triton can be used for different accelerators, and the GPU is one of them. In that case, Triton will lower TTIR into TTGIR, where GPU-specific operations like thread synchronizations, call coalescences and shared memory allocations are performed.</li>
  <li><a href="/assets/img/blog/gpu_profiling/llir.txt"><strong>LLIR (Low-Level Intermediate Representation)</strong></a>: After TTGIR, the code is transformed into LLIR, the lowest level of IR. If we inspec the LLIR, we can see at the start the we use a <code class="language-plaintext highlighter-rouge">LLVMDialectModule</code>. This indicates that the IR we are talking about is the LLVM IR, part of a larger collections of module and reusable compiler technologies as part of the <a href="https://llvm.org/docs/tutorial/MyFirstLanguageFrontend/index.html">LLVM project</a>. The idea is that no matter from which IR we lower into LLVM, we can use this IR to translate the code into different backends (for example into machine code for NVIDIA or AMD GPUs). The fact that we use a <code class="language-plaintext highlighter-rouge">LLVMDialectModule</code> hints that we do not only leverage LLVM, but the <em>dialect</em> part hints at the use of <a href="https://mlir.llvm.org/">MLIR</a>, a successor project that tries to unify the toolset of not only the IR to backend process, but also the toolset to create these IRs in the first place. You can read more about MLIR <a href="https://arxiv.org/pdf/2002.11054.pdf">in the original paper</a>, <a href="https://llvm.org/devmtg/2020-09/slides/MLIR_Tutorial.pdf">this developer presentation</a> or <a href="http://lastweek.io/notes/MLIR/">this blogpost</a>.</li>
  <li><a href="/assets/img/blog/gpu_profiling/ptx.txt"><strong>PTX (Parallel Thread Execution, also NVPTX)</strong></a>: <a href="https://en.wikipedia.org/wiki/Parallel_Thread_Execution">PTX</a> is now the ISA (<a href="https://en.wikipedia.org/wiki/Instruction_set_architecture">instruction set architecture</a>) used in NVIDIA GPUs. If we normally write CUDA kernels, the NVCC compiler translates CUDA C++ code into PTX; here, Triton ends up at the same destination via a different route passing Triton IR and LLVM IR. PTX is now a proper assembly language represented in ASCII text specific for NVIDIA GPUs that contain compilers in their graphic drivers to the assembly language <a href="https://news.ycombinator.com/item?id=36168678">SASS</a>, which is specific for each different graphics card to enable device-specific optimisations. This code is then finally transformed into binary code and executed by the GPU.</li>
</ol>

<p><img src="/assets/img/blog/triton/triton_compiler_pipeline.jpeg" alt="Triton Compiler Pipeline" /></p>

<p class="figcaption">Triton Compiler Pipeline (<a href="https://www.youtube.com/watch?v=AtbnRIzpwho">Link</a>)</p>

<p>Looking at the Triton Compiler Pipeline from Triton IR to LLVM IR, we see that many of the optimizations we specify in CUDA are performed in this transformation process; for example memory coalescence, matmul acceleration and layout adaptions.</p>

<p>The interesting part about Triton is that it is not limited to a specific set of hardware architectures, but can in principle be used for a variety of <a href="https://en.wikipedia.org/wiki/Instruction_set_architecture">ISAs (Instruction Set Architectures)</a>.</p>

<p><img src="/assets/img/blog/triton/triton_compiler_architecture.jpeg" alt="Triton Compiler Ecosystem" /></p>

<p class="figcaption">Triton Compiler Ecosystem (<a href="https://www.youtube.com/watch?v=AtbnRIzpwho">Link</a>)</p>

<p>While most programs targeted for GPUs will probably end up in LLVM IR and then get translated into the vendor-specific ISAs, code for CPUs, FPGAs and other hardware can get translated into other compiler backends, making the ecosystem modular.</p>

<h3 id="34-benchmarking-triton">3.4 Benchmarking Triton</h3>

<p>We want to benchmark our Triton kernels similar to our CUDA kernels, of course; if they do not give us speed-ups we would not have needed to deal with them in the first place!</p>

<p>For profiling, we use the decorator <code class="language-plaintext highlighter-rouge">triton.testing.perf_report</code> to get a performance report of our kernel.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">@</span><span class="n">triton</span><span class="p">.</span><span class="n">testing</span><span class="p">.</span><span class="n">perf_report</span><span class="p">(</span>
        <span class="n">triton</span><span class="p">.</span><span class="n">testing</span><span class="p">.</span><span class="n">Benchmark</span><span class="p">(</span>
            <span class="n">x_names</span><span class="o">=</span><span class="p">[</span><span class="s">'size'</span><span class="p">],</span>
            <span class="n">x_vals</span><span class="o">=</span><span class="p">[</span><span class="mi">2</span><span class="o">**</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">12</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">1</span><span class="p">)],</span>
            <span class="n">x_log</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
            <span class="n">line_arg</span><span class="o">=</span><span class="s">'provider'</span><span class="p">,</span>
            <span class="n">line_vals</span><span class="o">=</span><span class="p">[</span><span class="s">'triton'</span><span class="p">,</span> <span class="s">'torch'</span><span class="p">],</span>
            <span class="n">line_names</span><span class="o">=</span><span class="p">[</span><span class="s">'Triton'</span><span class="p">,</span> <span class="s">'Torch'</span><span class="p">],</span>
            <span class="n">styles</span><span class="o">=</span><span class="p">[(</span><span class="s">'blue'</span><span class="p">,</span> <span class="s">'-'</span><span class="p">),</span> <span class="p">(</span><span class="s">'green'</span><span class="p">,</span> <span class="s">'-'</span><span class="p">)],</span>  <span class="c1"># Line styles.
</span>            <span class="n">ylabel</span><span class="o">=</span><span class="s">'GB/s'</span><span class="p">,</span>  <span class="c1"># Label name for the y-axis.
</span>            <span class="n">plot_name</span><span class="o">=</span><span class="s">'matrix-square-performance'</span><span class="p">,</span>  <span class="c1"># Name for the plot. Used also as a file name for saving the plot.
</span>            <span class="n">args</span><span class="o">=</span><span class="p">{},</span>  <span class="c1"># Values for function arguments not in x_names and y_name.
</span>        <span class="p">))</span>

<span class="k">def</span> <span class="nf">benchmark</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">provider</span><span class="p">):</span>
    <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">rand</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span>
    <span class="n">quantiles</span> <span class="o">=</span> <span class="p">[</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.8</span><span class="p">]</span>
    <span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s">'torch'</span><span class="p">:</span>
        <span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="p">.</span><span class="n">testing</span><span class="p">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">x</span><span class="o">**</span><span class="mi">2</span><span class="p">,</span> <span class="n">quantiles</span> <span class="o">=</span> <span class="n">quantiles</span><span class="p">)</span>
    <span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s">'triton'</span><span class="p">:</span>
        <span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="p">.</span><span class="n">testing</span><span class="p">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">square</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">quantiles</span><span class="o">=</span><span class="n">quantiles</span><span class="p">)</span>
    <span class="n">gbps</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">ms</span><span class="p">:</span> <span class="mi">12</span> <span class="o">*</span> <span class="n">size</span> <span class="o">/</span> <span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-6</span>
    <span class="k">return</span> <span class="n">gbps</span><span class="p">(</span><span class="n">ms</span><span class="p">),</span> <span class="n">gbps</span><span class="p">(</span><span class="n">max_ms</span><span class="p">),</span> <span class="n">gbps</span><span class="p">(</span><span class="n">min_ms</span><span class="p">)</span>

<span class="n">benchmark</span><span class="p">.</span><span class="n">run</span><span class="p">(</span><span class="n">show_plots</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">print_data</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
</code></pre></div></div>

<p>With this benchmark, we get both a print-out of our data as well as a graphical representation.</p>

<p><img src="/assets/img/blog/triton/triton_benchmark.jpeg" alt="Triton Compiler Ecosystem" /></p>

<p class="figcaption">Triton benchmark plot.</p>

<p>We can see that in this case, there is no significant speed-up over PyTorch; however looking at <a href="https://triton-lang.org/main/getting-started/tutorials/index.html">more complex examples on the Triton page</a>, there are speed-ups to be achieved.</p>

<p>To read more about Triton, you can have a look at the <a href="https://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf">original research paper</a>, a <a href="https://www.youtube.com/watch?v=G951lCm_qnk">video by the author Philippe Tillet</a> and a <a href="https://www.reddit.com/r/MachineLearning/comments/otdpkx/n_introducing_triton_opensource_gpu_programming/">Reddit discussion</a> where he himself gave some useful perspectives on the project.</p>

<h2 id="4-torchcompile">4. <code class="language-plaintext highlighter-rouge">torch.compile</code></h2>

<p>To get a feel for how Triton fits into the PyTorch2 compilation stack, we can leverage the fact that <code class="language-plaintext highlighter-rouge">torch.compile</code> actually uses Triton under the hood. We can just write a simple function and then call <code class="language-plaintext highlighter-rouge">torch.compile</code> on it. Then, when running the script, we set the environment variable <code class="language-plaintext highlighter-rouge">os.environ["TORCH_LOGS"]</code> to different values (depending on which stage of the PyTorch compilation process we want to investigate) or set these values directly in PyTorch via <code class="language-plaintext highlighter-rouge">torch._logging.set_logs(argument)</code> with different arguments.</p>

<table>
  <thead>
    <tr>
      <th>Stage</th>
      <th>Value for TORCH_LOGS<br />(Env. variable)</th>
      <th>Argument to <code class="language-plaintext highlighter-rouge">set_logs</code><br />(Python function)</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Dynamo Tracing</td>
      <td><code class="language-plaintext highlighter-rouge">+dynamo</code></td>
      <td><code class="language-plaintext highlighter-rouge">dynamo=logging.DEBUG</code></td>
    </tr>
    <tr>
      <td>Traced Graph</td>
      <td><code class="language-plaintext highlighter-rouge">graph</code></td>
      <td><code class="language-plaintext highlighter-rouge">graph=True</code></td>
    </tr>
    <tr>
      <td>Fusion Detections</td>
      <td><code class="language-plaintext highlighter-rouge">fusion</code></td>
      <td><code class="language-plaintext highlighter-rouge">fusion=True</code></td>
    </tr>
    <tr>
      <td>Triton Output Code</td>
      <td><code class="language-plaintext highlighter-rouge">output_code</code></td>
      <td><code class="language-plaintext highlighter-rouge">output_code=True</code></td>
    </tr>
  </tbody>
</table>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code></code></pre></div></div>

<p><img src="/assets/img/blog/triton/triton_dl_stack.jpeg" alt="Triton in the DL Stack" /></p>

<p class="figcaption">Triton DL Stack (<a href="https://www.youtube.com/watch?v=AtbnRIzpwho">Link</a>)</p>

<p>Looking at this, we can see that Triton leverages some heuristics to enable autotuning and other efficiency improvements. For example, it infers data types and element numbers and then uses this information to optimize the kernel.</p>

<h2 id="credits">Credits</h2>

<p>Thanks to the <a href="https://www.youtube.com/@CUDAMODE">CUDA MODE</a> lecture series for the inspiration for this post and the community around that for interesting discussions!</p>]]></content><author><name>Kieran Didi</name><email>kieran.didi@gmail.com</email></author><category term="ml" /><summary type="html"><![CDATA[A bit of background on GPU acceleration and how to use it with PyTorch]]></summary><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" url="/assets/img/blog/ssh_gpu/gpucluster.jpg" /><media:content medium="image" url="/assets/img/blog/ssh_gpu/gpucluster.jpg" xmlns:media="http://search.yahoo.com/mrss/" /></entry><entry><title type="html">How to represent protein structures in ML</title><link href="/blog/proteins/2024-02-03-protein-representations/" rel="alternate" type="text/html" title="How to represent protein structures in ML" /><published>2024-02-03T00:00:00+00:00</published><updated>2026-02-12T20:40:40+00:00</updated><id>/blog/proteins/protein-representations</id><content type="html" xml:base="/blog/proteins/2024-02-03-protein-representations/"><![CDATA[<p>Machine Learning approaches empower <a href="https://www.sciencedirect.com/science/article/abs/pii/S2405471223002983">a new suite of algorithms and applications</a> in structural biology and protein engineering/design. However, there is quite a gap between how protein structure data is classically stored in databases and how machine learning algorithms deal with data. Here, I want to bridge that gap and show how current algorithms such as AlphaFold2 make use of protein structure data in practice.</p>

<ol id="markdown-toc">
  <li><a href="#protein-structure-file-formats-pdb-vs-pdbxmmcif-vs-mmtf-vs-binarycif" id="markdown-toc-protein-structure-file-formats-pdb-vs-pdbxmmcif-vs-mmtf-vs-binarycif">Protein Structure File Formats: PDB vs PDBx/mmCIF vs MMTF vs BinaryCIF</a>    <ol>
      <li><a href="#pdb-format-legacy" id="markdown-toc-pdb-format-legacy">PDB format (legacy)</a></li>
      <li><a href="#pdbxmmcif-format" id="markdown-toc-pdbxmmcif-format">PDBx/mmCIF format</a></li>
      <li><a href="#mmtf-format-legacy" id="markdown-toc-mmtf-format-legacy">MMTF format (legacy)</a></li>
      <li><a href="#binarycif-format" id="markdown-toc-binarycif-format">BinaryCIF format</a></li>
    </ol>
  </li>
  <li><a href="#input-data-for-machine-learning-algorithms" id="markdown-toc-input-data-for-machine-learning-algorithms">Input Data for Machine Learning Algorithms</a>    <ol>
      <li><a href="#amino-acid-encodings" id="markdown-toc-amino-acid-encodings">Amino acid encodings</a></li>
      <li><a href="#coordinates-atom14-vs-atom37" id="markdown-toc-coordinates-atom14-vs-atom37">Coordinates: Atom14 vs Atom37</a></li>
      <li><a href="#boundary-conditions-oxt" id="markdown-toc-boundary-conditions-oxt">Boundary Conditions: OXT</a></li>
      <li><a href="#example-lysozyme-atom-numbering" id="markdown-toc-example-lysozyme-atom-numbering">Example: Lysozyme atom numbering</a></li>
    </ol>
  </li>
  <li><a href="#reference-systems-local-reference-frames-vs-reference-free-methods" id="markdown-toc-reference-systems-local-reference-frames-vs-reference-free-methods">Reference Systems: Local reference frames vs reference-free methods</a>    <ol>
      <li><a href="#local-reference-based-methods" id="markdown-toc-local-reference-based-methods">Local reference-based methods</a>        <ol>
          <li><a href="#why-se3-instead-of-e3-equivariance-can-be-important" id="markdown-toc-why-se3-instead-of-e3-equivariance-can-be-important">Why SE(3) instead of E(3) equivariance can be important</a></li>
          <li><a href="#ambivalent-mappings-from-frames-to-coordinates" id="markdown-toc-ambivalent-mappings-from-frames-to-coordinates">Ambivalent mappings from frames to coordinates</a></li>
        </ol>
      </li>
      <li><a href="#reference-free-methods-invariant-and-equivariant-update-functions" id="markdown-toc-reference-free-methods-invariant-and-equivariant-update-functions">Reference-free methods: Invariant and Equivariant Update Functions</a></li>
      <li><a href="#screw-these-symmetries-data-augmentation-and-other-strategies" id="markdown-toc-screw-these-symmetries-data-augmentation-and-other-strategies">Screw these symmetries: data augmentation and other strategies</a></li>
    </ol>
  </li>
  <li><a href="#batching-padded-versus-sparse" id="markdown-toc-batching-padded-versus-sparse">Batching: Padded versus sparse</a>    <ol>
      <li><a href="#the-batching-pain-with-variable-length-input" id="markdown-toc-the-batching-pain-with-variable-length-input">The batching pain with variable-length input</a></li>
      <li><a href="#efficient-padding-via-length-batching" id="markdown-toc-efficient-padding-via-length-batching">Efficient padding via length batching</a></li>
      <li><a href="#sparse-batching" id="markdown-toc-sparse-batching">Sparse batching</a></li>
    </ol>
  </li>
  <li><a href="#afdb-esmatlas--co-how-to-deal-with-large-databases" id="markdown-toc-afdb-esmatlas--co-how-to-deal-with-large-databases">AFDB, ESMAtlas &amp; co: how to deal with large databases</a></li>
  <li><a href="#summary" id="markdown-toc-summary">Summary</a></li>
</ol>

<p>If you would like to cite this post in an academic context, you can use this BibTeX snippet:</p>

<div class="language-bibtex highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nc">@misc</span><span class="p">{</span><span class="nl">didi2024proteinrepresentations</span><span class="p">,</span>
  <span class="na">author</span> <span class="p">=</span> <span class="s">{Didi, Kieran}</span><span class="p">,</span>
  <span class="na">title</span> <span class="p">=</span> <span class="s">{How to represent protein structures in ML}</span><span class="p">,</span>
  <span class="na">url</span> <span class="p">=</span> <span class="s">{https://kdidi.netlify.app/blog/proteins/2024-02-03-protein-representations/}</span><span class="p">,</span>
  <span class="na">year</span> <span class="p">=</span> <span class="s">{2024}</span>
<span class="p">}</span>
</code></pre></div></div>

<h2 id="protein-structure-file-formats-pdb-vs-pdbxmmcif-vs-mmtf-vs-binarycif">Protein Structure File Formats: PDB vs PDBx/mmCIF vs MMTF vs BinaryCIF</h2>

<p>Before we turn to machine learning algorithms such as AlphaFold2, let’s shortly discuss how these coordinates are stored in the <a href="https://www.rcsb.org/">PDB</a> to start with.</p>

<p>Over the years there has been quite an evolution with respect to data formats for protein structures.</p>

<h3 id="pdb-format-legacy">PDB format (legacy)</h3>

<p>The original PDB format <a href="https://en.wikipedia.org/wiki/Protein_Data_Bank_(file_format)">introduced in 1976</a> was intended as a human-readable file that would allow researchers to exchange data easily. While very successful, it is a very wasteful format by today’s standards in terms of whitespace and indentation, making automatic parsing realtively difficult.</p>

<p>Here an excerpt of the PDB file of a <a href="https://www.rcsb.org/structure/168L">Lysozyme structure</a>:</p>

<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c"># file: "168l.pdb"</span>
HEADER    HYDROLASE <span class="o">(</span>O-GLYCOSYL<span class="o">)</span>                  24-MAR-95   168L              
TITLE     PROTEIN FLEXIBILITY AND ADAPTABILITY SEEN IN 25 CRYSTAL FORMS OF T4   
TITLE    2 LYSOZYME                                                             
COMPND    MOL_ID: 1<span class="p">;</span>                                                            
COMPND   2 MOLECULE: T4 LYSOZYME<span class="p">;</span>                                               
COMPND   3 CHAIN: A, B, C, D, E<span class="p">;</span>                                                
COMPND   4 EC: 3.2.1.17<span class="p">;</span>                                                        
COMPND   5 ENGINEERED: YES                                                      
SOURCE    MOL_ID: 1<span class="p">;</span>                                                            
SOURCE   2 ORGANISM_SCIENTIFIC: ENTEROBACTERIA PHAGE T4<span class="p">;</span>                        
SOURCE   3 ORGANISM_TAXID: 10665<span class="p">;</span>                                               
SOURCE   4 EXPRESSION_SYSTEM_VECTOR_TYPE: PLASMID<span class="p">;</span>                              
SOURCE   5 EXPRESSION_SYSTEM_PLASMID: M13                                       
KEYWDS    HYDROLASE <span class="o">(</span>O-GLYCOSYL<span class="o">)</span>                                                
EXPDTA    X-RAY DIFFRACTION                                                     
AUTHOR    X.-J.ZHANG,B.W.MATTHEWS                                               
REVDAT   5   07-FEB-24 168L    1       REMARK SEQADV                            
REVDAT   4   29-NOV-17 168L    1       REMARK HELIX                             
REVDAT   3   24-FEB-09 168L    1       VERSN                                    
REVDAT   2   01-APR-03 168L    1       JRNL                                     
REVDAT   1   10-JUL-95 168L    0                                                
JRNL        AUTH   X.J.ZHANG,J.A.WOZNIAK,B.W.MATTHEWS                           
JRNL        TITL   PROTEIN FLEXIBILITY AND ADAPTABILITY SEEN IN 25 CRYSTAL      
JRNL        TITL 2 FORMS OF T4 LYSOZYME.                                        
JRNL        REF    J.MOL.BIOL.                   V. 250   527 1995              
JRNL        REFN                   ISSN 0022-2836                               
JRNL        PMID   7616572                                                      
JRNL        DOI    10.1006/JMBI.1995.0396                                       
REMARK   1                                                                      
REMARK   1 REFERENCE 1                                                          
REMARK   1  AUTH   L.H.WEAVER,B.W.MATTHEWS                                      
REMARK   1  TITL   STRUCTURE OF BACTERIOPHAGE T4 LYSOZYME REFINED AT 1.7        
REMARK   1  TITL 2 ANGSTROMS RESOLUTION                                         
REMARK   1  REF    J.MOL.BIOL.                   V. 193   189 1987              
REMARK   1  REFN                   ISSN 0022-2836                               
REMARK   2                                                                      
REMARK   2 RESOLUTION.    2.90 ANGSTROMS.
...
SEQRES   1 A  164  MET ASN ILE PHE GLU MET LEU ARG ILE ASP GLU GLY LEU          
SEQRES   2 A  164  ARG LEU LYS ILE TYR LYS ASP THR GLU GLY TYR TYR THR          
SEQRES   3 A  164  ILE GLY ILE GLY HIS LEU LEU THR LYS SER PRO SER LEU          
SEQRES   4 A  164  ASN ALA ALA LYS SER GLU LEU ASP LYS ALA ILE GLY ARG          
SEQRES   5 A  164  ASN CYS ASN GLY VAL ILE THR LYS ASP GLU ALA GLU LYS
...
HELIX    1  A1 ILE A    3  GLU A   11  1                                   9    
HELIX    2  A2 LEU A   39  ILE A   50  1                                  12    
HELIX    3  A3 LYS A   60  ARG A   80  1                                  21    
HELIX    4  A4 ALA A   82  SER A   90  1                                   9    
HELIX    5  A5 ALA A   93  MET A  106  1                                  14    
...
ATOM      1  N   MET A   1      74.851  69.339  <span class="nt">-6</span>.260  1.00 37.97           N  
ATOM      2  CA  MET A   1      75.137  68.258  <span class="nt">-5</span>.357  1.00 38.78           C  
ATOM      3  C   MET A   1      73.896  67.665  <span class="nt">-4</span>.750  1.00 40.36           C  
ATOM      4  O   MET A   1      72.862  68.348  <span class="nt">-4</span>.627  1.00 40.50           O  
ATOM      5  CB  MET A   1      76.039  68.696  <span class="nt">-4</span>.203  1.00 40.16           C      
</code></pre></div></div>

<p>You can imagine how parsing something like the resolution automatically from this might be quite a pain. The main structure of such a PDB file is as follows:</p>

<ul>
  <li>it starts with a <code class="language-plaintext highlighter-rouge">HEADER</code> and some additional metadata such as the authors and the journal where the structure was published</li>
  <li>then there are many <code class="language-plaintext highlighter-rouge">REMARKS</code> that give additional information like the resolution of the structure and the experimental method by which it was acquired</li>
  <li>what follows is the <code class="language-plaintext highlighter-rouge">SEQRES</code> (short for sequence representation) that lists the sequence for the structure for quick parsing (more information on this <a href="https://pdb101.rcsb.org/learn/guide-to-understanding-pdb-data/primary-sequences-and-the-pdb-format">here</a>)</li>
  <li>then some information about assigned secondary structure indicated via <code class="language-plaintext highlighter-rouge">HELIX</code> or <code class="language-plaintext highlighter-rouge">SHEET</code></li>
  <li>finally, we have the actual structure information, prefaced with the ATOM qualifier, describing the atom type, the residue name, which chain it is part of and of course the coordinates as well as additional metadata such as the <a href="https://proteopedia.org/wiki/index.php/Temperature_value">B-factor</a></li>
</ul>

<p>Two important things to note at this point:</p>
<ol>
  <li>Counterintuitively, the <code class="language-plaintext highlighter-rouge">SEQRES</code> information does not always align with the sequence contained in the structure itself via the <code class="language-plaintext highlighter-rouge">ATOM</code> fields. This is a problem that plagues later data formats as well and can be <a href="https://pdb101.rcsb.org/learn/guide-to-understanding-pdb-data/primary-sequences-and-the-pdb-format">attributed to a variety of reasons</a>, mostly that flexible loops and chain ends are often not resolved in experimental structures but nonetheless present in the <code class="language-plaintext highlighter-rouge">SEQRES</code> representation. That is the reason why models like AlphaFold2 and OpenFold require tools like <a href="https://academic.oup.com/bioinformatics/article/36/6/1928/5607735">KAlign</a> to align the sequence representation to the structure representation in cases where they do not match during template search (see for example <a href="https://github.com/aqlaboratory/openfold/blob/main/openfold/data/tools/kalign.py">this file</a> in the OpenFold codebase or section 1.2.3 in the <a href="https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf">AlphaFold 2 SI</a> (page 5)).</li>
  <li>The atom names are not just the chemical elements (C, N, O, …), but have specific other descriptors depending on where in the amino acid this element occurs (C can be C, CA, CB, CG, …). How each of these amino acids is named exactly is described in the <a href="https://www.wwpdb.org/data/ccd#pdbechem">PDB Chemical Component Dictionary</a>, but in general you can keep in mind that for many atoms we enumerate them with Greek characters after the atom symbol; CG then stands for “Carbon Gamma”, i.e the third carbon atom in the chain.</li>
</ol>

<p>The PDB format does not support Greek characters, so the atom names are translated into the most similar Latin letters:</p>

<table>
  <thead>
    <tr>
      <th>Atom name</th>
      <th>Pronunciation</th>
      <th>PDB name</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>α</td>
      <td>alpha</td>
      <td>A</td>
    </tr>
    <tr>
      <td>β</td>
      <td>beta</td>
      <td>B</td>
    </tr>
    <tr>
      <td>γ</td>
      <td>gamma</td>
      <td>G</td>
    </tr>
    <tr>
      <td>δ</td>
      <td>delta</td>
      <td>D</td>
    </tr>
    <tr>
      <td>ε</td>
      <td>epsilon</td>
      <td>E</td>
    </tr>
    <tr>
      <td>ζ</td>
      <td>zeta</td>
      <td>Z</td>
    </tr>
    <tr>
      <td>ν</td>
      <td>nu</td>
      <td>H</td>
    </tr>
  </tbody>
</table>

<p>C<span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>α</mi></mrow><annotation encoding="application/x-tex">\alpha</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord mathnormal" style="margin-right:0.0037em;">α</span></span></span></span> is thus called CA, O<span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>γ</mi></mrow><annotation encoding="application/x-tex">\gamma</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.1944em;"></span><span class="mord mathnormal" style="margin-right:0.05556em;">γ</span></span></span></span> is called OG and so on. Sometimes (e.g. in Asp) there may be two identical atoms in the same position, in which case they are named 1 and 2, e.g. the two carboxyl atoms in Asp are called OD1 and OD2. Later in this article we will see a representation of these atoms for all amino acids, but for now we can use the <a href="https://www.ebi.ac.uk/pdbe-srv/pdbechem/">PDBeChem interface</a> to look up this representation for the amino acid (or, in fact, any chemical component in the PDB) that we are interested in.</p>

<p>If you insert <code class="language-plaintext highlighter-rouge">SER</code> for the amino acid serine in the “Code” search box, hit the <code class="language-plaintext highlighter-rouge">Search</code> button and upon getting the result click the <code class="language-plaintext highlighter-rouge">Atoms</code> tab on the left-hand side of the page, you will get all the atoms in that specific amino acid. We will see later that the representation in models such as AlphaFold2 is a bit shorter since a) they do not include hydrogens in the model and b) one oxygen atom is lost in the condensation of the individual amino acids into the backbone (one water molecule per bond formed to be precise).</p>

<h3 id="pdbxmmcif-format">PDBx/mmCIF format</h3>

<p>As mentioned, the PDB format has quite a few limitations when it comes to supporting large structures as well as complex chemistries. To improve on this, a new format called <a href="https://pdb101.rcsb.org/learn/guide-to-understanding-pdb-data/beginner%E2%80%99s-guide-to-pdb-structures-and-the-pdbx-mmcif-format">PDBx/mmCIF</a> was introduced and is currently the default format in the PDB. It uses the ASCII character set and is a tabular data format, in which data items have a name of the format <code class="language-plaintext highlighter-rouge">_categoryname.attributename</code>, for example <code class="language-plaintext highlighter-rouge">_citation_author.name</code>. If there is only one value for this data item, it is displayed in the same line as a key-value pair. If there are multiple values for these names, a <code class="language-plaintext highlighter-rouge">loop_</code> token prefaces the categories, followed by rows of data items where the different values are separeted by white spaces.</p>

<p>Compared to the legacy PDB format where a structure is just described as a list of atoms and amino acids, PDBx/mmCIF has more semantics in its representation. One example of this is the concept of an <em>entity</em>, which is defined as a <a href="https://pdb101.rcsb.org/learn/guide-to-understanding-pdb-data/beginner%E2%80%99s-guide-to-pdb-structures-and-the-pdbx-mmcif-format">chemically distinct part of a structure as represented in the PDBx/mmCIF data file</a>. For example, a chemical ligand would be an entity, as would chains in a protein. Importantly, these entities can be present multiple times: a homodimer will have one entity since the same chain is present twice.</p>

<p>With this background, let us look at the PDBx/mmCIF file for the same lysozyme structure we looked at before:</p>

<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c"># file: "168l.cif"</span>
data_168L
<span class="c"># </span>
_entry.id   168L 
<span class="c"># </span>
_audit_conform.dict_name       mmcif_pdbx.dic 
_audit_conform.dict_version    5.385 
_audit_conform.dict_location   http://mmcif.pdb.org/dictionaries/ascii/mmcif_pdbx.dic 
<span class="c"># </span>
loop_
_database_2.database_id 
_database_2.database_code 
_database_2.pdbx_database_accession 
_database_2.pdbx_DOI 
PDB   168L         pdb_0000168l 10.2210/pdb168l/pdb 
WWPDB D_1000170153 ?            ?                   
<span class="c"># </span>
...
_entity.id                         1 
_entity.type                       polymer 
_entity.src_method                 man 
_entity.pdbx_description           <span class="s1">'T4 LYSOZYME'</span> 
_entity.formula_weight             18373.139 
_entity.pdbx_number_of_molecules   5 
_entity.pdbx_ec                    3.2.1.17 
_entity.pdbx_mutation              ? 
_entity.pdbx_fragment              ? 
_entity.details                    ? 
<span class="c"># </span>
_entity_poly.entity_id                      1 
_entity_poly.type                           <span class="s1">'polypeptide(L)'</span> 
_entity_poly.nstd_linkage                   no 
_entity_poly.nstd_monomer                   no 
_entity_poly.pdbx_seq_one_letter_code       
<span class="p">;</span>MNIFEMLRIDEGLRLKIYKDTEGYYTIGIGHLLTKSPSLNAAKSELDKAIGRNCNGVITKDEAEKLFNQDVDAAVRGILR
NAKLKPVYDSLDAVRRCALINMVFQMGETGVAGFTNSLRMLQQKRWDAAAAALAAAAWYNQTPNRAKRVITTFRTGTWDA
YKNL
<span class="p">;</span>
_entity_poly.pdbx_seq_one_letter_code_can   
<span class="p">;</span>MNIFEMLRIDEGLRLKIYKDTEGYYTIGIGHLLTKSPSLNAAKSELDKAIGRNCNGVITKDEAEKLFNQDVDAAVRGILR
NAKLKPVYDSLDAVRRCALINMVFQMGETGVAGFTNSLRMLQQKRWDAAAAALAAAAWYNQTPNRAKRVITTFRTGTWDA
YKNL
<span class="p">;</span>
_entity_poly.pdbx_strand_id                 A,B,C,D,E 
_entity_poly.pdbx_target_identifier         ? 
<span class="c"># </span>
loop_
_entity_poly_seq.entity_id 
_entity_poly_seq.num 
_entity_poly_seq.mon_id 
_entity_poly_seq.hetero 
1 1   MET n 
1 2   ASN n 
1 3   ILE n 
1 4   PHE n 
1 5   GLU n 
1 6   MET n 
1 7   LEU n 
1 8   ARG n 
...
loop_
_chem_comp.id 
_chem_comp.type 
_chem_comp.mon_nstd_flag 
_chem_comp.name 
_chem_comp.pdbx_synonyms 
_chem_comp.formula 
_chem_comp.formula_weight 
ALA <span class="s1">'L-peptide linking'</span> y ALANINE         ? <span class="s1">'C3 H7 N O2'</span>     89.093  
ARG <span class="s1">'L-peptide linking'</span> y ARGININE        ? <span class="s1">'C6 H15 N4 O2 1'</span> 175.209 
ASN <span class="s1">'L-peptide linking'</span> y ASPARAGINE      ? <span class="s1">'C4 H8 N2 O3'</span>    132.118 
ASP <span class="s1">'L-peptide linking'</span> y <span class="s1">'ASPARTIC ACID'</span> ? <span class="s1">'C4 H7 N O4'</span>     133.103 
CYS <span class="s1">'L-peptide linking'</span> y CYSTEINE        ? <span class="s1">'C3 H7 N O2 S'</span>   121.158 
GLN <span class="s1">'L-peptide linking'</span> y GLUTAMINE       ? <span class="s1">'C5 H10 N2 O3'</span>   146.144 
GLU <span class="s1">'L-peptide linking'</span> y <span class="s1">'GLUTAMIC ACID'</span> ? <span class="s1">'C5 H9 N O4'</span>     147.129 
...
loop_
_atom_site.group_PDB 
_atom_site.id 
_atom_site.type_symbol 
_atom_site.label_atom_id 
_atom_site.label_alt_id 
_atom_site.label_comp_id 
_atom_site.label_asym_id 
_atom_site.label_entity_id 
_atom_site.label_seq_id 
_atom_site.pdbx_PDB_ins_code 
_atom_site.Cartn_x 
_atom_site.Cartn_y 
_atom_site.Cartn_z 
_atom_site.occupancy 
_atom_site.B_iso_or_equiv 
_atom_site.pdbx_formal_charge 
_atom_site.auth_seq_id 
_atom_site.auth_comp_id 
_atom_site.auth_asym_id 
_atom_site.auth_atom_id 
_atom_site.pdbx_PDB_model_num 
ATOM 1    N N   <span class="nb">.</span> MET A 1 1   ? 74.851  69.339  <span class="nt">-6</span>.260  1.00 37.97  ? 1   MET A N   1 
ATOM 2    C CA  <span class="nb">.</span> MET A 1 1   ? 75.137  68.258  <span class="nt">-5</span>.357  1.00 38.78  ? 1   MET A CA  1 
ATOM 3    C C   <span class="nb">.</span> MET A 1 1   ? 73.896  67.665  <span class="nt">-4</span>.750  1.00 40.36  ? 1   MET A C   1 
ATOM 4    O O   <span class="nb">.</span> MET A 1 1   ? 72.862  68.348  <span class="nt">-4</span>.627  1.00 40.50  ? 1   MET A O   1 
ATOM 5    C CB  <span class="nb">.</span> MET A 1 1   ? 76.039  68.696  <span class="nt">-4</span>.203  1.00 40.16  ? 1   MET A CB  1 
ATOM 6    C CG  <span class="nb">.</span> MET A 1 1   ? 76.921  67.555  <span class="nt">-3</span>.776  1.00 41.09  ? 1   MET A CG  1 
ATOM 7    S SD  <span class="nb">.</span> MET A 1 1   ? 77.902  67.038  <span class="nt">-5</span>.191  1.00 40.98  ? 1   MET A SD  1 
ATOM 8    C CE  <span class="nb">.</span> MET A 1 1   ? 78.748  65.645  <span class="nt">-4</span>.424  1.00 41.39  ? 1   MET A CE  1 
ATOM 9    N N   <span class="nb">.</span> ASN A 1 2   ? 74.139  66.409  <span class="nt">-4</span>.302  1.00 41.77  ? 2   ASN A N   1 
...
ATOM 6442 C CG  <span class="nb">.</span> LEU E 1 164 ? 95.884  25.834  <span class="nt">-10</span>.740 0.00 85.05  ? 164 LEU E CG  1 
ATOM 6443 C CD1 <span class="nb">.</span> LEU E 1 164 ? 96.110  27.302  <span class="nt">-11</span>.107 0.00 85.07  ? 164 LEU E CD1 1 
ATOM 6444 C CD2 <span class="nb">.</span> LEU E 1 164 ? 94.874  25.202  <span class="nt">-11</span>.694 0.00 85.06  ? 164 LEU E CD2 1 
ATOM 6445 O OXT <span class="nb">.</span> LEU E 1 164 ? 98.129  21.647  <span class="nt">-9</span>.779  0.00 84.32  ? 164 LEU E OXT 1 
<span class="c"># </span>
</code></pre></div></div>

<h3 id="mmtf-format-legacy">MMTF format (legacy)</h3>

<p>PDBx/mmCIF is now the standard format for storing macromolecular data. While due to its extensible and verbose format it has rich metadata and is suited for <em>archival</em> purposes, it is not the best format to <em>transmit</em> large amounts of structural data due to redundant annotations and repetitive information as you have seen above. Also, the inefficient representation of coordinates separated by whitespaces to make it human-readable is another hurdle for fast transmission of data.</p>

<p>Due to these limitations, the <a href="https://mmtf.rcsb.org/index.html">MMTF format</a> (Macromolecular transmission format) was introduced. It does not contain all data present in the PDBx/mmCIF files, but all the data necessary for most visualisation and structural analysis programs. The main pros of MMTF are its compact encoding and fast parsing due to binary instead of string representations.</p>

<p><img src="/assets/img/blog/prot_representation/mmtf_parsing.png" alt="MMTF Compression Pipeline" /></p>

<p class="figcaption">Overview of the MMTF compression pipeline. Source: <a href="https://github.com/sbl-sdsc/mmtf-workshop-2018/blob/master/0-introduction/MMTF2018-Introduction.pdf">UCSD Presentation</a></p>

<p>We can see that after some data preparation, the main steps in the MMTF pipeline are various ways of encoding to reduce the file size:</p>
<ul>
  <li><a href="">integer encoding</a></li>
  <li><a href="">dictionary encoding</a></li>
  <li><a href="">run-length encoding</a></li>
  <li><a href="">delta encoding</a></li>
</ul>

<p>After these encodings, the file size is compressed further by packing into the <a href="https://msgpack.org/">MessagePack format</a>. Its slogan reads like <em>It’s like JSON, but fast and small</em>, indicating its flexiblity in storing data e.g. as key-value pars, but in a binarized format.</p>

<p>MMTF is great for fast transmission of data and rethought quite a lot of things in clever ways. However, it deviated quite a bit from the mmCIF standard and therefore <a href="https://bioinformatics.stackexchange.com/questions/14738/binarycif-vs-mmtf-formats-which-one-to-choose">never really caught on in the community</a>. This has now been confirmed, with MMTF being <a href="https://www.mail-archive.com/ccp4bb@jiscmail.ac.uk/msg56121.html">deprecated from July 2024 onward</a>.</p>

<h3 id="binarycif-format">BinaryCIF format</h3>

<p>There was a need for a binarized efficient format for protein structure information transfer that was more aligned with the PDBx/mmCIF file format specification. Enter <a href="https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1008247">Binary CIF</a>, a newer format that is easier to interconvert with the now standard PDBx/mmCIF. The <a href="https://github.com/dsehnal/BinaryCIF">Binary CIF specification</a> is actually quite readable, so I recommend checking it out.</p>

<p>BinaryCIF was heavily inspired by MMTF, with many people working on both formats. This is visible in the usage of MessagePack and the different encodings employed.</p>

<p><img src="/assets/img/blog/prot_representation/binary_cif_compression.png" alt="BinaryCIF compressions" /></p>

<p class="figcaption">Encodings employed for BinaryCIF. Source: <a href="https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1008247">BinaryCIF paper</a></p>

<p>There are a few additional ones that you can read up on in the specification on GitHub, but mostly the same encodings were used as in MMTF</p>

<h2 id="input-data-for-machine-learning-algorithms">Input Data for Machine Learning Algorithms</h2>

<p>We’ve discussed how protein structures are stored in databases; with that done, let us talk about how they are represented in machine learning algorithms.</p>

<h3 id="amino-acid-encodings">Amino acid encodings</h3>

<p>Encoding the sequence information into a numerical format should not be too hard; our vocabulary size is only 20 and we do not have to deal with symmetries as we will see later with geometric information like coordinates.</p>

<p>However, if you actually look into different code bases, you will soon find a decade-old problem revived again:</p>

<p><img src="/assets/img/blog/prot_representation/standards_xkcd.png" alt="standardisation" /></p>

<p class="figcaption">The old ordeal of standardisation. Source: <a href="https://xkcd.com/927/">xkcd.com</a></p>

<p>Depending on which codebase you use, the ordering of amino acids used to encode them into numerical format might be different, introducing the possibility of silent but horrible bugs later down the line. Some alphabets even have a different vocabulary size since they deal with post-translational modifications, non-canonical amino acids or other phenomena you encounter in the wild west of structural biology.</p>

<p>For many applications, people use a de-facto standard by adapting the encoding defined by AlphaFold2. If we look at the OpenFold codebase, we can see that their ordering includes the 20 canonical amino acids together with an unknown residue token represented by <code class="language-plaintext highlighter-rouge">X</code>:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># This is the standard residue order when coding AA type as a number.
# Reproduce it by taking 3-letter AA codes and sorting them alphabetically.
</span><span class="n">restypes</span> <span class="o">=</span> <span class="p">[</span>
    <span class="s">"A"</span><span class="p">,</span>
    <span class="s">"R"</span><span class="p">,</span>
    <span class="s">"N"</span><span class="p">,</span>
    <span class="s">"D"</span><span class="p">,</span>
    <span class="s">"C"</span><span class="p">,</span>
    <span class="s">"Q"</span><span class="p">,</span>
    <span class="s">"E"</span><span class="p">,</span>
    <span class="s">"G"</span><span class="p">,</span>
    <span class="s">"H"</span><span class="p">,</span>
    <span class="s">"I"</span><span class="p">,</span>
    <span class="s">"L"</span><span class="p">,</span>
    <span class="s">"K"</span><span class="p">,</span>
    <span class="s">"M"</span><span class="p">,</span>
    <span class="s">"F"</span><span class="p">,</span>
    <span class="s">"P"</span><span class="p">,</span>
    <span class="s">"S"</span><span class="p">,</span>
    <span class="s">"T"</span><span class="p">,</span>
    <span class="s">"W"</span><span class="p">,</span>
    <span class="s">"Y"</span><span class="p">,</span>
    <span class="s">"V"</span><span class="p">,</span>
<span class="p">]</span>
<span class="n">restype_order</span> <span class="o">=</span> <span class="p">{</span><span class="n">restype</span><span class="p">:</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">restype</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">restypes</span><span class="p">)}</span>
<span class="n">restype_num</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">restypes</span><span class="p">)</span>  <span class="c1"># := 20.
</span><span class="n">unk_restype_index</span> <span class="o">=</span> <span class="n">restype_num</span>  <span class="c1"># Catch-all index for unknown restypes.
</span>
<span class="n">restypes_with_x</span> <span class="o">=</span> <span class="n">restypes</span> <span class="o">+</span> <span class="p">[</span><span class="s">"X"</span><span class="p">]</span>
<span class="n">restype_order_with_x</span> <span class="o">=</span> <span class="p">{</span><span class="n">restype</span><span class="p">:</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">restype</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">restypes_with_x</span><span class="p">)}</span>
</code></pre></div></div>
<p class="figcaption">OpenFold amino acid encoding.</p>

<p>However, some other models/frameworks use an amino acid encoding that is created by sorting the <em>1-letter codes</em> instead of the <em>3-letter codes</em> alphabetically. If in doubt, check which encoding your data uses to avoid confusion.</p>

<h3 id="coordinates-atom14-vs-atom37">Coordinates: Atom14 vs Atom37</h3>

<p>When looking at either the original <a href="https://github.com/google-deepmind/alphafold">AlphaFold codebase</a> or the open-source reproduction in PyTorch called <a href="https://github.com/aqlaboratory/openfold">OpenFold</a>, many people trip over how the coordinates from the file formats discussed earlier are represented inside the neural network. This confusion is enhanced by there being two different network-internal representations which are converted into each other depending on the use case scenario.</p>

<p>The documentation on these two representations is sparse, with one being available on a <a href="https://huggingface.co/spaces/simonduerr/ProteinMPNN/blame/e65166bd70446c6fddcc1581dbc6dac06e7f8dca/alphafold/alphafold/model/all_atom.py">HuggingFace docstring</a>:</p>

<p class="note" title="atom14 vs atom37">Generally we employ two different representations for all atom coordinates,
one is atom37 where each heavy atom corresponds to a given position in a 37
dimensional array, This mapping is non amino acid specific, but each slot
corresponds to an atom of a given name, for example slot 12 always corresponds
to ‘C delta 1’, positions that are not present for a given amino acid are
zeroed out and denoted by a mask.
The other representation we employ is called atom14, this is a more dense way
of representing atoms with 14 slots. Here a given slot will correspond to a
different kind of atom depending on amino acid type, for example slot 5
corresponds to ‘N delta 2’ for Aspargine, but to ‘C delta 1’ for Isoleucine.
14 is chosen because it is the maximum number of heavy atoms for any standard
amino acid.
The order of slots can be found in ‘residue_constants.residue_atoms’.
Internally the model uses the atom14 representation because it is
computationally more efficient.
The internal atom14 representation is turned into the atom37 at the output of
the network to facilitate easier conversion to existing protein datastructures.</p>

<p>What does this mean in practice? Let’s look at the code. When looking at <a href="https://github.com/aqlaboratory/openfold/blob/127f1e7023c380c01330cee45544c23c079babe9/openfold/np/residue_constants.py#L355"><code class="language-plaintext highlighter-rouge">residue_constants.residue_atoms</code></a>, we get the following description for the <code class="language-plaintext highlighter-rouge">atom14</code> representation:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># file: "residue_constants.py"
# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention.
</span><span class="n">residue_atoms</span> <span class="o">=</span> <span class="p">{</span>
    <span class="s">"ALA"</span><span class="p">:</span> <span class="p">[</span><span class="s">"C"</span><span class="p">,</span> <span class="s">"CA"</span><span class="p">,</span> <span class="s">"CB"</span><span class="p">,</span> <span class="s">"N"</span><span class="p">,</span> <span class="s">"O"</span><span class="p">],</span>
    <span class="s">"ARG"</span><span class="p">:</span> <span class="p">[</span><span class="s">"C"</span><span class="p">,</span> <span class="s">"CA"</span><span class="p">,</span> <span class="s">"CB"</span><span class="p">,</span> <span class="s">"CG"</span><span class="p">,</span> <span class="s">"CD"</span><span class="p">,</span> <span class="s">"CZ"</span><span class="p">,</span> <span class="s">"N"</span><span class="p">,</span> <span class="s">"NE"</span><span class="p">,</span> <span class="s">"O"</span><span class="p">,</span> <span class="s">"NH1"</span><span class="p">,</span> <span class="s">"NH2"</span><span class="p">],</span>
    <span class="s">"ASP"</span><span class="p">:</span> <span class="p">[</span><span class="s">"C"</span><span class="p">,</span> <span class="s">"CA"</span><span class="p">,</span> <span class="s">"CB"</span><span class="p">,</span> <span class="s">"CG"</span><span class="p">,</span> <span class="s">"N"</span><span class="p">,</span> <span class="s">"O"</span><span class="p">,</span> <span class="s">"OD1"</span><span class="p">,</span> <span class="s">"OD2"</span><span class="p">],</span>
    <span class="s">"ASN"</span><span class="p">:</span> <span class="p">[</span><span class="s">"C"</span><span class="p">,</span> <span class="s">"CA"</span><span class="p">,</span> <span class="s">"CB"</span><span class="p">,</span> <span class="s">"CG"</span><span class="p">,</span> <span class="s">"N"</span><span class="p">,</span> <span class="s">"ND2"</span><span class="p">,</span> <span class="s">"O"</span><span class="p">,</span> <span class="s">"OD1"</span><span class="p">],</span>
    <span class="s">"CYS"</span><span class="p">:</span> <span class="p">[</span><span class="s">"C"</span><span class="p">,</span> <span class="s">"CA"</span><span class="p">,</span> <span class="s">"CB"</span><span class="p">,</span> <span class="s">"N"</span><span class="p">,</span> <span class="s">"O"</span><span class="p">,</span> <span class="s">"SG"</span><span class="p">],</span>
    <span class="s">"GLU"</span><span class="p">:</span> <span class="p">[</span><span class="s">"C"</span><span class="p">,</span> <span class="s">"CA"</span><span class="p">,</span> <span class="s">"CB"</span><span class="p">,</span> <span class="s">"CG"</span><span class="p">,</span> <span class="s">"CD"</span><span class="p">,</span> <span class="s">"N"</span><span class="p">,</span> <span class="s">"O"</span><span class="p">,</span> <span class="s">"OE1"</span><span class="p">,</span> <span class="s">"OE2"</span><span class="p">],</span>
    <span class="s">"GLN"</span><span class="p">:</span> <span class="p">[</span><span class="s">"C"</span><span class="p">,</span> <span class="s">"CA"</span><span class="p">,</span> <span class="s">"CB"</span><span class="p">,</span> <span class="s">"CG"</span><span class="p">,</span> <span class="s">"CD"</span><span class="p">,</span> <span class="s">"N"</span><span class="p">,</span> <span class="s">"NE2"</span><span class="p">,</span> <span class="s">"O"</span><span class="p">,</span> <span class="s">"OE1"</span><span class="p">],</span>
    <span class="s">"GLY"</span><span class="p">:</span> <span class="p">[</span><span class="s">"C"</span><span class="p">,</span> <span class="s">"CA"</span><span class="p">,</span> <span class="s">"N"</span><span class="p">,</span> <span class="s">"O"</span><span class="p">],</span>
    <span class="s">"HIS"</span><span class="p">:</span> <span class="p">[</span><span class="s">"C"</span><span class="p">,</span> <span class="s">"CA"</span><span class="p">,</span> <span class="s">"CB"</span><span class="p">,</span> <span class="s">"CG"</span><span class="p">,</span> <span class="s">"CD2"</span><span class="p">,</span> <span class="s">"CE1"</span><span class="p">,</span> <span class="s">"N"</span><span class="p">,</span> <span class="s">"ND1"</span><span class="p">,</span> <span class="s">"NE2"</span><span class="p">,</span> <span class="s">"O"</span><span class="p">],</span>
    <span class="s">"ILE"</span><span class="p">:</span> <span class="p">[</span><span class="s">"C"</span><span class="p">,</span> <span class="s">"CA"</span><span class="p">,</span> <span class="s">"CB"</span><span class="p">,</span> <span class="s">"CG1"</span><span class="p">,</span> <span class="s">"CG2"</span><span class="p">,</span> <span class="s">"CD1"</span><span class="p">,</span> <span class="s">"N"</span><span class="p">,</span> <span class="s">"O"</span><span class="p">],</span>
    <span class="s">"LEU"</span><span class="p">:</span> <span class="p">[</span><span class="s">"C"</span><span class="p">,</span> <span class="s">"CA"</span><span class="p">,</span> <span class="s">"CB"</span><span class="p">,</span> <span class="s">"CG"</span><span class="p">,</span> <span class="s">"CD1"</span><span class="p">,</span> <span class="s">"CD2"</span><span class="p">,</span> <span class="s">"N"</span><span class="p">,</span> <span class="s">"O"</span><span class="p">],</span>
    <span class="s">"LYS"</span><span class="p">:</span> <span class="p">[</span><span class="s">"C"</span><span class="p">,</span> <span class="s">"CA"</span><span class="p">,</span> <span class="s">"CB"</span><span class="p">,</span> <span class="s">"CG"</span><span class="p">,</span> <span class="s">"CD"</span><span class="p">,</span> <span class="s">"CE"</span><span class="p">,</span> <span class="s">"N"</span><span class="p">,</span> <span class="s">"NZ"</span><span class="p">,</span> <span class="s">"O"</span><span class="p">],</span>
    <span class="s">"MET"</span><span class="p">:</span> <span class="p">[</span><span class="s">"C"</span><span class="p">,</span> <span class="s">"CA"</span><span class="p">,</span> <span class="s">"CB"</span><span class="p">,</span> <span class="s">"CG"</span><span class="p">,</span> <span class="s">"CE"</span><span class="p">,</span> <span class="s">"N"</span><span class="p">,</span> <span class="s">"O"</span><span class="p">,</span> <span class="s">"SD"</span><span class="p">],</span>
    <span class="s">"PHE"</span><span class="p">:</span> <span class="p">[</span><span class="s">"C"</span><span class="p">,</span> <span class="s">"CA"</span><span class="p">,</span> <span class="s">"CB"</span><span class="p">,</span> <span class="s">"CG"</span><span class="p">,</span> <span class="s">"CD1"</span><span class="p">,</span> <span class="s">"CD2"</span><span class="p">,</span> <span class="s">"CE1"</span><span class="p">,</span> <span class="s">"CE2"</span><span class="p">,</span> <span class="s">"CZ"</span><span class="p">,</span> <span class="s">"N"</span><span class="p">,</span> <span class="s">"O"</span><span class="p">],</span>
    <span class="s">"PRO"</span><span class="p">:</span> <span class="p">[</span><span class="s">"C"</span><span class="p">,</span> <span class="s">"CA"</span><span class="p">,</span> <span class="s">"CB"</span><span class="p">,</span> <span class="s">"CG"</span><span class="p">,</span> <span class="s">"CD"</span><span class="p">,</span> <span class="s">"N"</span><span class="p">,</span> <span class="s">"O"</span><span class="p">],</span>
    <span class="s">"SER"</span><span class="p">:</span> <span class="p">[</span><span class="s">"C"</span><span class="p">,</span> <span class="s">"CA"</span><span class="p">,</span> <span class="s">"CB"</span><span class="p">,</span> <span class="s">"N"</span><span class="p">,</span> <span class="s">"O"</span><span class="p">,</span> <span class="s">"OG"</span><span class="p">],</span>
    <span class="s">"THR"</span><span class="p">:</span> <span class="p">[</span><span class="s">"C"</span><span class="p">,</span> <span class="s">"CA"</span><span class="p">,</span> <span class="s">"CB"</span><span class="p">,</span> <span class="s">"CG2"</span><span class="p">,</span> <span class="s">"N"</span><span class="p">,</span> <span class="s">"O"</span><span class="p">,</span> <span class="s">"OG1"</span><span class="p">],</span>
    <span class="s">"TRP"</span><span class="p">:</span> <span class="p">[</span><span class="s">"C"</span><span class="p">,</span> <span class="s">"CA"</span><span class="p">,</span> <span class="s">"CB"</span><span class="p">,</span> <span class="s">"CG"</span><span class="p">,</span> <span class="s">"CD1"</span><span class="p">,</span> <span class="s">"CD2"</span><span class="p">,</span> <span class="s">"CE2"</span><span class="p">,</span> <span class="s">"CE3"</span><span class="p">,</span> <span class="s">"CZ2"</span><span class="p">,</span> <span class="s">"CZ3"</span><span class="p">,</span> <span class="s">"CH2"</span><span class="p">,</span> <span class="s">"N"</span><span class="p">,</span> <span class="s">"NE1"</span><span class="p">,</span> <span class="s">"O"</span><span class="p">],</span>
    <span class="s">"TYR"</span><span class="p">:</span> <span class="p">[</span><span class="s">"C"</span><span class="p">,</span> <span class="s">"CA"</span><span class="p">,</span> <span class="s">"CB"</span><span class="p">,</span> <span class="s">"CG"</span><span class="p">,</span> <span class="s">"CD1"</span><span class="p">,</span> <span class="s">"CD2"</span><span class="p">,</span> <span class="s">"CE1"</span><span class="p">,</span> <span class="s">"CE2"</span><span class="p">,</span> <span class="s">"CZ"</span><span class="p">,</span> <span class="s">"N"</span><span class="p">,</span> <span class="s">"O"</span><span class="p">,</span> <span class="s">"OH"</span><span class="p">],</span>
    <span class="s">"VAL"</span><span class="p">:</span> <span class="p">[</span><span class="s">"C"</span><span class="p">,</span> <span class="s">"CA"</span><span class="p">,</span> <span class="s">"CB"</span><span class="p">,</span> <span class="s">"CG1"</span><span class="p">,</span> <span class="s">"CG2"</span><span class="p">,</span> <span class="s">"N"</span><span class="p">,</span> <span class="s">"O"</span><span class="p">]}</span>
</code></pre></div></div>
<p class="figcaption"><code class="language-plaintext highlighter-rouge">atom14</code> ordering.</p>

<p>We see that depending on whch amino acid we have present, a certain position in a residue array can represent a different atom (for example, position 3 is <code class="language-plaintext highlighter-rouge">CG2</code> for Threonine, <code class="language-plaintext highlighter-rouge">CG1</code> for Valine and <code class="language-plaintext highlighter-rouge">N</code> for Serine). This makes storing this information very efficient, but can be cumbersome if we need to retrieve the coordinates of a certain atom like <code class="language-plaintext highlighter-rouge">N</code> from our data structure.</p>

<p>On the other hand, the <code class="language-plaintext highlighter-rouge">atom37</code> representation has a fixed atom data size for every residue. This ordering can be found in <a href="https://github.com/aqlaboratory/openfold/blob/127f1e7023c380c01330cee45544c23c079babe9/openfold/np/residue_constants.py#L555"><code class="language-plaintext highlighter-rouge">residue_constants.atom_types</code></a>:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># file: "residue_constants.py"
# This mapping is used when we need to store atom data in a format that requires
# fixed atom data size for every residue (e.g. a numpy array).
</span><span class="n">atom_types</span> <span class="o">=</span> <span class="p">[</span>
    <span class="s">"N"</span><span class="p">,</span>
    <span class="s">"CA"</span><span class="p">,</span>
    <span class="s">"C"</span><span class="p">,</span>
    <span class="s">"CB"</span><span class="p">,</span>
    <span class="s">"O"</span><span class="p">,</span>
    <span class="s">"CG"</span><span class="p">,</span>
    <span class="s">"CG1"</span><span class="p">,</span>
    <span class="s">"CG2"</span><span class="p">,</span>
    <span class="s">"OG"</span><span class="p">,</span>
    <span class="s">"OG1"</span><span class="p">,</span>
    <span class="s">"SG"</span><span class="p">,</span>
    <span class="s">"CD"</span><span class="p">,</span>
    <span class="s">"CD1"</span><span class="p">,</span>
    <span class="s">"CD2"</span><span class="p">,</span>
    <span class="s">"ND1"</span><span class="p">,</span>
    <span class="s">"ND2"</span><span class="p">,</span>
    <span class="s">"OD1"</span><span class="p">,</span>
    <span class="s">"OD2"</span><span class="p">,</span>
    <span class="s">"SD"</span><span class="p">,</span>
    <span class="s">"CE"</span><span class="p">,</span>
    <span class="s">"CE1"</span><span class="p">,</span>
    <span class="s">"CE2"</span><span class="p">,</span>
    <span class="s">"CE3"</span><span class="p">,</span>
    <span class="s">"NE"</span><span class="p">,</span>
    <span class="s">"NE1"</span><span class="p">,</span>
    <span class="s">"NE2"</span><span class="p">,</span>
    <span class="s">"OE1"</span><span class="p">,</span>
    <span class="s">"OE2"</span><span class="p">,</span>
    <span class="s">"CH2"</span><span class="p">,</span>
    <span class="s">"NH1"</span><span class="p">,</span>
    <span class="s">"NH2"</span><span class="p">,</span>
    <span class="s">"OH"</span><span class="p">,</span>
    <span class="s">"CZ"</span><span class="p">,</span>
    <span class="s">"CZ2"</span><span class="p">,</span>
    <span class="s">"CZ3"</span><span class="p">,</span>
    <span class="s">"NZ"</span><span class="p">,</span>
    <span class="s">"OXT"</span><span class="p">,</span>
<span class="p">]</span>
<span class="n">atom_order</span> <span class="o">=</span> <span class="p">{</span><span class="n">atom_type</span><span class="p">:</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">atom_type</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">atom_types</span><span class="p">)}</span>
<span class="n">atom_type_num</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">atom_types</span><span class="p">)</span>  <span class="c1"># := 37.
</span></code></pre></div></div>

<p class="figcaption"><code class="language-plaintext highlighter-rouge">atom37</code> ordering.</p>

<p>Here we can see that the ordering is always the same no matter which residue is represented; however, most of the fields will always be empty since the longest amino acid (tryptophane) has only 14 atoms. We therefore exchange efficiency for standardisation, which explains why internally AF2 often uses <code class="language-plaintext highlighter-rouge">atom14</code>, but when it interfaces to other programs at I/O it often uses <code class="language-plaintext highlighter-rouge">atom37</code>.</p>

<p>If we think about our example of <code class="language-plaintext highlighter-rouge">Ser</code> again, we can see how the machine representations map to the actual amino acid (again with the caveat that hydrogens are ommited and the carboxyl oxygen is not counted since in a peptide backbone it will have let as water).</p>

<p><img src="/assets/img/blog/prot_representation/serine_repr.jpeg" alt="serine example" /></p>

<table>
  <thead>
    <tr>
      <th>Category</th>
      <th>Atom14</th>
      <th>Atom37</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Memory Requirements</td>
      <td>Efficient</td>
      <td>Wasteful</td>
    </tr>
    <tr>
      <td>Data Layout</td>
      <td>Varying Shape</td>
      <td>Fixed Shape</td>
    </tr>
    <tr>
      <td>Sequence Dependence</td>
      <td>Yes</td>
      <td>No</td>
    </tr>
  </tbody>
</table>

<h3 id="boundary-conditions-oxt">Boundary Conditions: OXT</h3>

<p>I have been talking before about the oxygen atom of the carboxy group that is lost when two amino acids combine to form a peptide bond. Well, that is true for all amino acids except the last one at the C-terminus since there the carboxy group will still be free and has two oxygen atoms. At physiological pH the carboxy group will be deprotonated so both oxygen atoms are chemically equivalent with equal bond lengths (as opposed to the single and double bond image we always draw on paper), but our <a href="https://chemistry.stackexchange.com/questions/22245/what-does-c-oxt-stand-for-in-pdb-files">file formats still require us</a> to name one of the oxygens at the terminus as a “normal” oxygen, i.e. <code class="language-plaintext highlighter-rouge">O</code> and the other one <code class="language-plaintext highlighter-rouge">OXT</code>. You will therefore often see the last atom in a protein structure being <code class="language-plaintext highlighter-rouge">OXT</code>, such as in <a href="https://biopandas.github.io/biopandas/tutorials/Working_with_PDB_Structures_in_DataFrames/">this Biopandas tutorial</a>. When I say “often”, I mean “not always”; the termini of protein are known to often be too flexible to crystallise, therefore the structure in our PDB files will often end prior to the C-terminus and not contain an OXT. This is not super problematic since given the planarity of the delocalised carboxy electron system, one can place the OXT easily given the carbon and the other oxygen atom. Predicted structures such as the ones from AlphaFold2 on the other hand will always contain the OXT atom since they do not have to battle experimental resolution problems.</p>

<h3 id="example-lysozyme-atom-numbering">Example: Lysozyme atom numbering</h3>

<p>Let us now visualise the concepts we looked at so far (atom names and atom representations) with a concrete example, again based on the lysozyme structure with the PDB code <code class="language-plaintext highlighter-rouge">168l</code>. Install PyMol (either the <a href="https://pymol.org/">commercial</a> or the <a href="https://github.com/schrodinger/pymol-open-source?tab=readme-ov-file">open-source</a> version) and open the program.</p>

<p class="note">If you have not used PyMol before, you can either skip this section or look at <a href="https://structural-bioinformatics.netlify.app/blog/proteins/2023-02-01-lesson1/">this lesson from my Structural Bioinformatics course</a> that goes over this in detail.</p>

<p>Then, execute the following commands via the integrated terminal:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">fetch</span> <span class="mi">168</span><span class="n">lA</span> <span class="c1"># get first chain of lysozyme assembly
</span><span class="n">select</span> <span class="n">selection</span><span class="p">,</span> <span class="n">resi</span> <span class="mi">11</span><span class="o">-</span><span class="mi">15</span> <span class="c1"># select a subset of residues for simplicity
</span><span class="n">hide</span> <span class="n">everything</span> <span class="c1"># hide the whole structure for clarity
</span><span class="n">show</span> <span class="n">sticks</span><span class="p">,</span> <span class="n">selection</span> <span class="c1"># show stick representation for the selected subset; carbon is green, oxygen is red, nitrogen is blue
</span><span class="n">color</span> <span class="n">yellow</span><span class="p">,</span> <span class="p">(</span><span class="n">name</span> <span class="n">CG</span><span class="p">)</span> <span class="c1"># color all CG atoms yellow
</span><span class="n">color</span> <span class="n">orange</span><span class="p">,</span>  <span class="p">(</span><span class="n">name</span> <span class="n">NH1</span><span class="p">)</span> <span class="c1"># color the single NH1 atom orange
</span></code></pre></div></div>

<p>After doing this, you should see something like this:</p>

<p><img src="/assets/img/blog/prot_representation/pymol_structure.png" alt="pymol_structure" /></p>

<p>We can compare this to a schematic sketch of this protein segment, similar to what we did before with serine:</p>

<p><img src="/assets/img/blog/prot_representation/chain_repr.jpeg" alt="chain example" /></p>

<p class="figcaption">Schematic representation of the selection from our protein, with the coloring imitating our color scheme in PyMol.</p>

<p>We can see that PyMol knows about the atom naming convention we discussed and can select and color residues accordingly. It does this by parsing the information it gets from the PDB file and storing this inside the structure object it displays.</p>

<p>We can do the same thing programmatically by using a library such as <a href="https://www.biotite-python.org/index.html">Biotite</a>.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">biotite.structure</span> <span class="k">as</span> <span class="n">struc</span>
<span class="kn">import</span> <span class="nn">biotite.structure.io.mmtf</span> <span class="k">as</span> <span class="n">mmtf</span>
<span class="kn">import</span> <span class="nn">biotite.database.rcsb</span> <span class="k">as</span> <span class="n">rcsb</span>

<span class="n">mmtf_file</span> <span class="o">=</span> <span class="n">mmtf</span><span class="p">.</span><span class="n">MMTFFile</span><span class="p">.</span><span class="n">read</span><span class="p">(</span><span class="n">rcsb</span><span class="p">.</span><span class="n">fetch</span><span class="p">(</span><span class="s">"168l"</span><span class="p">,</span> <span class="s">"mmtf"</span><span class="p">))</span>
<span class="n">structure</span> <span class="o">=</span> <span class="n">mmtf</span><span class="p">.</span><span class="n">get_structure</span><span class="p">(</span><span class="n">mmtf_file</span><span class="p">,</span> <span class="n">model</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>

<span class="n">chain_A</span> <span class="o">=</span> <span class="n">structure</span><span class="p">[</span>
    <span class="p">(</span><span class="n">structure</span><span class="p">.</span><span class="n">chain_id</span> <span class="o">==</span> <span class="s">"A"</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">structure</span><span class="p">.</span><span class="n">hetero</span> <span class="o">==</span> <span class="bp">False</span><span class="p">)</span>
<span class="p">]</span>
<span class="k">print</span><span class="p">(</span><span class="n">chain_A</span><span class="p">.</span><span class="n">res_id</span><span class="p">)</span> <span class="c1"># array([  1,   1,   1, ..., 164, 164, 164])
</span><span class="n">selection</span> <span class="o">=</span> <span class="n">chain_A</span><span class="p">[(</span><span class="n">chain_A</span><span class="p">.</span><span class="n">res_id</span> <span class="o">&gt;</span> <span class="mi">10</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">chain_A</span><span class="p">.</span><span class="n">res_id</span> <span class="o">&lt;=</span> <span class="mi">15</span><span class="p">)]</span>
<span class="k">print</span><span class="p">(</span><span class="n">selection</span><span class="p">.</span><span class="n">res_id</span><span class="p">)</span> <span class="c1"># [ 11 11 ... 15 15 ]
</span><span class="k">print</span><span class="p">(</span><span class="n">selection</span><span class="p">.</span><span class="n">array_length</span><span class="p">())</span> <span class="c1"># 40
</span></code></pre></div></div>

<p>We see that our selection contains 40 atoms. We can check if that corresponds to the amino acids we wanted to select by checking how many non-hydrogen atoms each of these amino acids have and by subtracting on average 1 oxygen atom per amino acid for the formation of the peptide bond.</p>

<p><img src="/assets/img/blog/prot_representation/amino_acids.png" alt="amino_acids" /></p>

<p class="figcaption">Proteinogenic amino acids and some of their properties. Source: <a href="https://en.wikipedia.org/wiki/File:Overview_proteinogenic_amino_acids-DE.svg">Wikipedia</a></p>

<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mtable rowspacing="0.25em" columnalign="right left right" columnspacing="0em 1em"><mtr><mtd class ="mtr-glue"></mtd><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow><mi>E</mi><mo>+</mo><mi>G</mi><mo>+</mo><mn>2</mn><mi>L</mi><mo>+</mo><mi>R</mi><mo>−</mo><mn>5</mn></mrow></mstyle></mtd><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow><mrow></mrow><mo>=</mo><mn>10</mn><mo>+</mo><mn>5</mn><mo>+</mo><mn>2</mn><mo>∗</mo><mn>9</mn><mo>+</mo><mn>12</mn><mo>−</mo><mn>5</mn></mrow></mstyle></mtd><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow><mo>=</mo><mn>40</mn></mrow></mstyle></mtd><mtd class ="mtr-glue"></mtd><mtd class ="mml-eqn-num"></mtd></mtr></mtable><annotation encoding="application/x-tex">\begin{align}
E + G + 2L + R - 5
&amp;= 10 + 5 + 2*9 + 12 - 5
&amp;= 40
\end{align}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.5em;vertical-align:-0.5em;"></span><span class="mtable"><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1em;"><span style="top:-3.16em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05764em;">E</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord mathnormal">G</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord">2</span><span class="mord mathnormal">L</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord mathnormal" style="margin-right:0.00773em;">R</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord">5</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.5em;"><span></span></span></span></span></span><span class="col-align-l"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1em;"><span style="top:-3.16em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mord">10</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord">5</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord">2</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">∗</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord">9</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord">12</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord">5</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.5em;"><span></span></span></span></span></span><span class="arraycolsep" style="width:1em;"></span><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1em;"><span style="top:-3.16em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mord">40</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.5em;"><span></span></span></span></span></span></span></span><span class="tag"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1em;"><span style="top:-3em;"><span class="pstrut" style="height:2.84em;"></span><span class="eqn-num"></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.5em;"><span></span></span></span></span></span></span></span></span>

<h2 id="reference-systems-local-reference-frames-vs-reference-free-methods">Reference Systems: Local reference frames vs reference-free methods</h2>

<p>We now have covered how we go from the database formats for protein structures (PDBx/mmCIF, MMTF and BinaryCIF) to the formats commonly used as inputs for machine learning models (atom14, atom37). The question now is: what do the machine learning models do with this input information?</p>

<p>Given that we deal with geometric quantities such as coordinates of protein structures, considerations like invariance and equivariance come into play. There is a whole field called <a href="https://geometricdeeplearning.com/"><em>Geometric Deep Learning</em></a> that deals with with these considerations. For the usage of machine learning models for protein structure, it is important to understand the distinction between <em>reference-free</em> and <em>reference-based</em> methods.</p>

<p class="note">To learn more about geometric deep learning, you can either check out the <a href="https://arxiv.org/abs/2104.13478">protobook by Bronstein et al.</a>, the <a href="https://arxiv.org/abs/2312.07511">Hitchhiker’s guide to geometric GNNs</a> or <a href="https://structural-bioinformatics.netlify.app/blog/proteins/2023-08-02-lesson5/">this lecture</a> I gave on the topic.</p>

<p><img src="/assets/img/blog/prot_representation/geometric_gnn_overview.png" alt="geometric_gnn_overview" /></p>

<p>If we predict some molecular property (such as binding affinity, solubility or immunogenicity) it is quite obvious to a human that rotations or translations of the protein should not change the prediction of these quantities. A neural network, however, just sees different numbers when a protein is translated and therefore needs to learn that these different inputs correspond to the same protein. This can be done via <a href="https://en.wikipedia.org/wiki/Data_augmentation">data augmentation</a>, but this can become data-inefficient. Therefore, people looked for ways to build this inductive bias of invariance or equivariance to <a href="https://arxiv.org/abs/2103.15980">SE(3) group actions</a> (i.e. rotations and translations) into the model.</p>

<h3 id="local-reference-based-methods">Local reference-based methods</h3>

<p>On one hand, some models leverage <em>reference-based</em> methods, largely following the example of the original AlphaFold2 model. Here, a local reference frame for each residue is defined based on the backbone geometry, with the translational component being equal to the CA position and the rotational component originating from a Gram-Schmidt orthogonalisation with respect to the CA-C and the CA-N bond vector.</p>

<p>Here a paragraph from the <a href="https://arxiv.org/abs/2312.07511">Hitchhiker’s Guide to Geometric GNNs for 3D Atomic Systems</a> that summarizes the current state in this field of research:</p>

<p class="note">Canonical frame-based invariant GNNs. Canonical frame-based GNNs [Liu et al., 2022, Wang
et al., 2022a] use a local or global frame of reference to scalarise geometric quantities into invariant
features which are used for message passing, offering an alternative technique when canonical
reference frames can be defined. Most notably, the Invariant Point Attention layer (IPA) from
AlphaFold2 [Jumper et al., 2021] defines canonical local reference frames at each residue in the
protein backbone centred at the alpha Carbon atom and using the Nitrogen and adjacent Carbon atoms.
Other invariant GNNs for protein structure modelling also process similar local reference frames
[Ingraham et al., 2019, Wang et al., 2023b]. IPA is an invariant message passing layer operating
on an all-to-all graph of protein residues. In each IPA layer, each node creates a geometric feature
(position) in its local reference frame via a learnable linear transformation of its invariant features.
To aggregate features from neighbours, neighbouring nodes’ positions are first rotated into a global
reference frame where they can be composed with their invariant features (via an invariant attention
mechanism), followed by rotating the aggregated features back into local reference frames at each
node and projecting back to update the invariant features.</p>

<p>These canonical local reference frames <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>T</mi><mo>=</mo><mo stretchy="false">(</mo><mi>r</mi><mo separator="true">,</mo><mi>x</mi><mo stretchy="false">)</mo><mo>∈</mo><mtext>SE(3)</mtext></mrow><annotation encoding="application/x-tex">T = (r, x) \in \text{SE(3)}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.13889em;">T</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.02778em;">r</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord text"><span class="mord">SE(3)</span></span></span></span></span> can be used to deal with quantities in a SE(3)-invariant way. Importantly, the orientational nature of the frame allows us to be SE(3)-invariant but not E(3)-invariant, i.e. reflections are still accounted for. This is important for biological applications since <a href="https://www.ncbi.nlm.nih.gov/pmc/articles/PMC5765859/">chirality</a> plays a huge role in biomolecular interactions.</p>

<h4 id="why-se3-instead-of-e3-equivariance-can-be-important">Why SE(3) instead of E(3) equivariance can be important</h4>

<p>As an example of why this is important, we can look at the task of protein structure prediction that <a href="https://www.nature.com/articles/s41586-021-03819-2">AlphaFold2</a> tackled.</p>

<p class="note">To learn more about AlphaFold2 and the problem of protein structure prediction, you can either check out the <a href="https://www.youtube.com/watch?v=yqeUH4RsJp8">3-part lecture series about AF2 by Nazim Bouatta</a> or <a href="https://structural-bioinformatics.netlify.app/blog/proteins/2023-08-03-lesson6/">this lecture</a> I gave on the topic.</p>

<p>Here, one important metric for measuring prediction accuracy is the GDT score. To get good at maximising this score, a natural way to think about it is to take your predicted coordinates, compare them to the ground-truth coordinates and compute something like an RMSD loss. However, this does not take rototranslations into account of course. We can remedy that by calculating a <a href="https://web.stanford.edu/class/cs273/slides/conformational-space.ppt">dRMSD loss</a>, i.e. a RMSD loss on all pairwise distances in the structure. By using these internal coordinates, we are invariant to rototranslations.</p>

<p>However, we are also invariant to reflections! When training AlphaFold2, the team at DeepMind tested what would happen if they used this dRMSD loss for training a model.</p>

<p><img src="/assets/img/blog/prot_representation/gdt_fape_comparison.png" alt="gdt_fape_comparison" /></p>

<p class="figcaption">Results if AF2 is trained with dRSMD instead of FAPE as loss. Source: <a href="https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf">AF2 SI</a>, page 36, section 1.9.3.</p>

<p>You can see that while the local structure of predicted proteins (as measured by the <a href="https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3799472/">lddt-CA score</a>) seem very good, the global structure (as measured by the <a href="https://en.wikipedia.org/wiki/Global_distance_test">GDT score</a>) seems to follow a bimodal distribution, with half the predictions performing well and the other half faring badly. Could this be due to the reflection invariance of the dRMSD loss? When calculating the GDT score with respect to the mirror image structure, the team observed a reversal of the distribution! Finally, when looking at the maximum of these two scores (one calculated with respect to the ground truth structure and one with respect to its mirror image), the model shows strong performance, indicating that the issue was indeed the reflection-invariant dRSMD loss.</p>

<p>Here frames come to our rescue and allow the definition of the so-called FAPE loss (frame-aligned point error, minimal implementation <a href="https://github.com/wangleiofficial/FAPEloss">here</a>). With their help, we can compute distance-like quantities, but in a reflection-aware way. How do we do that? We can take a predicted position <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>x</mi><mi>j</mi></msub></mrow><annotation encoding="application/x-tex">x_j</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.7167em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span> and compute its position relative to the predicted frame of a different residue <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>T</mi><mi>i</mi></msub></mrow><annotation encoding="application/x-tex">T_i</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">T</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:-0.1389em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span>. With this, we effectively get a displacement vector which is however reflection-aware due to the rotational component of the frame transformation.</p>

<p>We can do the same thing for the ground-truth positions and frames that can be computed for the same combination of residues and score the difference as a RMSD-like quantity. This is what the FAPE loss amounts to.</p>

<p><img src="/assets/img/blog/prot_representation/fape_columbia.png" alt="fape_columbia" /></p>

<p class="figcaption">FAPE loss visualisation for a single pair of residues. Source: <a href="https://www.youtube.com/watch?v=ri39B0Voujc">YouTube talk at HMS</a>.</p>

<p>An equivalent way of visualising this involves not looking at a single pair of residues, but considering it in the context of the whole structure. Here, we align the predicted and the target structure based on frame <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>T</mi><mi>i</mi></msub></mrow><annotation encoding="application/x-tex">T_i</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">T</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:-0.1389em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> and then calculate the L2 norm of all the other residues with respect to this specific alignment. We can then repeat this for all residues in the sequence to calculate the overall FAPE loss.</p>

<p><img src="/assets/img/blog/prot_representation/fape_epfl.png" alt="fape_epfl" /></p>

<p class="figcaption">FAPE loss in the context of the whole structure. Source: <a href="https://www.ncbi.nlm.nih.gov/pmc/articles/PMC10204179/">AF2Seq paper</a>.</p>

<p>Note that there are different versions of the FAPE loss used in different parts of the model; while the final FAPE loss computes these L2 norms for all atoms, the intermediate FAPE loss only considers the CA positions.</p>

<p>This type of frame definition is by no means the only way you can construct frames; <a href="https://www.nature.com/articles/s41587-022-01432-w">RGN2</a>, another model for protein structure prediction instead uses Frenet–Serret frames to model the protein backbone.</p>

<h4 id="ambivalent-mappings-from-frames-to-coordinates">Ambivalent mappings from frames to coordinates</h4>

<p>At the end of AlphaFold2, the algorithm has to again map the frame-based representation into 3D coordinates. This should not be a problem since we have our backbone frames that allows to reconstruct the backbone positions, and we predict the torsion angles of the rigid groups in the side chains so that we can place all atoms correctly according to the following table.</p>

<p><img src="/assets/img/blog/prot_representation/af_rigids_table.png" alt="af_rigids_table" /></p>

<p class="figcaption">Rigid groups for constructing all atoms from given torsion angles. Source: <a href="https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf">AF2 SI, Table 2</a></p>

<p>However, you can see a few boxed atoms in the table. These atoms are symmetric under 180 degree rotations, such as five of the six atoms in the phenyl ring of phenylalanine (PHE) and tyrosine (TYR) or the terminal carboxyl oxygens in asparagine (ASP) and glutamate (GLU).</p>

<p>Some of these atoms are on the rotation axis such as the terminal carbon atom in the phenyl rings and are therefore invariant to the rotation; some of the other atoms however swap positions due to the 180 degree rotation symmetry and their atom names are therefore ambiguous.</p>

<p>AlphaFold deals with this by renaming the atoms in a globally consistent way via lDDT loss computations (see <a href="https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf">algorithm 26 in the SI</a> and <a href="https://github.com/aqlaboratory/openfold/blob/127f1e7023c380c01330cee45544c23c079babe9/openfold/np/residue_constants.py#L1341">this part</a> of the OpenFold codebase).</p>

<p><img src="/assets/img/blog/prot_representation/af_renaming_table.png" alt="af_renaming_table" /></p>

<p class="figcaption">Renaming convention for ambivalent atom placements. Source: <a href="https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf">AF2 SI, Table 3</a></p>

<p>Another problem that comes from this ambiguity is that the network can in theory predict to valid values for the torsion angle of these rigid groups, <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>χ</mi></mrow><annotation encoding="application/x-tex">\chi</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.1944em;"></span><span class="mord mathnormal">χ</span></span></span></span> as well as <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>χ</mi><mo>+</mo><mi>π</mi></mrow><annotation encoding="application/x-tex">\chi + \pi</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.7778em;vertical-align:-0.1944em;"></span><span class="mord mathnormal">χ</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">π</span></span></span></span>. AlphaFold therefore allows the network to predict both angles by giving it both the predicted and the possible alternative angle (in the case of non-symmetric configurations, they are both set to the predicted value). In this way, the network is allowed to learn both valid values.</p>

<h3 id="reference-free-methods-invariant-and-equivariant-update-functions">Reference-free methods: Invariant and Equivariant Update Functions</h3>

<p>We do not necessarily need to represent our structures as frames where we define a local reference coordinate system, but can also directly operate on our coordinates as long as we update our representation at every layer in a way that properly leverages these symmetries (e.g. by SE(3) invariance or equivariance).</p>

<p>Examples that leverage this approach include <a href="https://arxiv.org/abs/2009.01411">GVP-GNN</a> which defines equivariant update functions as well as <a href="https://arxiv.org/abs/1706.08566">SchNet</a> and <a href="https://arxiv.org/abs/2003.03123">DimeNet</a> that leverage invariant update functions (message passing functions in GNN-speak).</p>

<p class="note">To learn more about how these different approaches can be classified, I recommend both <a href="https://proceedings.mlr.press/v202/joshi23a.html">this paper</a> as well as the <a href="https://arxiv.org/abs/2312.07511">Hitchhiker’s guide to geometric GNNs</a>.</p>

<p>Leaving the GNN camp for a bit, <a href="https://arxiv.org/abs/2310.02508">Ophiuchus</a> showed that one can use hierarchical autoencoders to operate over protein structures which are represented by CA atoms and geometric features attached to them that describe the other atomic positions. They employ SE(3)-equivariant convolutions to operate on this representation and demonstrate its usage for compression and structure generation.</p>

<h3 id="screw-these-symmetries-data-augmentation-and-other-strategies">Screw these symmetries: data augmentation and other strategies</h3>

<p>Frame-based representations have been successfull in AlphaFold and have since been used in many other models, both supervised and generative, for example <a href="https://www.nature.com/articles/s41586-023-06415-8">RFDiffusion</a> and <a href="https://www.nature.com/articles/s41586-023-06728-8">Chroma</a>. However, defining things like diffusion processes over these frames becomes <a href="https://arxiv.org/abs/2302.02277">quite a bit harder</a>, and if you additionally deal with sidechains and other details, frames might be too cumbersome for your use case.</p>

<p>Other models therefore do not use frames, but some kind of internal coordinates that can be used without explicitly considering these symmetry constraints. Some examples of this include <a href="https://www.cell.com/cell-systems/pdf/S2405-4712(19)30076-6.pdf">RGN</a> and <a href="https://www.nature.com/articles/s41467-024-45051-2">FoldingDiff</a> that leverage torsion angles or <a href="https://www.nature.com/articles/s43588-023-00440-3">ProteinSGM</a> that leverages a mixture of torsion angles and backbone distances.</p>

<p>Another strategy that does not involve dealing with symmetries is - well, not dealing with symmetries. <a href="https://www.biorxiv.org/content/10.1101/2023.05.24.542194v1.full">Protpardelle</a> is a protein diffusion model that operates on pure coordinate representations via a vision transformer and does some rotational and translational data augmentation to account for these symmetries. Finally, in the small molecule world, the <a href="https://arxiv.org/abs/2311.17932">Molecular Conformer Fields paper</a> showed that empirically, not enforcing these symmetry constraints explicitly can still lead to SOTA performance, sparking <a href="https://twitter.com/itsbautistam/status/1734929304440479791">quite a discussion on Twitter</a>.</p>

<h2 id="batching-padded-versus-sparse">Batching: Padded versus sparse</h2>

<p>We’ve now covered the whole pipeline, starting from database formats over input formats to network-internal representations to properly handle symmetries. A final consideration comes into play when we think about <a href="https://machinelearningmastery.com/difference-between-a-batch-and-an-epoch/">batching</a>, a commonly used technique in machine learning where you do not pass your samples one by one into the network, but combine them into a bigger tensor to achieve better hardware utilisation and therefore training performance.</p>

<p class="note">There are many subtleties to choosing your batch size since generally we perform a gradient update step after each of these batches; therefore, the batch size not only influences training performance but also accuracy by changing the dynamics of our gradient descent procedure. I won’t go into detail here on that, but recommend <a href="https://karpathy.github.io/2019/04/25/recipe/">Andrej Karpathy’s blog</a> on general recipes for training neural networks.</p>

<h3 id="the-batching-pain-with-variable-length-input">The batching pain with variable-length input</h3>

<p>This batching of tensors is trivial in many computer vision use cases since often all your images are of the same size; you can therefore just stack them along a new dimension and ready is your batch.</p>

<p>For protein structures, it is a bit more complicated due to variable length. One strategy to deal with this involves <a href="https://huggingface.co/docs/transformers/main/en/pad_truncation">padding and trunction</a>. Here, we choose some maximum length for our batch and pad structures that are shorter than this via padding tokens (for coordinates this can be 0 or a small value that is unlikely to occur exactly like this in the data) and truncate structures that are longer than this (either randomly or via some biologically defined domain boundaries). This solves our issue, but introduces new ones: often, we do not want to truncate data since we may lose important information. If we now always choose the longest structure in a batch as the maximum length, we may end up with very inefficient training if there are very short sequences in the batch and padding tokens begin to represent a significant part of our batch.</p>

<h3 id="efficient-padding-via-length-batching">Efficient padding via length batching</h3>

<p>To circumvent this, people took inspiration from NLP. In the transformer paper, for example, it is stated that to circumvent the inefficient padding issue, <a href="https://arxiv.org/abs/1706.03762"><em>sentence pairs were batched together by approximate sequence length</em></a>, resulting in more optimal padding. This has been replicated for example in <a href="https://github.com/microsoft/protein-frame-flow/blob/1c5ad9c28a1264e449d98c382123bb48227d9d97/data/pdb_dataloader.py#L162">generative models for protein structure</a>. This change might influence training dynamics since now the model sees similarly-sized inputs inside every batch, but empirically still seems to work fine.</p>

<h3 id="sparse-batching">Sparse batching</h3>

<p>In the previous section we talked about the usage of GNNs (graph neural networks) for protein structures. A popular library in the field of GNNs is <a href="https://pytorch-geometric.readthedocs.io/en/latest/index.html#">PyG</a> (PyTorch Geometric) that can be used for all kinds of graph-structure data.</p>

<p>In contrast to the padding-and-truncation approach I mentioned before, they opt for a sparse batching procedure they term <a href="https://pytorch-geometric.readthedocs.io/en/latest/advanced/batching.html">advanced mini-batching</a>.</p>

<p>Here, we treat the our graph data points in a batch as <em>one single datapoint</em> and use pointers to tell us about the boundaries between these. In practice, we concatenate all our node features along an existing dimension instead of stacking them along a new dimension, making padding and truncation obsolete.</p>

<p><img src="/assets/img/blog/prot_representation/pyg_batching.png" alt="PyG batching" /></p>

<p class="figcaption">Advanced mini-batching in Pytorch Geometric. Source: <a href="https://pytorch-geometric.readthedocs.io/en/latest/advanced/batching.html">PyG Docs</a></p>

<p>We do something similar for the adjacency matrix which indicates the connectivity in the graphs. Stacking these in a block-diagonal fashion allows us to reuse existing algorithms for GNNs such as <a href="https://danielegrattarola.github.io/posts/2021-03-12/gnn-lecture-part-2.html">message-passing</a> without having to change implementations. In addition, since the majority of elements in this matrix will be zero, we can use <a href="https://glaringlee.github.io/sparse.html">sparse representations</a> that allow us to deal with this in a memory-efficient way.</p>

<p>If you inspect protein structures represented in this PyG format (such as in the <a href="https://proteins.sh/">ProteinWorkshop project</a> we recently published), you can see that a graph will look like this:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>DataBatch(
  coords=[7241, 37, 3],
  residues=[32], 
  residue_id=[7241], 
  chains=[7241], 
  seq_pos=[7241, 1], 
  batch=[7241], 
  ptr=[33])
</code></pre></div></div>

<p>In contrast, this same batch in the “dense” format that uses padding would look like this:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>DataBatch(
  coords=[32, 385, 37, 3],
  residues=[32],
  residue_id=[32, 385],
  chains=[32, 385],
  seq_pos=[32, 385, 1])
</code></pre></div></div>

<p>We can notice several differences:</p>
<ul>
  <li>the dense format represents the batch as an explicit tensor dimension (first dimension of size 32) in all attributes. This dimension is not apparent in the PyG batch except for the attributes that are graph-level attributes and therefore do not change with the size of the graph (<code class="language-plaintext highlighter-rouge">residues</code> is an example here, for each graph it is a single list).</li>
  <li>we can see in the dense batch that the longest protein structure in this batch is 385 residues (apparent in for example the <code class="language-plaintext highlighter-rouge">residue_id</code> attribute, a numerical encoding of the amino acid type). In the PyG batch, we can see that stacked together all amino acids in the batch sum to 7241. If you compare 7241 to 32*385 = 12320, we can see that padding introduces around 40% of memory overhead compared to the efficiently batched representation.</li>
  <li>the PyG batch stores the batching information not in a separate dimension, but in separate attributes: <code class="language-plaintext highlighter-rouge">batch</code> indicates for each node in the batch to which graph in the batch it belongs, and <code class="language-plaintext highlighter-rouge">ptr</code> contains pointers to the boundaries between all the graphs in the batch to enable efficient indexing and information retrieval.</li>
</ul>

<p>Interconversion from dense to PyG format and back is easy to do if all of the graphs are the same size: we can use the <a href="https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/loader/dense_data_loader.html">PyG DenseDataLoader</a> for that.</p>

<p>In the padded case, there is no such functionality yet, but there might soon be a <a href="https://github.com/pyg-team/pytorch_geometric/pull/8518">DensePaddingDataLoader</a> that does exactly that.</p>

<h2 id="afdb-esmatlas--co-how-to-deal-with-large-databases">AFDB, ESMAtlas &amp; co: how to deal with large databases</h2>

<p>The PDB as a database of experimental protein structures keeps growing, currently standing at <a href="https://www.rcsb.org/">nearly 218k</a> entries. However, it seems small compared to the <a href="https://academic.oup.com/nar/article/50/D1/D439/6430488">AlphaFoldDB (&gt;200m)</a> and <a href="https://esmatlas.com/">ESMAtlas (772m structures)</a>, powered by the recent advances in protein structure prediction via methods like <a href="https://www.nature.com/articles/s41586-021-03819-2">AlphaFold2</a> and <a href="https://www.science.org/doi/10.1126/science.ade2574">ESMFold</a>.</p>

<p>This development changed the game in protein biology. While until recently the <a href="https://moalquraishi.wordpress.com/2019/04/01/the-future-of-protein-science-will-not-be-supervised/">gap between available protein sequences and structures widened further and further</a>, we suddenly have a wealth of structural information that was unimaginable a decade ago. This quote from Mohammed AlQuraishi (Columbia University) sums up this paradigm shift well:</p>

<blockquote class="lead">
  <p>Everything we did with protein sequences we can now do with protein structures</p>
</blockquote>

<p>While that is a theoretically true and very exciting prospect, there is one big problem: we do not have tools to deal with such amounts of structural data. Here a visual comparison between the size of the PDB and the AFDB:</p>

<p><img src="/assets/img/blog/prot_representation/afdb_size.png" alt="afdb_size" /></p>

<p class="figcaption">Visual comparison of the size of the PDB vs the AFDB. Source: <a href="https://www.youtube.com/watch?v=IJtWTxhuunk">YouTube</a></p>

<p>You can see that we deal with a different order of magnitude in data here. This brings up a plethora of issues, starting from pure memory usage (the storage for AFDB is 23 TB) to questions of how we move these enormous amounts of data and also process them.</p>

<p>Many groups have developed tools in the last years to tackle this issue. Especially the <a href="https://steineggerlab.com/en/">Steinegger lab</a> has produced some fantastic tools in that space from which I want to present three here in this blogpost: Foldcomp for structure compression, Foldseek for structure clustering and mmseqs for sequence clustering (also very important in that context for generating both input MSAs and training splits).</p>

<p>Many groups have developed tools in the last years to tackle this issue. Especially the <a href="https://steineggerlab.com/en/">Steinegger lab</a> has produced some fantastic tools in that space. If you want to read more about these tools, I have a <a href="">separate blogpost</a> describing three of them in detail: Foldcomp for structure compression, Foldseek for structure clustering and mmseqs for sequence clustering (also very important in that context for generating both input MSAs and training splits).</p>

<p><img src="/assets/img/blog/prot_representation/steinegger_tools.png" alt="steinegger_tools" /></p>

<p class="figcaption">Tools from the Steinegger Lab. Source: <a href="https://www.youtube.com/watch?v=IJtWTxhuunk">YouTube</a></p>

<h2 id="summary">Summary</h2>

<p>In this post we discussed four different levels of information representation:</p>
<ol>
  <li>We started with the data formats in which protein structures are stored and transmitted and the evolution they underwent in the last decades.</li>
  <li>After that we looked at how both sequence and structure information can be converted into a format that can be used by machine learning algorithms, specifically the <code class="language-plaintext highlighter-rouge">atom14</code> and <code class="language-plaintext highlighter-rouge">atom37</code> format.</li>
  <li>Once inside the network, we discussed how different methods leverage this information differently, either via reference-based or reference-free methods, looking at how we can deal with geometric information while respecting the symmetries inherent to it.</li>
  <li>Finally, we looked at how different frameworks deal with the variable length of protein structures and how this affects batching behaviour.</li>
</ol>

<p>I hope that this post can shine some light not only on which representations are used in which circumstances but also why. If you have feedback let me know!</p>]]></content><author><name>Kieran Didi</name><email>kieran.didi@gmail.com</email></author><category term="proteins" /><summary type="html"><![CDATA[How algorithms such as AlphaFold turn PDB structures into a format that they can process]]></summary><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" url="/assets/img/blog/prot_representation/protein_bits.png" /><media:content medium="image" url="/assets/img/blog/prot_representation/protein_bits.png" xmlns:media="http://search.yahoo.com/mrss/" /></entry><entry><title type="html">How does Pytorch implement a linear layer?</title><link href="/blog/ml/2024-01-10-pytorch-whirlwind/" rel="alternate" type="text/html" title="How does Pytorch implement a linear layer?" /><published>2024-01-10T00:00:00+00:00</published><updated>2024-04-07T17:56:12+00:00</updated><id>/blog/ml/pytorch-whirlwind</id><content type="html" xml:base="/blog/ml/2024-01-10-pytorch-whirlwind/"><![CDATA[<p>PyTorch is <em>the</em> deep learning library. It is used by researchers and practitioners alike to build and train neural networks. It is also open source, which means that we can look at the source code to understand how it works. This is especially useful if we want to understand how a specific operation is implemented.</p>

<p>In my post about <a href="">GPU programming in PyTorch</a>, we saw that calling a linear layer in PyTorch via <code class="language-plaintext highlighter-rouge">torch.nn.Linear</code> results in a call to the <code class="language-plaintext highlighter-rouge">aten::addmm</code> function. The ATen library is part of the <a href="https://pytorch-dev-podcast.simplecast.com/episodes/c-frontend">PyTorch C++ API</a> and is responsible for the tensor operations in PyTorch. So if we want to understand how the linear layer is implemented in PyTorch, we need to dig into C++ code and understand how the <code class="language-plaintext highlighter-rouge">aten::addmm</code> function is implemented. This is a bit of a convoluted process, but I  hope that in the process you learn as much about the PyTorch codebase as I did when I went down this rabbit hole.</p>

<ul id="markdown-toc">
  <li><a href="#pytorch-docs-and-the-dispatcher" id="markdown-toc-pytorch-docs-and-the-dispatcher">PyTorch Docs and the Dispatcher</a></li>
  <li><a href="#native-functions-and-the-codegen-pipeline" id="markdown-toc-native-functions-and-the-codegen-pipeline">Native functions and the codegen pipeline</a></li>
  <li><a href="#navigating-the-atnative-namespace" id="markdown-toc-navigating-the-atnative-namespace">Navigating the <code class="language-plaintext highlighter-rouge">at::native</code> namespace</a></li>
  <li><a href="#structured-kernels" id="markdown-toc-structured-kernels">Structured Kernels</a></li>
  <li><a href="#where-are-the-actual-implementations" id="markdown-toc-where-are-the-actual-implementations">Where are the actual implementations?</a>    <ul>
      <li><a href="#shape-checking-torch_meta_funcaddmm" id="markdown-toc-shape-checking-torch_meta_funcaddmm">Shape checking: <code class="language-plaintext highlighter-rouge">TORCH_META_FUNC(addmm)</code></a></li>
      <li><a href="#cpu-implementation-torch_impl_funcaddmm_out_cpu" id="markdown-toc-cpu-implementation-torch_impl_funcaddmm_out_cpu">CPU implementation: <code class="language-plaintext highlighter-rouge">TORCH_IMPL_FUNC(addmm_out_cpu)</code></a></li>
      <li><a href="#cuda-implementation-torch_impl_funcaddmm_out_cuda" id="markdown-toc-cuda-implementation-torch_impl_funcaddmm_out_cuda">CUDA implementation: <code class="language-plaintext highlighter-rouge">TORCH_IMPL_FUNC(addmm_out_cuda)</code></a></li>
    </ul>
  </li>
  <li><a href="#conclusion" id="markdown-toc-conclusion">Conclusion</a></li>
  <li><a href="#credits" id="markdown-toc-credits">Credits</a></li>
</ul>

<h2 id="pytorch-docs-and-the-dispatcher">PyTorch Docs and the Dispatcher</h2>

<p>To get an idea of what these operations do, we can look at the <a href="https://pytorch.org/cppdocs/api/namespace_at.html#namespace-at">PyTorch at Namespace docs</a> and look for these functions. Via this we see that the <a href="https://pytorch.org/cppdocs/api/function_namespaceat_1a96bac9e697e177adb535c1330635be44.html#exhale-function-namespaceat-1a96bac9e697e177adb535c1330635be44">aten::addmm</a> function is defined in <code class="language-plaintext highlighter-rouge">build/aten/src/ATen/Functions.h</code>. Looking at the <a href="https://pytorch.org/cppdocs/api/program_listing_file_build_aten_src_ATen_Functions.h.html">program listing</a>, we can see that it calls <code class="language-plaintext highlighter-rouge">at::_ops::addmm_out::call(self, mat1, mat2, beta, alpha, out)</code>.</p>

<p>We can look at the respective <a href="https://pytorch.org/docs/stable/generated/torch.addmm.html"><code class="language-plaintext highlighter-rouge">Python API</code></a> to learn more about the different arguments of the <code class="language-plaintext highlighter-rouge">addmm</code> function. The <code class="language-plaintext highlighter-rouge">addmm</code> function is a matrix multiplication followed by a matrix addition of the following form:</p>

<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><mtext>out</mtext><mo>=</mo><mi>β</mi><mtext>input</mtext><mo>+</mo><mi>α</mi><mo stretchy="false">(</mo><mtext>mat1</mtext><mi mathvariant="normal">@</mi><mtext>mat2</mtext><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\text{out} = \beta \text{input} + \alpha (\text{mat1} @ \text{mat2})</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6151em;"></span><span class="mord text"><span class="mord">out</span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="mord text"><span class="mord">input</span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.0037em;">α</span><span class="mopen">(</span><span class="mord text"><span class="mord">mat1</span></span><span class="mord">@</span><span class="mord text"><span class="mord">mat2</span></span><span class="mclose">)</span></span></span></span></span>

<p>The <code class="language-plaintext highlighter-rouge">mat1</code> and <code class="language-plaintext highlighter-rouge">mat2</code> arguments are the input matrices, <code class="language-plaintext highlighter-rouge">beta</code> is a scaling factor for the input matrix <code class="language-plaintext highlighter-rouge">input</code>, <code class="language-plaintext highlighter-rouge">alpha</code> is a scaling factor for the matrix multiplication and <code class="language-plaintext highlighter-rouge">out</code> is the output tensor.</p>

<p>Just looking through the <a href="https://github.com/pytorch/pytorch/tree/main">PyTorch GitHub repo</a> looking for the implementation of function is unfortunately quite a pain. One of the main reasons for that is that depending on your backend (CPU, NVIDIA GPU, Apple M-series chips, …), the <a href="http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/">PyTorch dispatcher</a> dynamically dispatches to the correct kernel for your setup.</p>

<h2 id="native-functions-and-the-codegen-pipeline">Native functions and the codegen pipeline</h2>

<p>Another complication is that many operations are not really fully implemented in the PyTorch codebase, but will get generated during the PyTorch build process via a <a href="https://github.com/pytorch/pytorch/wiki/Codegen-and-Structured-Kernels">code-generation pipeline</a> (more on this <a href="https://pytorch-dev-podcast.simplecast.com/episodes/code-generation">in this podcast episode</a>). This is sensible since while many operations in PyTorch are in principle quite simple (element-wise additions, activation functions, …), there is a lot of boilerplate that every operation has to implement (like bindings to python, autograd support, registering the kernel to the dispatcher, …). The codegen pipeline allows PyTorch to generate this boilerplate code automatically.</p>

<p>What we need to do therefore is to look at the <a href="https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native"><code class="language-plaintext highlighter-rouge">native_functions.yaml</code> file</a>, with “native” functions being the <a href="https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native">modern mechansim for adding operators and functions to ATen</a> (more details in <a href="https://pytorch-dev-podcast.simplecast.com/episodes/native-functions-yaml">this podcast episode</a>). This file describes metadata about each operator that gets consumed by the codegen (more details on the different fields in this yaml file <a href="https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md">here</a>).</p>

<p>If we search in the <code class="language-plaintext highlighter-rouge">native_functions.yaml</code> file for <code class="language-plaintext highlighter-rouge">addmm</code>, we find the following entry:</p>

<div class="language-yaml highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># file: "native_functions.yaml"</span>
<span class="pi">-</span> <span class="na">func</span><span class="pi">:</span> <span class="s">addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -&gt; Tensor</span>
  <span class="na">structured_delegate</span><span class="pi">:</span> <span class="s">addmm.out</span>
  <span class="na">variants</span><span class="pi">:</span> <span class="s">function, method</span>
  <span class="na">dispatch</span><span class="pi">:</span>
    <span class="na">SparseCPU</span><span class="pi">:</span> <span class="s">addmm_sparse_dense_cpu</span>
    <span class="na">SparseCUDA</span><span class="pi">:</span> <span class="s">addmm_sparse_dense_cuda</span>
    <span class="s">SparseCsrCPU, SparseCsrCUDA</span><span class="err">:</span> <span class="s">addmm_sparse_compressed_dense</span>
  <span class="na">tags</span><span class="pi">:</span> <span class="s">core</span>
</code></pre></div></div>

<p class="figcaption">Entry for the <a href="https://github.com/pytorch/pytorch/blob/34db6f1b13206d0b5cc3297e4a92dd0c4b8aea45/aten/src/ATen/native/native_functions.yaml#L6826"><code class="language-plaintext highlighter-rouge">addmm</code> function</a></p>

<p>We see the <code class="language-plaintext highlighter-rouge">structured_delegate</code> field, which tells us that the actual implementation of the <code class="language-plaintext highlighter-rouge">addmm</code> function is in the <code class="language-plaintext highlighter-rouge">addmm.out</code> function (more on this later). We can find the implementation of this function in the <code class="language-plaintext highlighter-rouge">native_functions.yaml</code> file:</p>

<div class="language-yaml highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># file: "native_functions.yaml"</span>
<span class="pi">-</span> <span class="na">func</span><span class="pi">:</span> <span class="s">addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -&gt; Tensor(a!)</span>
  <span class="na">structured</span><span class="pi">:</span> <span class="s">True</span>
  <span class="na">dispatch</span><span class="pi">:</span>
    <span class="na">CPU</span><span class="pi">:</span> <span class="s">addmm_out_cpu</span>
    <span class="na">CUDA</span><span class="pi">:</span> <span class="s">addmm_out_cuda</span>
    <span class="na">MPS</span><span class="pi">:</span> <span class="s">addmm_out_mps</span>
    <span class="na">SparseCPU</span><span class="pi">:</span> <span class="s">addmm_out_sparse_dense_cpu</span>
    <span class="na">SparseCUDA</span><span class="pi">:</span> <span class="s">addmm_out_sparse_dense_cuda</span>
    <span class="na">SparseCsrCPU</span><span class="pi">:</span> <span class="s">addmm_out_sparse_compressed_cpu</span>
    <span class="na">SparseCsrCUDA</span><span class="pi">:</span> <span class="s">addmm_out_sparse_compressed_cuda</span>
</code></pre></div></div>

<p class="figcaption">Entry for the <a href="https://github.com/pytorch/pytorch/blob/34db6f1b13206d0b5cc3297e4a92dd0c4b8aea45/aten/src/ATen/native/native_functions.yaml#L6815"><code class="language-plaintext highlighter-rouge">addmm</code> function</a></p>

<p>Ignoring the <code class="language-plaintext highlighter-rouge">structured</code> field for now, we see multiple things:</p>

<ol>
  <li>We have multiple entries for the <code class="language-plaintext highlighter-rouge">addmm</code> function, <code class="language-plaintext highlighter-rouge">addmm</code> and <code class="language-plaintext highlighter-rouge">addmm_out</code>. There are in fact three different versions of most PyTorch operators (however, we only see the <code class="language-plaintext highlighter-rouge">addmm</code> and <code class="language-plaintext highlighter-rouge">addmm_out</code> functions in the codebase since the in-place version is generated automatically):
    <ul>
      <li><code class="language-plaintext highlighter-rouge">addmm</code>: the functional version that performs the operation without modifying the original tensor and returns a new tensor, for example <code class="language-plaintext highlighter-rouge">output = torch.add(input, other)</code></li>
      <li><code class="language-plaintext highlighter-rouge">addmm_</code>: the in-place version that modifies the original tensor, for example <code class="language-plaintext highlighter-rouge">input.add_(other)</code></li>
      <li><code class="language-plaintext highlighter-rouge">addmm_out</code>: the out-of-place version that takes an additional tensor as an argument and writes the result to this tensor, for example <code class="language-plaintext highlighter-rouge">torch.add(input, other, out=output)</code></li>
    </ul>
  </li>
  <li>
    <p>We see that for each backend (CPU, CUDA, MPS, …) there is a separate implementation of the <code class="language-plaintext highlighter-rouge">addmm</code> function. This is because the implementation of the <code class="language-plaintext highlighter-rouge">addmm</code> function can be highly dependent on the specific hardware and memory layout of the input tensors. For example, the <code class="language-plaintext highlighter-rouge">addmm</code> function for sparse tensors is implemented differently than the <code class="language-plaintext highlighter-rouge">addmm</code> function for dense tensors.</p>
  </li>
  <li>
    <p>In summary, this means that we need to write (#variants * #backends kernel) implementations for each operator. This is a lot of boilerplate code that the codegen pipeline can generate for us.</p>
  </li>
  <li>The <code class="language-plaintext highlighter-rouge">variants</code> field tells us that the <code class="language-plaintext highlighter-rouge">addmm</code> function can be called as a namespace function (<code class="language-plaintext highlighter-rouge">at::addmm()</code>) or as a Tensor method (<code class="language-plaintext highlighter-rouge">t.addmm()</code>). This is because PyTorch supports both functional and method-based APIs. To qualify as a Tensor method, there most be a <code class="language-plaintext highlighter-rouge">Tensor self</code> argument in the function signature since otherwise the function would not be able to be called as a method on a tensor. In the method variant this argument will be removed from the function signature. A function variant is always generated by ATen, but when should you also generate a method variant? From the <a href="https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native">PyTorch native README</a>:</li>
</ol>

<blockquote class="lead">
  <p>Tensor operations as methods are appropriate for “core” Tensor operations (e.g., add, sub, etc.), but not for more complicated neural network layers (e.g., <code class="language-plaintext highlighter-rouge">conv2d</code>) and internal functions designed specifically for binding (e.g., <code class="language-plaintext highlighter-rouge">cudnn_convolution</code>).</p>
</blockquote>

<h2 id="navigating-the-atnative-namespace">Navigating the <code class="language-plaintext highlighter-rouge">at::native</code> namespace</h2>

<p>If we want to look for where a specific implementation of the <code class="language-plaintext highlighter-rouge">addmm</code> function is, we just need to look for the name of the function in the <code class="language-plaintext highlighter-rouge">at::native</code> namespace. This still does not bring us to the actual implementation of the function easily because there are <a href="https://dev-discuss.pytorch.org/t/where-do-the-2000-pytorch-operators-come-from-more-than-you-wanted-to-know/373">more than 2000 PyTorch operators</a> which can be grouped into <a href="https://docs.google.com/spreadsheets/d/1Sp4HUjxwMifS5oDQg0yvjqk7hKOpCfKO4jWH4MTGP-k/edit#gid=0">various categories</a>. We can see in the post linked in the last sentence that <code class="language-plaintext highlighter-rouge">addmm</code> is counted as one of the 13 <em>composite matmul</em> operators. There are different ways to categorize the operators (for example by <a href="http://blog.ezyang.com/2020/05/a-brief-taxonomy-of-pytorch-operators-by-shape-behavior/">shape behavior</a>), but the point is that there are a lot of them.</p>

<p>To find our <code class="language-plaintext highlighter-rouge">addmm</code> needle in the <code class="language-plaintext highlighter-rouge">at::native</code> namespace haystack, we can either <a href="https://docs.github.com/en/codespaces/the-githubdev-web-based-editor">directly open a codespace on GitHub</a> or we can clone the PyTorch repo. Both options give us access to a terminal where we can find the implementation of the <code class="language-plaintext highlighter-rouge">addmm</code> function by running <code class="language-plaintext highlighter-rouge">git grep "addmm"</code>. This will give us a list of all files in the <a href="https://stackoverflow.com/questions/60843047/locating-a-function-in-a-git-repository">current folder of the PyTorch repo</a> that contain the string <code class="language-plaintext highlighter-rouge">addmm</code>. We can then look through these files to find the actual implementation of the <code class="language-plaintext highlighter-rouge">addmm</code> function. So we do the following in summary:</p>

<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>git clone https://github.com/pytorch/pytorch 
<span class="nb">cd </span>pytorch/aten/src/ATen/native
git <span class="nb">grep</span> <span class="s2">"addmm"</span>
</code></pre></div></div>

<p>This gives us a lot of output, but we can see that there are two kinds of functions declarations in <code class="language-plaintext highlighter-rouge">LinearAlgebra.cpp</code> that look promising:</p>

<ol>
  <li>A meta function called <a href="https://github.com/pytorch/pytorch/blob/c5116d9e44f7a0ab40d26e47077ecdd15693e9dd/aten/src/ATen/native/LinearAlgebra.cpp#L181"><code class="language-plaintext highlighter-rouge">TORCH_META_FUNC(addmm)</code></a></li>
  <li>Multiple implementatin functions: <a href="https://github.com/pytorch/pytorch/blob/c5116d9e44f7a0ab40d26e47077ecdd15693e9dd/aten/src/ATen/native/LinearAlgebra.cpp#L1621"><code class="language-plaintext highlighter-rouge">TORCH_IMPL_FUNC(addmm_out_cpu)</code></a>, but also the CUDA implementation in the <code class="language-plaintext highlighter-rouge">cuda/Blas.cpp</code> file called <a href="https://github.com/pytorch/pytorch/blob/c5116d9e44f7a0ab40d26e47077ecdd15693e9dd/aten/src/ATen/native/cuda/Blas.cpp#L505"><code class="language-plaintext highlighter-rouge">TORCH_IMPL_FUNC(addmm_out_cuda)</code></a></li>
</ol>

<p>This insight leads us to another new term we have to understand in order to make sense of the codebase: Structured Kernels.</p>

<h2 id="structured-kernels">Structured Kernels</h2>

<p><a href="https://github.com/pytorch/rfcs/blob/rfc-0005/RFC-0005-structured-kernel-definitions.md">Structured Kernels</a> is a new (i.e. from 2021) way to define PyTorch operators. It abstracts away even more of the boilerplate code that has to be written for each operator and backend than native functions alone, to the extent that you only need to write a shape-checking function (meta function) and a kernel implementation function for the out-kernel and the structured kernel will take care of the rest.</p>

<p>This now explains the <code class="language-plaintext highlighter-rouge">structured</code> and <code class="language-plaintext highlighter-rouge">structured_delegate</code> fields in the <code class="language-plaintext highlighter-rouge">native_functions.yaml</code> file. The <code class="language-plaintext highlighter-rouge">structured</code> field tells us that the <code class="language-plaintext highlighter-rouge">addmm</code> function is a structured kernel, and the <code class="language-plaintext highlighter-rouge">structured_delegate</code> field tells us that the actual implementation of the <code class="language-plaintext highlighter-rouge">addmm</code> function is in the <code class="language-plaintext highlighter-rouge">addmm.out</code> function.</p>

<p>Pre structured kernels, entries in the <code class="language-plaintext highlighter-rouge">native_functions.yaml</code> file looked like this:</p>

<div class="language-yaml highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># file: "native_functions.yaml"</span>
<span class="pi">-</span> <span class="na">func</span><span class="pi">:</span> <span class="s">addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -&gt; Tensor</span>
  <span class="c1"># structured_delegate: addmm.out removed!</span>
  <span class="na">variants</span><span class="pi">:</span> <span class="s">function, method</span>
  <span class="na">dispatch</span><span class="pi">:</span>
    <span class="c1">#CPU, CUDA and MPS kernels added!</span>
    <span class="na">CPU</span><span class="pi">:</span> <span class="s">addmm_cpu</span>
    <span class="na">CUDA</span><span class="pi">:</span> <span class="s">addmm_cuda</span>
    <span class="na">MPS</span><span class="pi">:</span> <span class="s">addmm_mps</span>
    <span class="na">SparseCPU</span><span class="pi">:</span> <span class="s">addmm_sparse_dense_cpu</span>
    <span class="na">SparseCUDA</span><span class="pi">:</span> <span class="s">addmm_sparse_dense_cuda</span>
    <span class="s">SparseCsrCPU, SparseCsrCUDA</span><span class="err">:</span> <span class="s">addmm_sparse_compressed_dense</span>
  <span class="na">tags</span><span class="pi">:</span> <span class="s">core</span>
</code></pre></div></div>

<div class="language-yaml highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># file: "native_functions.yaml"</span>
<span class="pi">-</span> <span class="na">func</span><span class="pi">:</span> <span class="s">addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -&gt; Tensor(a!)</span>
  <span class="c1"># structured: True removed!</span>
  <span class="na">dispatch</span><span class="pi">:</span>
    <span class="na">CPU</span><span class="pi">:</span> <span class="s">addmm_out_cpu</span>
    <span class="na">CUDA</span><span class="pi">:</span> <span class="s">addmm_out_cuda</span>
    <span class="na">MPS</span><span class="pi">:</span> <span class="s">addmm_out_mps</span>
    <span class="na">SparseCPU</span><span class="pi">:</span> <span class="s">addmm_out_sparse_dense_cpu</span>
    <span class="na">SparseCUDA</span><span class="pi">:</span> <span class="s">addmm_out_sparse_dense_cuda</span>
    <span class="na">SparseCsrCPU</span><span class="pi">:</span> <span class="s">addmm_out_sparse_compressed_cpu</span>
    <span class="na">SparseCsrCUDA</span><span class="pi">:</span> <span class="s">addmm_out_sparse_compressed_cuda</span>
</code></pre></div></div>

<p>You see that before structured kernels, both the <code class="language-plaintext highlighter-rouge">addmm</code> and <code class="language-plaintext highlighter-rouge">addmm_out</code> functions had a <code class="language-plaintext highlighter-rouge">dispatch</code> field that specified all the backends for which the function had to be implemented. The <code class="language-plaintext highlighter-rouge">CPU</code>, <code class="language-plaintext highlighter-rouge">CUDA</code> and MPS kernel now have to be implemented separately for the <code class="language-plaintext highlighter-rouge">addmm</code> and <code class="language-plaintext highlighter-rouge">addmm_out</code> functions. This is a lot of boilerplate code that the structured kernel can generate for us.</p>

<p>In the structured kernel yaml file, you see that the <code class="language-plaintext highlighter-rouge">addmm</code> function has a <code class="language-plaintext highlighter-rouge">structured_delegate</code> field that points to the <code class="language-plaintext highlighter-rouge">addmm.out</code> function. This is because the <code class="language-plaintext highlighter-rouge">addmm</code> function is a structured kernel, and the actual implementation of the <code class="language-plaintext highlighter-rouge">addmm</code> function is in the <code class="language-plaintext highlighter-rouge">addmm.out</code> function. The <code class="language-plaintext highlighter-rouge">addmm.out</code> function is a structured kernel that is implemented in the <code class="language-plaintext highlighter-rouge">LinearAlgebra.cpp</code> file.</p>

<p>In the ideal case of a structured kernel, the <code class="language-plaintext highlighter-rouge">addmm</code> function would not need any <code class="language-plaintext highlighter-rouge">dispatch</code> field because the <code class="language-plaintext highlighter-rouge">addmm_out</code> as the structural delegate would implement all the kernel implementations. This can be seen in the example from the <a href="https://github.com/pytorch/rfcs/blob/rfc-0005/RFC-0005-structured-kernel-definitions.md">RFC for structured kernels</a>:</p>

<p><img src="/assets/img/blog/pytorch_whirlwind/structured_kernel_readme.png" alt="Structured Kernel README" /></p>

<p>In the <code class="language-plaintext highlighter-rouge">addmm</code> function, however, we still see the <code class="language-plaintext highlighter-rouge">dispatch</code> field. This is because the <code class="language-plaintext highlighter-rouge">addmm</code> function is a composite matmul operator, and the implementation can be highly specific in the sparse case. Therefore we cannot rely on the structured kernel to generate the correct implementation for us, and we have to specify the dispatch field manually. If you want to learn more about how all this is implemented under the hood, check out <a href="https://drive.google.com/file/d/16qPvpCF4Jbh7ss2lCQMk5hmcyzJvUyQj/view">this slide deck</a>.</p>

<h2 id="where-are-the-actual-implementations">Where are the actual implementations?</h2>

<p>We are already quite deep down in the rabbit hole and tracked down the <code class="language-plaintext highlighter-rouge">addmm</code> function to the <code class="language-plaintext highlighter-rouge">LinearAlgebra.cpp</code> and the <code class="language-plaintext highlighter-rouge">cuda/Blas.cpp</code> file. These files contains the meta function <code class="language-plaintext highlighter-rouge">TORCH_META_FUNC(addmm)</code> and the implementation functions <code class="language-plaintext highlighter-rouge">TORCH_IMPL_FUNC(addmm_out_cpu)</code> and <code class="language-plaintext highlighter-rouge">TORCH_IMPL_FUNC(addmm_out_cuda)</code>. The <code class="language-plaintext highlighter-rouge">TORCH_META_FUNC</code> function is a meta function that checks the shapes of the input tensors and calls the correct implementation function. The <code class="language-plaintext highlighter-rouge">TORCH_IMPL_FUNC</code> function is the actual implementation of the <code class="language-plaintext highlighter-rouge">addmm</code> function for the CPU and CUDA backends.</p>

<p>Let us look at these in turn now.</p>

<h3 id="shape-checking-torch_meta_funcaddmm">Shape checking: <code class="language-plaintext highlighter-rouge">TORCH_META_FUNC(addmm)</code></h3>

<p>The <code class="language-plaintext highlighter-rouge">TORCH_META_FUNC(addmm)</code> function is a wrapper around <code class="language-plaintext highlighter-rouge">ADDMM_META()</code>. Why another wrapper, you may ask? Well, the shape checkign done is this function is transferable to other cases such as for <a href="https://github.com/pytorch/pytorch/blob/c5116d9e44f7a0ab40d26e47077ecdd15693e9dd/aten/src/ATen/native/LinearAlgebra.cpp#L185C1-L185C35"><code class="language-plaintext highlighter-rouge">TORCH_META_FUNC(_addmm_activation)</code></a>, so the wrapper promotes reusability.</p>

<p>Looking at the <a href="https://github.com/pytorch/pytorch/blob/c5116d9e44f7a0ab40d26e47077ecdd15693e9dd/aten/src/ATen/native/LinearAlgebra.cpp#L169C1-L179C92">implementation of <code class="language-plaintext highlighter-rouge">ADDMM_META()</code></a>, we see that it is actually not a function but a preprocessor macro:</p>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="cp">#define ADDMM_META() \
  TORCH_CHECK(self.scalar_type() == mat2.scalar_type(), "self and mat2 must have the same dtype, but got ", self.scalar_type(), " and ", mat2.scalar_type()); \
  TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "mat1 and mat2 must have the same dtype, but got ", mat1.scalar_type(), " and ", mat2.scalar_type()); \
  TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor"); \
  TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor"); \
  TORCH_CHECK( \
      mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (", \
      mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); \
 \
  auto names = at::namedinference::propagate_names_for_addmm(mat1, mat2, self); \
  set_output_raw_strided(0, {mat1.sizes()[0], mat2.sizes()[1]}, {}, mat1.options(), names);
</span></code></pre></div></div>

<p>As expected, it performs a lot of checks on the input tensors. It checks that the input tensors have the same data type, that <code class="language-plaintext highlighter-rouge">mat1</code> and <code class="language-plaintext highlighter-rouge">mat2</code> are both 2D tensors (i.e. matrices), and that the shapes of <code class="language-plaintext highlighter-rouge">mat1</code> and <code class="language-plaintext highlighter-rouge">mat2</code> are compatible for matrix multiplication. It then calls the <code class="language-plaintext highlighter-rouge">at::namedinference::propagate_names_for_addmm</code> function to propagate the names of the input tensors to the output tensor. Finally, it sets the output tensor to the correct shape.</p>

<h3 id="cpu-implementation-torch_impl_funcaddmm_out_cpu">CPU implementation: <code class="language-plaintext highlighter-rouge">TORCH_IMPL_FUNC(addmm_out_cpu)</code></h3>

<p>If we look at the <code class="language-plaintext highlighter-rouge">TORCH_IMPL_FUNC(addmm_out_cpu)</code> function, we see that it is again a wrapper! It first expands the output tnesor to the correct shape (rows = <code class="language-plaintext highlighter-rouge">{mat1.sizes()[0]</code>, columns = <code class="language-plaintext highlighter-rouge">mat2.sizes()[1]}</code>) and then calls the <code class="language-plaintext highlighter-rouge">addmm_impl_cpu_()</code> function.</p>

<p>Fortunately, this time we do not have to search long for the actual implementation of the <code class="language-plaintext highlighter-rouge">addmm_impl_cpu_()</code> function. It is in the <a href="https://github.com/pytorch/pytorch/blob/c5116d9e44f7a0ab40d26e47077ecdd15693e9dd/aten/src/ATen/native/LinearAlgebra.cpp#L1405">same file</a> and longer than the previous wrapper function (which makes sense since it is the actual implementation of the <code class="language-plaintext highlighter-rouge">addmm</code> function).</p>

<p>Looking at the function signature, we see the following:</p>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">static</span> <span class="kt">void</span> <span class="n">addmm_impl_cpu_</span><span class="p">(</span>
    <span class="n">Tensor</span> <span class="o">&amp;</span><span class="n">result</span><span class="p">,</span> <span class="k">const</span> <span class="n">Tensor</span> <span class="o">&amp;</span><span class="n">self</span><span class="p">,</span> <span class="n">Tensor</span> <span class="n">m1</span><span class="p">,</span> <span class="n">Tensor</span> <span class="n">m2</span><span class="p">,</span> <span class="k">const</span> <span class="n">Scalar</span><span class="o">&amp;</span> <span class="n">beta</span><span class="p">,</span> <span class="k">const</span> <span class="n">Scalar</span><span class="o">&amp;</span> <span class="n">alpha</span><span class="p">)</span>
</code></pre></div></div>

<p>We see that the function does not return anything, but takes a reference to the output tensor <code class="language-plaintext highlighter-rouge">result</code> and the input tensors <code class="language-plaintext highlighter-rouge">self</code>, <code class="language-plaintext highlighter-rouge">m1</code> and <code class="language-plaintext highlighter-rouge">m2</code> as well as the scaling factors <code class="language-plaintext highlighter-rouge">beta</code> and <code class="language-plaintext highlighter-rouge">alpha</code>. It starts with a some shape asserts and data type checks. It then allocates the sizes of the different matrices to <code class="language-plaintext highlighter-rouge">auto</code> variables since accessing these arrays is faster than calling the <code class="language-plaintext highlighter-rouge">size()</code> method multiple times (we will need these sizes for the matrix multiplication). After some additional checks and resizings we get to the core of the function.</p>
<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">// Some paths in the code below do not handle multiplications of the form [a, 0] x [0, b]</span>
  <span class="k">if</span> <span class="p">(</span><span class="n">m1_sizes</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)</span> <span class="p">{</span>
    <span class="k">if</span> <span class="p">(</span><span class="n">beta</span><span class="p">.</span><span class="n">toComplexDouble</span><span class="p">()</span> <span class="o">==</span> <span class="mf">0.0</span><span class="p">)</span> <span class="p">{</span>
      <span class="n">result</span><span class="p">.</span><span class="n">zero_</span><span class="p">();</span>
    <span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
      <span class="k">if</span> <span class="p">(</span><span class="o">!</span><span class="n">self</span><span class="p">.</span><span class="n">is_same</span><span class="p">(</span><span class="n">result</span><span class="p">))</span> <span class="p">{</span>
        <span class="n">result</span><span class="p">.</span><span class="n">copy_</span><span class="p">(</span><span class="n">self</span><span class="p">);</span>
      <span class="p">}</span>
      <span class="n">result</span><span class="p">.</span><span class="n">mul_</span><span class="p">(</span><span class="n">beta</span><span class="p">);</span>
    <span class="p">}</span>
    <span class="k">return</span><span class="p">;</span>
  <span class="p">}</span>
</code></pre></div></div>

<p class="figcaption">Checks for the <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>β</mi></mrow><annotation encoding="application/x-tex">\beta</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal" style="margin-right:0.05278em;">β</span></span></span></span> value. <a href="https://github.com/pytorch/pytorch/blob/c5116d9e44f7a0ab40d26e47077ecdd15693e9dd/aten/src/ATen/native/LinearAlgebra.cpp#L1435C1-L1445C4">Link</a></p>

<p>As the comment tells us, the code after the excerpt cannot handle multiplications of the form <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo stretchy="false">[</mo><mi>a</mi><mo separator="true">,</mo><mn>0</mn><mo stretchy="false">]</mo><mo>×</mo><mo stretchy="false">[</mo><mn>0</mn><mo separator="true">,</mo><mi>b</mi><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex">[a, 0] \times [0, b]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">[</span><span class="mord mathnormal">a</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">0</span><span class="mclose">]</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">[</span><span class="mord">0</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathnormal">b</span><span class="mclose">]</span></span></span></span>, so it checks for this case and handles it separately. We can see that if the input scaling factor <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>β</mi></mrow><annotation encoding="application/x-tex">\beta</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal" style="margin-right:0.05278em;">β</span></span></span></span> is zero, the output tensor is zeroed out. If the input scaling factor <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>β</mi></mrow><annotation encoding="application/x-tex">\beta</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal" style="margin-right:0.05278em;">β</span></span></span></span> is not zero, the output tensor copies the entries from the <code class="language-plaintext highlighter-rouge">self</code> tensor and is scaled by <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>β</mi></mrow><annotation encoding="application/x-tex">\beta</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal" style="margin-right:0.05278em;">β</span></span></span></span>. The function then returns.</p>

<p>After that, we cast the tensors <code class="language-plaintext highlighter-rouge">result</code> and <code class="language-plaintext highlighter-rouge">m1</code> as matrix <code class="language-plaintext highlighter-rouge">a</code> and <code class="language-plaintext highlighter-rouge">m2</code> as matrix <code class="language-plaintext highlighter-rouge">b</code>. We do this to prepare the shapes correctly for the matrix multiplication.</p>

<p>Finally, we get to the matrix multiplication itself. Depending on which CPU hardware we have we can still dispatch to two different implementation.</p>

<ol>
  <li>On AArch64 we can call the <code class="language-plaintext highlighter-rouge">mkldnn_matmul</code> function that is faster in case certain shape considerations are fulfilled:</li>
</ol>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code>  <span class="kt">bool</span> <span class="n">dispatched</span> <span class="o">=</span> <span class="nb">false</span><span class="p">;</span>
<span class="cp">#if defined(__aarch64__) &amp;&amp; AT_MKLDNN_ACL_ENABLED()
</span>  <span class="c1">// On AArch64 if LHS matrix in BLAS routine is transposed but RHS is not then</span>
  <span class="c1">// it is faster to call oneDNN matrix multiplication primitive with RHS*LHS</span>
  <span class="c1">// that will call then into Arm® Compute Library (ACL) GEMM kernel and also</span>
  <span class="c1">// additionally have support for running kernel with BF16 instructions</span>
  <span class="k">if</span> <span class="p">(</span><span class="n">transpose_c</span><span class="p">)</span> <span class="p">{</span>
    <span class="kt">bool</span> <span class="n">apply_heur</span> <span class="o">=</span> <span class="n">apply_mkldnn_matmul_heur</span><span class="p">(</span><span class="n">b</span><span class="p">.</span><span class="n">sizes</span><span class="p">()[</span><span class="mi">0</span><span class="p">],</span> <span class="n">b</span><span class="p">.</span><span class="n">sizes</span><span class="p">()[</span><span class="mi">1</span><span class="p">],</span> <span class="n">a</span><span class="p">.</span><span class="n">sizes</span><span class="p">()[</span><span class="mi">1</span><span class="p">]);</span>
    <span class="k">if</span> <span class="p">(</span><span class="n">apply_heur</span> <span class="o">&amp;&amp;</span> <span class="n">transpose_a</span> <span class="o">&amp;&amp;</span> <span class="o">!</span><span class="n">transpose_b</span> <span class="o">&amp;&amp;</span> <span class="n">result</span><span class="p">.</span><span class="n">scalar_type</span><span class="p">()</span> <span class="o">==</span> <span class="n">at</span><span class="o">::</span><span class="n">ScalarType</span><span class="o">::</span><span class="n">Float</span><span class="p">)</span> <span class="p">{</span>
      <span class="k">try</span> <span class="p">{</span>
        <span class="n">mkldnn_matmul</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">beta</span><span class="p">.</span><span class="n">to</span><span class="o">&lt;</span><span class="kt">float</span><span class="o">&gt;</span><span class="p">(),</span> <span class="n">alpha</span><span class="p">.</span><span class="n">to</span><span class="o">&lt;</span><span class="kt">float</span><span class="o">&gt;</span><span class="p">());</span>
        <span class="c1">// We have dispatched to ACL GEMM for single precision float</span>
        <span class="c1">// so do not need to dispatch to BLAS GEMM below</span>
        <span class="n">dispatched</span> <span class="o">=</span> <span class="nb">true</span><span class="p">;</span>
      <span class="p">}</span> <span class="k">catch</span> <span class="p">(</span><span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">exception</span><span class="o">&amp;</span> <span class="n">e</span><span class="p">)</span> <span class="p">{</span>
        <span class="n">TORCH_WARN</span><span class="p">(</span><span class="s">"mkldnn_matmul failed, switching to BLAS gemm:"</span><span class="p">,</span> <span class="n">e</span><span class="p">.</span><span class="n">what</span><span class="p">());</span>
        <span class="n">at</span><span class="o">::</span><span class="n">globalContext</span><span class="p">().</span><span class="n">setUserEnabledMkldnn</span><span class="p">(</span><span class="nb">false</span><span class="p">);</span>
      <span class="p">}</span>
    <span class="p">}</span>
  <span class="p">}</span>
<span class="cp">#endif
</span></code></pre></div></div>
<p class="figcaption">AArch64 matrix multiplication dispatch. <a href="https://github.com/pytorch/pytorch/blob/c5116d9e44f7a0ab40d26e47077ecdd15693e9dd/aten/src/ATen/native/LinearAlgebra.cpp#L1517C1-L1537C7">Link</a></p>

<ol>
  <li>If this option is not enabled (or if the heuristic check for the matrix shapes fails), we fall back to the <code class="language-plaintext highlighter-rouge">gemm</code> function from the BLAS library:</li>
</ol>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code>  <span class="k">if</span><span class="p">(</span><span class="o">!</span><span class="n">dispatched</span><span class="p">)</span> <span class="p">{</span>
    <span class="c1">// Apply BLAS routine</span>
    <span class="n">_AT_DISPATCH_ADDMM_TYPES</span><span class="p">(</span><span class="n">result</span><span class="p">.</span><span class="n">scalar_type</span><span class="p">(),</span> <span class="s">"addmm_impl_cpu_"</span><span class="p">,</span> <span class="p">[</span><span class="o">&amp;</span><span class="p">]{</span>
          <span class="k">using</span> <span class="n">opmath_t</span> <span class="o">=</span> <span class="n">at</span><span class="o">::</span><span class="n">opmath_type</span><span class="o">&lt;</span><span class="n">scalar_t</span><span class="o">&gt;</span><span class="p">;</span>
          <span class="n">at</span><span class="o">::</span><span class="n">native</span><span class="o">::</span><span class="n">cpublas</span><span class="o">::</span><span class="n">gemm</span><span class="p">(</span>
              <span class="n">transpose_a</span> <span class="o">?</span> <span class="n">a</span><span class="p">.</span><span class="n">is_conj</span><span class="p">()</span> <span class="o">?</span> <span class="n">TransposeType</span><span class="o">::</span><span class="n">ConjTranspose</span> <span class="o">:</span> <span class="n">TransposeType</span><span class="o">::</span><span class="n">Transpose</span> <span class="o">:</span> <span class="n">TransposeType</span><span class="o">::</span><span class="n">NoTranspose</span><span class="p">,</span>
              <span class="n">transpose_b</span> <span class="o">?</span> <span class="n">b</span><span class="p">.</span><span class="n">is_conj</span><span class="p">()</span> <span class="o">?</span> <span class="n">TransposeType</span><span class="o">::</span><span class="n">ConjTranspose</span> <span class="o">:</span> <span class="n">TransposeType</span><span class="o">::</span><span class="n">Transpose</span> <span class="o">:</span> <span class="n">TransposeType</span><span class="o">::</span><span class="n">NoTranspose</span><span class="p">,</span>
              <span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span>
              <span class="n">alpha</span><span class="p">.</span><span class="n">to</span><span class="o">&lt;</span><span class="n">opmath_t</span><span class="o">&gt;</span><span class="p">(),</span>
              <span class="n">a</span><span class="p">.</span><span class="n">const_data_ptr</span><span class="o">&lt;</span><span class="n">scalar_t</span><span class="o">&gt;</span><span class="p">(),</span> <span class="n">lda</span><span class="p">,</span>
              <span class="n">b</span><span class="p">.</span><span class="n">const_data_ptr</span><span class="o">&lt;</span><span class="n">scalar_t</span><span class="o">&gt;</span><span class="p">(),</span> <span class="n">ldb</span><span class="p">,</span>
              <span class="n">beta</span><span class="p">.</span><span class="n">to</span><span class="o">&lt;</span><span class="n">opmath_t</span><span class="o">&gt;</span><span class="p">(),</span>
              <span class="n">c</span><span class="p">.</span><span class="n">mutable_data_ptr</span><span class="o">&lt;</span><span class="n">scalar_t</span><span class="o">&gt;</span><span class="p">(),</span> <span class="n">ldc</span><span class="p">);</span>
        <span class="p">});</span>
  <span class="p">}</span>
</code></pre></div></div>

<p class="figcaption">CPU BLAS dispatch to the GEMM function. <a href="https://github.com/pytorch/pytorch/blob/c5116d9e44f7a0ab40d26e47077ecdd15693e9dd/aten/src/ATen/native/LinearAlgebra.cpp#L1539C1-L1558C2">Link</a></p>

<p>With this, we have the actual implementation of the <code class="language-plaintext highlighter-rouge">addmm</code> function for the CPU backend.</p>

<h3 id="cuda-implementation-torch_impl_funcaddmm_out_cuda">CUDA implementation: <code class="language-plaintext highlighter-rouge">TORCH_IMPL_FUNC(addmm_out_cuda)</code></h3>

<p>The <a href="https://github.com/pytorch/pytorch/blob/c5116d9e44f7a0ab40d26e47077ecdd15693e9dd/aten/src/ATen/native/cuda/Blas.cpp#L505">CUDA implementation</a> is quite similar on first sight: we again call the actual implementation function <code class="language-plaintext highlighter-rouge">addmm_out_cuda_impl()</code> which is reused in multiple other functions.</p>

<p>The actual implementation of the <code class="language-plaintext highlighter-rouge">addmm_out_cuda_impl()</code> function is in the <a href="https://github.com/pytorch/pytorch/blob/c5116d9e44f7a0ab40d26e47077ecdd15693e9dd/aten/src/ATen/native/cuda/Blas.cpp#L208">same file</a> and again starts with some shape asserts and data type checks. We again have some a check that looks at the case where the input scaling factor <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>β</mi></mrow><annotation encoding="application/x-tex">\beta</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal" style="margin-right:0.05278em;">β</span></span></span></span> is zero and handles it separately:</p>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code>  <span class="k">if</span> <span class="p">(</span><span class="n">mat1</span><span class="p">.</span><span class="n">numel</span><span class="p">()</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)</span> <span class="p">{</span>
    <span class="c1">// By definition, when beta==0, values in self should be ignored. nans and infs</span>
    <span class="c1">// should not propagate</span>
    <span class="k">if</span> <span class="p">(</span><span class="n">beta</span><span class="p">.</span><span class="n">toComplexDouble</span><span class="p">()</span> <span class="o">==</span> <span class="mf">0.</span><span class="p">)</span> <span class="p">{</span>
      <span class="k">return</span> <span class="n">result</span><span class="p">.</span><span class="n">zero_</span><span class="p">();</span>
    <span class="p">}</span>
    <span class="c1">// TODO: We could squeeze some perf by calling at::cuda::mul_out here instead, to bypass the dispatcher.</span>
    <span class="c1">// That requires some fixing some internal build dependencies though.</span>
    <span class="k">return</span> <span class="n">at</span><span class="o">::</span><span class="n">mul_out</span><span class="p">(</span>
        <span class="n">result</span><span class="p">,</span>
        <span class="n">self</span><span class="p">.</span><span class="n">expand</span><span class="p">(</span><span class="n">result</span><span class="p">.</span><span class="n">sizes</span><span class="p">()),</span>
        <span class="n">at</span><span class="o">::</span><span class="n">native</span><span class="o">::</span><span class="n">scalar_tensor</span><span class="p">(</span>
            <span class="n">beta</span><span class="p">,</span>
            <span class="n">self</span><span class="p">.</span><span class="n">scalar_type</span><span class="p">(),</span>
            <span class="n">c10</span><span class="o">::</span><span class="n">nullopt</span> <span class="cm">/* layout */</span><span class="p">,</span>
            <span class="n">at</span><span class="o">::</span><span class="n">kCPU</span><span class="p">,</span>
            <span class="n">c10</span><span class="o">::</span><span class="n">nullopt</span> <span class="cm">/* pin_memory */</span><span class="p">));</span>
  <span class="p">}</span>
</code></pre></div></div>

<p class="figcaption"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>β</mi></mrow><annotation encoding="application/x-tex">\beta</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal" style="margin-right:0.05278em;">β</span></span></span></span> checks of the <code class="language-plaintext highlighter-rouge">addmm</code> CUDA implementation. <a href="https://github.com/pytorch/pytorch/blob/c5116d9e44f7a0ab40d26e47077ecdd15693e9dd/aten/src/ATen/native/cuda/Blas.cpp#L302C1-L319C4">Link</a></p>

<p>After that, we again dispatch to different kernels (this time CUDA kernels) depending on the hardware we have. The CUDA implementation is more complex than the CPU implementation since we have to take into account the different CUDA hardware architectures and the different CUDA libraries that are available. Here is one example:</p>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code>    <span class="c1">// If batch is 1 call gemm rather than bgemm</span>
    <span class="k">if</span> <span class="p">(</span><span class="n">num_batches</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)</span> <span class="p">{</span>
      <span class="n">at</span><span class="o">::</span><span class="n">cuda</span><span class="o">::</span><span class="n">blas</span><span class="o">::</span><span class="n">gemm</span><span class="o">&lt;</span><span class="n">scalar_t</span><span class="o">&gt;</span><span class="p">(</span>
          <span class="n">transa</span><span class="p">,</span> <span class="n">transb</span><span class="p">,</span>
          <span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span>
          <span class="n">alpha_val</span><span class="p">,</span>
          <span class="n">batch1_ptr</span><span class="p">,</span> <span class="n">lda</span><span class="p">,</span>
          <span class="n">batch2_ptr</span><span class="p">,</span> <span class="n">ldb</span><span class="p">,</span>
          <span class="n">beta_val</span><span class="p">,</span>
          <span class="n">result_ptr</span><span class="p">,</span> <span class="n">ldc</span><span class="p">);</span>
    <span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
      <span class="n">at</span><span class="o">::</span><span class="n">cuda</span><span class="o">::</span><span class="n">blas</span><span class="o">::</span><span class="n">bgemm</span><span class="o">&lt;</span><span class="n">scalar_t</span><span class="o">&gt;</span><span class="p">(</span>
        <span class="n">transa</span><span class="p">,</span> <span class="n">transb</span><span class="p">,</span>
        <span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span>
        <span class="n">alpha_val</span><span class="p">,</span>
        <span class="n">batch1_ptr</span><span class="p">,</span> <span class="n">lda</span><span class="p">,</span> <span class="n">batch1_</span><span class="o">-&gt;</span><span class="n">strides</span><span class="p">()[</span><span class="mi">0</span><span class="p">],</span>
        <span class="n">batch2_ptr</span><span class="p">,</span> <span class="n">ldb</span><span class="p">,</span> <span class="n">batch2_</span><span class="o">-&gt;</span><span class="n">strides</span><span class="p">()[</span><span class="mi">0</span><span class="p">],</span>
        <span class="n">beta_val</span><span class="p">,</span>
        <span class="n">result_ptr</span><span class="p">,</span> <span class="n">ldc</span><span class="p">,</span> <span class="n">result_</span><span class="o">-&gt;</span><span class="n">strides</span><span class="p">()[</span><span class="mi">0</span><span class="p">],</span>
        <span class="n">num_batches</span>
      <span class="p">);</span>
   <span class="p">}</span>
</code></pre></div></div>

<p class="figcaption">CUDA GEMM dispatch. <a href="https://github.com/pytorch/pytorch/blob/c5116d9e44f7a0ab40d26e47077ecdd15693e9dd/aten/src/ATen/native/cuda/Blas.cpp#L474">Link</a></p>

<p>You can see that depending on the number of batches, we call either the <code class="language-plaintext highlighter-rouge">gemm</code> or the <code class="language-plaintext highlighter-rouge">bgemm</code> function from the CUDA BLAS library. The <code class="language-plaintext highlighter-rouge">bgemm</code> function is a batched version of the <code class="language-plaintext highlighter-rouge">gemm</code> function that can perform multiple matrix multiplications in parallel. This is useful if we have a batch of matrices that we want to multiply with the same matrix <code class="language-plaintext highlighter-rouge">mat2</code>. To learn more about the different CUDA BLAS functions, you can look at the <a href="https://docs.nvidia.com/cuda/cublas/index.html">cuBLAS documentation</a> and the <a href="https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html">matrix multiplication user guide</a>.</p>

<h2 id="conclusion">Conclusion</h2>

<p>In this post, we went on a whirlwind tour of the PyTorch codebase to understand how the <code class="language-plaintext highlighter-rouge">addmm</code> function is implemented. We saw that the <code class="language-plaintext highlighter-rouge">addmm</code> function is not only a PyTorch native function specified in the <code class="language-plaintext highlighter-rouge">native_functions.yaml</code> file, but also a structured kernel and that the actual implementation of the <code class="language-plaintext highlighter-rouge">addmm</code> function is in the <code class="language-plaintext highlighter-rouge">addmm.out</code> function. We then looked at the <code class="language-plaintext highlighter-rouge">addmm.out</code> function and realised that it is a wrapper around the <code class="language-plaintext highlighter-rouge">addmm_impl_cpu_()</code> and <code class="language-plaintext highlighter-rouge">addmm_impl_cuda_()</code> functions. Upon inspecting the <code class="language-plaintext highlighter-rouge">addmm_impl_cpu_()</code> and <code class="language-plaintext highlighter-rouge">addmm_impl_cuda_()</code> it became clear that these are the actual implementations of the <code class="language-plaintext highlighter-rouge">addmm</code> function for the CPU and CUDA backends and look quite complicated to to different dispatch conditions, shape checks and data type checks, but the core of the function (the matrix multiplication) in the end is again a call to a <code class="language-plaintext highlighter-rouge">kernel</code> from a library.</p>

<p>I hope that this post gave you a good overview of how to find the implementation of a PyTorch operator and how to navigate the PyTorch codebase. If you have a better way to do that, let me know!</p>

<h2 id="credits">Credits</h2>

<p>There is an amazing blog post about <a href="http://blog.ezyang.com/2019/05/pytorch-internals/">PyTorch internals</a> by Ed Zang as well as his <a href="https://pytorch-dev-podcast.simplecast.com/episodes">PyTorch developer podcast</a> that helped me immensely in understanding the PyTorch codebase. Also shoutout to Christian Perone for his <a href="https://blog.christianperone.com/2023/12/pytorch-2-intern">slides on PyTorch 2 internals</a> that shine some light on the recent developments connected with the PyTorch 2 release.</p>

<p>PyTorch Logo taken from <a href="https://www.google.com/url?sa=i&amp;url=https%3A%2F%2Fabout.fb.com%2Fnews%2F2022%2F09%2Fpytorch-foundation-to-accelerate-progress-in-ai-research%2F&amp;psig=AOvVaw3jeDHL-YCCnCHsFYMy-iUn&amp;ust=1707675842545000&amp;source=images&amp;cd=vfe&amp;opi=89978449&amp;ved=0CBQQjhxqFwoTCLiSrLGyoYQDFQAAAAAdAAAAABAE">this post</a>.</p>]]></content><author><name>Kieran Didi</name><email>kieran.didi@gmail.com</email></author><category term="ml" /><summary type="html"><![CDATA[How to find the implementation of a PyTorch operator - A whirlwind tour]]></summary><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" url="/assets/img/blog/pytorch.jpeg" /><media:content medium="image" url="/assets/img/blog/pytorch.jpeg" xmlns:media="http://search.yahoo.com/mrss/" /></entry><entry><title type="html">(GER) Was sind Diffusion Models</title><link href="/blog/ml/2023-05-15-diffusion-models/" rel="alternate" type="text/html" title="(GER) Was sind Diffusion Models" /><published>2023-05-15T00:00:00+00:00</published><updated>2023-08-26T03:04:10+00:00</updated><id>/blog/ml/diffusion-models</id><content type="html" xml:base="/blog/ml/2023-05-15-diffusion-models/"><![CDATA[<p>(Die deutsche Version beginn unten!)</p>

<p>This post is a rather unusual one since it is in German. I have always been involved in making content available in other languages to allow more people to enjoy it, such as when I did translations for Khan Academy. After translating the posts on normalising flows by Eric Jang, I have the pleasure of now translating <a href="https://lilianweng.github.io/posts/2021-07-11-diffusion-models/#classifier-free-guidance">Lily Wang’s excellent post</a> on diffusion models. I hope you enjoy it!</p>

<ul id="markdown-toc">
  <li><a href="#einführung" id="markdown-toc-einführung">Einführung</a></li>
  <li><a href="#was-sind-diffusion-models" id="markdown-toc-was-sind-diffusion-models">Was sind Diffusion Models</a></li>
  <li><a href="#forward-diffusion-process" id="markdown-toc-forward-diffusion-process">Forward Diffusion Process</a></li>
  <li><a href="#verbindung-zu-stochastic-gradient-langevin-dynamics" id="markdown-toc-verbindung-zu-stochastic-gradient-langevin-dynamics">Verbindung zu Stochastic Gradient Langevin Dynamics</a></li>
  <li><a href="#reverse-diffusion-process" id="markdown-toc-reverse-diffusion-process">Reverse Diffusion Process</a></li>
  <li><a href="#credits" id="markdown-toc-credits">Credits</a></li>
</ul>

<h2 id="einführung">Einführung</h2>

<p>GANs, VAEs und Normalising Flows sind drei Typen von Machine Learning Modellen für generative Zwecke. Alle drei haben sehr erfolgreich hochqualitative Beispiele generiert, aber jede der drei Familien hat eigene Probleme. GANs sind bekannt für instabiles Training und weniger Diversität der produzierten Beispiele durch ihr Training. VAEs basieren auf einem sogenannten “surrogate loss”. Normalising Flows müssen spezielle Architekturen verwenden, um reversible Transformationen zu konstruieren.</p>

<p>Diffusion Models sind von der “non-equilibrium” Thermodynamik inspiriert. Sie definieren eine Markov-Kette von Diffusionsschritten, um den Daten langsam zufälliges Rauschen hinzuzufügen, und lernen dann, den Diffusionsprozess umzukehren, um aus dem Rauschen gewünschte Datenproben zu konstruieren. Im Gegensatz zu VAEs oder Normalising Flows werden Diffusion Models mit einem festen Verfahren erlernt, und die latente Variable hat eine hohe Dimensionalität (dieselbe wie die Originaldaten).</p>

<p><img src="/assets/img/blog/diffusion_models/gen_model_overview.png" alt="gen_model_overview.png" /></p>

<h2 id="was-sind-diffusion-models">Was sind Diffusion Models</h2>

<p>Es wurden mehrere diffusionsbasierte generative Modelle mit ähnlichen Ideen vorgeschlagen, darunter diffusion probabilistic models (<a href="https://arxiv.org/abs/1503.03585">Sohl-Dickstein et al., 2015</a>), noise-conditioned score network (NCSN; <a href="https://arxiv.org/abs/1907.05600">Yang &amp; Ermon, 2019</a>), und denoising diffusion probabilistic models (DDPM; <a href="https://arxiv.org/abs/2006.11239">Ho et al. 2020</a>).</p>

<h2 id="forward-diffusion-process">Forward Diffusion Process</h2>

<p>Nehmen wir an, wir haben einen Datenpunkt von einer realen Datenverteilung, <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>x</mi><mn>0</mn></msub><mo>∼</mo><mi>q</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">x_0 \sim q(x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.5806em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">0</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">∼</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">q</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span>. Dann können wir einen <em>forward diffusion process</em> definieren, in dem wir in <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>T</mi></mrow><annotation encoding="application/x-tex">T</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.13889em;">T</span></span></span></span> Schritten kleine Mengen an Gaussian noise zu dem Datenpunkt hinzufügen und damit eine Sequenz <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>x</mi><mn>1</mn></msub><mo separator="true">,</mo><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mo separator="true">,</mo><msub><mi>x</mi><mi>T</mi></msub></mrow><annotation encoding="application/x-tex">x_1, ..., x_T</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">...</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3283em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> an korrumpierten (sogenannten <em>noised</em>) Datenpunkten erzeugen. Wir kontrollieren die Schrittgröße zwischen diesen Datenpunkten mit der sogenannten <em>variance schedule</em> <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo stretchy="false">{</mo><msub><mi>β</mi><mi>t</mi></msub><mo>∈</mo><mo stretchy="false">(</mo><mn>0</mn><mo separator="true">,</mo><mn>1</mn><mo stretchy="false">)</mo><msubsup><mo stretchy="false">}</mo><mrow><mi>t</mi><mo>=</mo><mn>1</mn></mrow><mi>T</mi></msubsup></mrow><annotation encoding="application/x-tex">\{\beta_t \in (0,1)\}^T_{t=1}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">{</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0528em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1.0913em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord">0</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">1</span><span class="mclose">)</span><span class="mclose"><span class="mclose">}</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8413em;"><span style="top:-2.4519em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mrel mtight">=</span><span class="mord mtight">1</span></span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2481em;"><span></span></span></span></span></span></span></span></span></span>.</p>

<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mtable rowspacing="0.25em" columnalign="right" columnspacing=""><mtr><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow><mi>q</mi><mo stretchy="false">(</mo><msub><mi>x</mi><mi>t</mi></msub><mi mathvariant="normal">∣</mi><msub><mi>x</mi><mrow><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub><mo stretchy="false">)</mo><mo>=</mo><mi mathvariant="script">N</mi><mo stretchy="false">(</mo><msub><mi>x</mi><mi>t</mi></msub><mo separator="true">;</mo><msqrt><mrow><mn>1</mn><mo>−</mo><msub><mi>β</mi><mi>t</mi></msub></mrow></msqrt><msub><mi>x</mi><mrow><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub><mo separator="true">,</mo><msub><mi>β</mi><mi>t</mi></msub><mtext mathvariant="bold">I</mtext><mo stretchy="false">)</mo><mspace width="1.0037em"/><mi>q</mi><mo stretchy="false">(</mo><msub><mi>x</mi><mrow><mn>1</mn><mo>:</mo><mi>T</mi></mrow></msub><mo>=</mo><munderover><mo>∏</mo><mrow><mi>t</mi><mo>=</mo><mn>1</mn></mrow><mi>T</mi></munderover><mi>q</mi><mo stretchy="false">(</mo><msub><mi>x</mi><mi>t</mi></msub><mi mathvariant="normal">∣</mi><msub><mi>x</mi><msub><mi>t</mi><mn>1</mn></msub></msub><mo stretchy="false">)</mo><mo stretchy="false">)</mo></mrow></mstyle></mtd></mtr></mtable><annotation encoding="application/x-tex">\begin{aligned}
    q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t} x_{t-1}, \beta_t \textbf{I}) \hspace{10px} q(x_{1:T} = \prod^T_{t=1} q(x_t | x_{t_1}))
\end{aligned}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:3.3954em;vertical-align:-1.4477em;"></span><span class="mord"><span class="mtable"><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.9477em;"><span style="top:-3.9477em;"><span class="pstrut" style="height:3.8283em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">q</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mord mathcal" style="margin-right:0.14736em;">N</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">;</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9839em;"><span class="svg-align" style="top:-3.2em;"><span class="pstrut" style="height:3.2em;"></span><span class="mord" style="padding-left:1em;"><span class="mord">1</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0528em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.9439em;"><span class="pstrut" style="height:3.2em;"></span><span class="hide-tail" style="min-width:1.02em;height:1.28em;"><svg xmlns="http://www.w3.org/2000/svg" width='400em' height='1.28em' viewBox='0 0 400000 1296' preserveAspectRatio='xMinYMin slice'><path d='M263,681c0.7,0,18,39.7,52,119
c34,79.3,68.167,158.7,102.5,238c34.3,79.3,51.8,119.3,52.5,120
c340,-704.7,510.7,-1060.3,512,-1067
l0 -0
c4.7,-7.3,11,-11,19,-11
H40000v40H1012.3
s-271.3,567,-271.3,567c-38.7,80.7,-84,175,-136,283c-52,108,-89.167,185.3,-111.5,232
c-22.3,46.7,-33.8,70.3,-34.5,71c-4.7,4.7,-12.3,7,-23,7s-12,-1,-12,-1
s-109,-253,-109,-253c-72.7,-168,-109.3,-252,-110,-252c-10.7,8,-22,16.7,-34,26
c-22,17.3,-33.3,26,-34,26s-26,-26,-26,-26s76,-59,76,-59s76,-60,76,-60z
M1001 80h400000v40h-400000z'/></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2561em;"><span></span></span></span></span></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0528em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord text"><span class="mord textbf">I</span></span><span class="mclose">)</span><span class="mspace" style="margin-right:1.0037em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">q</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3283em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span><span class="mrel mtight">:</span><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.8283em;"><span style="top:-1.8829em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mrel mtight">=</span><span class="mord mtight">1</span></span></span></span><span style="top:-3.05em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op">∏</span></span></span><span style="top:-4.3em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.2671em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">q</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3173em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.0714em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2501em;"><span></span></span></span></span></span></span><span class="mclose">))</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.4477em;"><span></span></span></span></span></span></span></span></span></span></span></span>

<p>Unser Datenpunkt <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>x</mi><mn>0</mn></msub></mrow><annotation encoding="application/x-tex">x_0</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.5806em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">0</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> verliert so seine erkennbaren Eigenschaften wenn <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>t</mi></mrow><annotation encoding="application/x-tex">t</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6151em;"></span><span class="mord mathnormal">t</span></span></span></span> größer wird. Wenn <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>T</mi><mo>→</mo><mi mathvariant="normal">∞</mi></mrow><annotation encoding="application/x-tex">T \to \infty</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.13889em;">T</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">→</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord">∞</span></span></span></span> ist <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>x</mi><mi>T</mi></msub></mrow><annotation encoding="application/x-tex">x_T</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.5806em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3283em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> equivalent zur isotropen Normalverteilung.</p>

<p><img src="/assets/img/blog/diffusion_models/diffusion_process.png" alt="diffusion_process" /></p>

<p class="figcaption">Fig. 2. Die Markovkette des <em>forward (reverse) diffusion process</em>, in dem eine Stichprobe durch langsames Hinzufügen/Entfernen von Rauschen erzeugt wird. Quelle: <a href="https://arxiv.org/abs/2006.11239">Ho et al. 2020</a> mit einigen zusätzlichen Anmerkungen.</p>

<p>Eine nützliche Eigenschaft dieses Prozesses ist dass wir <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>x</mi><mi>t</mi></msub></mrow><annotation encoding="application/x-tex">x_t</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.5806em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> zu einem beliebigen Zeitpunkt <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>t</mi></mrow><annotation encoding="application/x-tex">t</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6151em;"></span><span class="mord mathnormal">t</span></span></span></span> in geschlossener Form samplen können, und zwar mithilfe eines <a href="https://lilianweng.github.io/posts/2018-08-12-vae/#reparameterization-trick">Reparametrisierungs-Tricks</a>. Sei <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>α</mi><mi>t</mi></msub><mo>=</mo><mn>1</mn><mo>−</mo><msub><mi>β</mi><mi>t</mi></msub></mrow><annotation encoding="application/x-tex">\alpha_t = 1 - \beta_t</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.5806em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0037em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.7278em;vertical-align:-0.0833em;"></span><span class="mord">1</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0528em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> und <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mover accent="true"><msub><mi>α</mi><mi>t</mi></msub><mo stretchy="true">‾</mo></mover><mo>=</mo><msubsup><mo>∏</mo><mrow><mi>i</mi><mo>=</mo><mn>1</mn></mrow><mi>t</mi></msubsup><msub><mi>α</mi><mi>i</mi></msub></mrow><annotation encoding="application/x-tex">\overline{\alpha_t} = \prod^t_{i=1} \alpha_i</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.7806em;vertical-align:-0.15em;"></span><span class="mord overline"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6306em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal" style="margin-right:0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0037em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.5506em;"><span class="pstrut" style="height:3em;"></span><span class="overline-line" style="border-bottom-width:0.04em;"></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1.2332em;vertical-align:-0.2997em;"></span><span class="mop"><span class="mop op-symbol small-op" style="position:relative;top:0em;">∏</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9335em;"><span style="top:-2.4003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span><span class="mrel mtight">=</span><span class="mord mtight">1</span></span></span></span><span style="top:-3.2029em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2997em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:-0.0037em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span>:</p>

<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mtable rowspacing="0.25em" columnalign="right left" columnspacing="0em"><mtr><mtd><mstyle scriptlevel="0" displaystyle="true"><msub><mi>x</mi><mi>t</mi></msub></mstyle></mtd><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow><mrow></mrow><mo>=</mo><msqrt><msub><mi>α</mi><mi>t</mi></msub></msqrt><msub><mi>x</mi><mrow><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub><mo>+</mo><msqrt><mrow><mn>1</mn><mo>−</mo><msub><mi>α</mi><mi>t</mi></msub></mrow></msqrt><msub><mi>ϵ</mi><mrow><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub><mo separator="true">;</mo><mspace width="1.0037em"/><msub><mi>ϵ</mi><mrow><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub><mo separator="true">,</mo><msub><mi>ϵ</mi><mrow><mi>t</mi><mo>−</mo><mn>2</mn></mrow></msub><mo separator="true">,</mo><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mo>∼</mo><mi mathvariant="script">n</mi><mo stretchy="false">(</mo><mn>0</mn><mo separator="true">,</mo><mtext mathvariant="bold">I</mtext><mo stretchy="false">)</mo></mrow></mstyle></mtd></mtr><mtr><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow></mrow></mstyle></mtd><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow><mrow></mrow><mo>=</mo><msqrt><mrow><msub><mi>α</mi><mi>t</mi></msub><msub><mi>α</mi><mrow><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub></mrow></msqrt><msub><mi>x</mi><mrow><mi>t</mi><mo>−</mo><mn>2</mn></mrow></msub><mo>+</mo><msqrt><mrow><mn>1</mn><mo>−</mo><msub><mi>α</mi><mi>t</mi></msub><msub><mi>α</mi><mrow><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub></mrow></msqrt><mover accent="true"><msub><mi>ϵ</mi><mrow><mi>t</mi><mo>−</mo><mn>2</mn></mrow></msub><mo stretchy="true">‾</mo></mover><mo stretchy="false">(</mo><mo>∗</mo><mo stretchy="false">)</mo></mrow></mstyle></mtd></mtr><mtr><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow></mrow></mstyle></mtd><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow><mrow></mrow><mo>=</mo><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi></mrow></mstyle></mtd></mtr><mtr><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow></mrow></mstyle></mtd><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow><mrow></mrow><mo>=</mo><msqrt><mover accent="true"><msub><mi>α</mi><mi>t</mi></msub><mo stretchy="true">‾</mo></mover></msqrt><msub><mi>x</mi><mn>0</mn></msub><mo>+</mo><msqrt><mrow><mn>1</mn><mo>−</mo><mover accent="true"><msub><mi>α</mi><mi>t</mi></msub><mo stretchy="true">‾</mo></mover></mrow></msqrt><mi>ϵ</mi></mrow></mstyle></mtd></mtr><mtr><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow><mi>q</mi><mo stretchy="false">(</mo><msub><mi>x</mi><mi>t</mi></msub><mi mathvariant="normal">∣</mi><msub><mi>x</mi><mn>0</mn></msub><mo stretchy="false">)</mo><mo>=</mo><mi mathvariant="script">N</mi><mo stretchy="false">(</mo><msub><mi>x</mi><mi>t</mi></msub><mo separator="true">;</mo><msqrt><mover accent="true"><msub><mi>α</mi><mi>t</mi></msub><mo stretchy="true">‾</mo></mover></msqrt><msub><mi>x</mi><mn>0</mn></msub><mo separator="true">,</mo><mo stretchy="false">(</mo><mn>1</mn><mo>−</mo><mover accent="true"><msub><mi>α</mi><mi>t</mi></msub><mo stretchy="true">‾</mo></mover><mo stretchy="false">)</mo><mtext mathvariant="bold">I</mtext><mo stretchy="false">)</mo></mrow></mstyle></mtd></mtr></mtable><annotation encoding="application/x-tex">\begin{aligned}%!!15
    x_t &amp;= \sqrt{\alpha_t}x_{t-1} + \sqrt{1-\alpha_t} \epsilon_{t-1}; \hspace{10px} \epsilon_{t-1}, \epsilon_{t-2}, ... \sim \mathcal{n}(0,\textbf{I}) \\[1em]
        &amp;= \sqrt{\alpha_t \alpha_{t-1}}x_{t-2}  + \sqrt{1-\alpha_t \alpha_{t-1}} \overline{\epsilon_{t-2}} (*) \\[1em]
        &amp;= ... \\[1em]
        &amp;= \sqrt{\overline{\alpha_t}}x_0 + \sqrt{1-\overline{\alpha_t}}\epsilon \\[1em]

    q(x_t | x_{0}) = \mathcal{N}(x_t; \sqrt{\overline{\alpha_t}} x_{0}, (1-\overline{\alpha_t}) \textbf{I})
\end{aligned}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:11.7283em;vertical-align:-5.6141em;"></span><span class="mord"><span class="mtable"><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:6.1141em;"><span style="top:-8.233em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span style="top:-5.6211em;"><span class="pstrut" style="height:3em;"></span><span class="mord"></span></span><span style="top:-3.1211em;"><span class="pstrut" style="height:3em;"></span><span class="mord"></span></span><span style="top:-0.58em;"><span class="pstrut" style="height:3em;"></span><span class="mord"></span></span><span style="top:1.9541em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">q</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">0</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mord mathcal" style="margin-right:0.14736em;">N</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">;</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8742em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style="padding-left:0.833em;"><span class="mord overline"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6306em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal" style="margin-right:0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0037em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.5506em;"><span class="pstrut" style="height:3em;"></span><span class="overline-line" style="border-bottom-width:0.04em;"></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span style="top:-2.8342em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail" style="min-width:0.853em;height:1.08em;"><svg xmlns="http://www.w3.org/2000/svg" width='400em' height='1.08em' viewBox='0 0 400000 1080' preserveAspectRatio='xMinYMin slice'><path d='M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z'/></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.1658em;"><span></span></span></span></span></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">0</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mopen">(</span><span class="mord">1</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord overline"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6306em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal" style="margin-right:0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0037em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.5506em;"><span class="pstrut" style="height:3em;"></span><span class="overline-line" style="border-bottom-width:0.04em;"></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span><span class="mclose">)</span><span class="mord text"><span class="mord textbf">I</span></span><span class="mclose">)</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:5.6141em;"><span></span></span></span></span></span><span class="col-align-l"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:6.1141em;"><span style="top:-8.233em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.7742em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style="padding-left:0.833em;"><span class="mord"><span class="mord mathnormal" style="margin-right:0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0037em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.7342em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail" style="min-width:0.853em;height:1.08em;"><svg xmlns="http://www.w3.org/2000/svg" width='400em' height='1.08em' viewBox='0 0 400000 1080' preserveAspectRatio='xMinYMin slice'><path d='M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z'/></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2658em;"><span></span></span></span></span></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8811em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style="padding-left:0.833em;"><span class="mord">1</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0037em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.8411em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail" style="min-width:0.853em;height:1.08em;"><svg xmlns="http://www.w3.org/2000/svg" width='400em' height='1.08em' viewBox='0 0 400000 1080' preserveAspectRatio='xMinYMin slice'><path d='M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z'/></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.1589em;"><span></span></span></span></span></span><span class="mord"><span class="mord mathnormal">ϵ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span><span class="mpunct">;</span><span class="mspace" style="margin-right:1.0037em;"></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal">ϵ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal">ϵ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">...</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">∼</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mord mathnormal">n</span><span class="mopen">(</span><span class="mord">0</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord text"><span class="mord textbf">I</span></span><span class="mclose">)</span></span></span><span style="top:-5.6211em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.745em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style="padding-left:0.833em;"><span class="mord"><span class="mord mathnormal" style="margin-right:0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0037em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:-0.0037em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.705em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail" style="min-width:0.853em;height:1.08em;"><svg xmlns="http://www.w3.org/2000/svg" width='400em' height='1.08em' viewBox='0 0 400000 1080' preserveAspectRatio='xMinYMin slice'><path d='M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z'/></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.295em;"><span></span></span></span></span></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9519em;"><span class="svg-align" style="top:-3.2em;"><span class="pstrut" style="height:3.2em;"></span><span class="mord" style="padding-left:1em;"><span class="mord">1</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0037em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:-0.0037em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.9119em;"><span class="pstrut" style="height:3.2em;"></span><span class="hide-tail" style="min-width:1.02em;height:1.28em;"><svg xmlns="http://www.w3.org/2000/svg" width='400em' height='1.28em' viewBox='0 0 400000 1296' preserveAspectRatio='xMinYMin slice'><path d='M263,681c0.7,0,18,39.7,52,119
c34,79.3,68.167,158.7,102.5,238c34.3,79.3,51.8,119.3,52.5,120
c340,-704.7,510.7,-1060.3,512,-1067
l0 -0
c4.7,-7.3,11,-11,19,-11
H40000v40H1012.3
s-271.3,567,-271.3,567c-38.7,80.7,-84,175,-136,283c-52,108,-89.167,185.3,-111.5,232
c-22.3,46.7,-33.8,70.3,-34.5,71c-4.7,4.7,-12.3,7,-23,7s-12,-1,-12,-1
s-109,-253,-109,-253c-72.7,-168,-109.3,-252,-110,-252c-10.7,8,-22,16.7,-34,26
c-22,17.3,-33.3,26,-34,26s-26,-26,-26,-26s76,-59,76,-59s76,-60,76,-60z
M1001 80h400000v40h-400000z'/></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2881em;"><span></span></span></span></span></span><span class="mord overline"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6306em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal">ϵ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.5506em;"><span class="pstrut" style="height:3em;"></span><span class="overline-line" style="border-bottom-width:0.04em;"></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span><span class="mopen">(</span><span class="mord">∗</span><span class="mclose">)</span></span></span><span style="top:-3.1211em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mord">...</span></span></span><span style="top:-0.58em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8742em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style="padding-left:0.833em;"><span class="mord overline"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6306em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal" style="margin-right:0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0037em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.5506em;"><span class="pstrut" style="height:3em;"></span><span class="overline-line" style="border-bottom-width:0.04em;"></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span style="top:-2.8342em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail" style="min-width:0.853em;height:1.08em;"><svg xmlns="http://www.w3.org/2000/svg" width='400em' height='1.08em' viewBox='0 0 400000 1080' preserveAspectRatio='xMinYMin slice'><path d='M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z'/></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.1658em;"><span></span></span></span></span></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">0</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8811em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style="padding-left:0.833em;"><span class="mord">1</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord overline"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6306em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal" style="margin-right:0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0037em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.5506em;"><span class="pstrut" style="height:3em;"></span><span class="overline-line" style="border-bottom-width:0.04em;"></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span style="top:-2.8411em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail" style="min-width:0.853em;height:1.08em;"><svg xmlns="http://www.w3.org/2000/svg" width='400em' height='1.08em' viewBox='0 0 400000 1080' preserveAspectRatio='xMinYMin slice'><path d='M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z'/></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.1589em;"><span></span></span></span></span></span><span class="mord mathnormal">ϵ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:4.08em;"><span></span></span></span></span></span></span></span></span></span></span></span>

<p>(*) Wenn wir zwei Normalverteilungen mit verschiedenen Varianzen kombinieren, hat die neue Normalverteilung die Summe der Varianzen als Varianz: <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi mathvariant="script">N</mi><mo stretchy="false">(</mo><mn>0</mn><mo separator="true">,</mo><msubsup><mi>σ</mi><mn>1</mn><mn>2</mn></msubsup><mtext mathvariant="bold">I</mtext><mo stretchy="false">)</mo><mo>+</mo><mi mathvariant="script">N</mi><mo stretchy="false">(</mo><mn>0</mn><mo separator="true">,</mo><msubsup><mi>σ</mi><mn>2</mn><mn>2</mn></msubsup><mtext mathvariant="bold">I</mtext><mo stretchy="false">)</mo><mo>=</mo><mi mathvariant="script">N</mi><mo stretchy="false">(</mo><mn>0</mn><mo separator="true">,</mo><mo stretchy="false">(</mo><msubsup><mi>σ</mi><mn>1</mn><mn>2</mn></msubsup><mo>+</mo><msubsup><mi>σ</mi><mn>2</mn><mn>2</mn></msubsup><mo stretchy="false">)</mo><mtext mathvariant="bold">I</mtext><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\mathcal{N}(0, \sigma^2_1\textbf{I}) + \mathcal{N}(0, \sigma^2_2\textbf{I}) = \mathcal{N}(0, (\sigma^2_1 + \sigma^2_2)\textbf{I})</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.0641em;vertical-align:-0.25em;"></span><span class="mord mathcal" style="margin-right:0.14736em;">N</span><span class="mopen">(</span><span class="mord">0</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8141em;"><span style="top:-2.4519em;margin-left:-0.0359em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2481em;"><span></span></span></span></span></span></span><span class="mord text"><span class="mord textbf">I</span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1.0641em;vertical-align:-0.25em;"></span><span class="mord mathcal" style="margin-right:0.14736em;">N</span><span class="mopen">(</span><span class="mord">0</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8141em;"><span style="top:-2.4519em;margin-left:-0.0359em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2481em;"><span></span></span></span></span></span></span><span class="mord text"><span class="mord textbf">I</span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1.0641em;vertical-align:-0.25em;"></span><span class="mord mathcal" style="margin-right:0.14736em;">N</span><span class="mopen">(</span><span class="mord">0</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8141em;"><span style="top:-2.4519em;margin-left:-0.0359em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2481em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1.0641em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8141em;"><span style="top:-2.4519em;margin-left:-0.0359em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2481em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mord text"><span class="mord textbf">I</span></span><span class="mclose">)</span></span></span></span>. In unserem Falle ist die kombinierte Standardabweichung <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msqrt><mrow><mo stretchy="false">(</mo><mn>1</mn><mo>−</mo><msub><mi>α</mi><mi>t</mi></msub><mo stretchy="false">)</mo><mo>+</mo><msub><mi>α</mi><mi>t</mi></msub><mo stretchy="false">(</mo><mn>1</mn><mo>−</mo><msub><mi>α</mi><mrow><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub><mo stretchy="false">)</mo></mrow></msqrt><mo>=</mo><msqrt><mrow><mn>1</mn><mo>−</mo><msub><mi>α</mi><mi>t</mi></msub><msub><mi>α</mi><mrow><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub></mrow></msqrt></mrow><annotation encoding="application/x-tex">\sqrt{(1-\alpha_t) + \alpha_t(1-\alpha_{t-1})} = \sqrt{1-\alpha_t \alpha_{t-1}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.24em;vertical-align:-0.305em;"></span><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.935em;"><span class="svg-align" style="top:-3.2em;"><span class="pstrut" style="height:3.2em;"></span><span class="mord" style="padding-left:1em;"><span class="mopen">(</span><span class="mord">1</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0037em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0037em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord">1</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:-0.0037em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span><span style="top:-2.895em;"><span class="pstrut" style="height:3.2em;"></span><span class="hide-tail" style="min-width:1.02em;height:1.28em;"><svg xmlns="http://www.w3.org/2000/svg" width='400em' height='1.28em' viewBox='0 0 400000 1296' preserveAspectRatio='xMinYMin slice'><path d='M263,681c0.7,0,18,39.7,52,119
c34,79.3,68.167,158.7,102.5,238c34.3,79.3,51.8,119.3,52.5,120
c340,-704.7,510.7,-1060.3,512,-1067
l0 -0
c4.7,-7.3,11,-11,19,-11
H40000v40H1012.3
s-271.3,567,-271.3,567c-38.7,80.7,-84,175,-136,283c-52,108,-89.167,185.3,-111.5,232
c-22.3,46.7,-33.8,70.3,-34.5,71c-4.7,4.7,-12.3,7,-23,7s-12,-1,-12,-1
s-109,-253,-109,-253c-72.7,-168,-109.3,-252,-110,-252c-10.7,8,-22,16.7,-34,26
c-22,17.3,-33.3,26,-34,26s-26,-26,-26,-26s76,-59,76,-59s76,-60,76,-60z
M1001 80h400000v40h-400000z'/></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.305em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1.04em;vertical-align:-0.2369em;"></span><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8031em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style="padding-left:0.833em;"><span class="mord">1</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0037em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:-0.0037em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.7631em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail" style="min-width:0.853em;height:1.08em;"><svg xmlns="http://www.w3.org/2000/svg" width='400em' height='1.08em' viewBox='0 0 400000 1080' preserveAspectRatio='xMinYMin slice'><path d='M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z'/></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2369em;"><span></span></span></span></span></span></span></span></span>.</p>

<p>Normalerweise können wir uns größere Updateschritte erlauben wenn unsere Sample mehr Rauschen enthält, also setzen wir die variance schedule so, dass <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>β</mi><mi>t</mi></msub></mrow><annotation encoding="application/x-tex">\beta_t</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0528em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> mit <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>t</mi></mrow><annotation encoding="application/x-tex">t</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6151em;"></span><span class="mord mathnormal">t</span></span></span></span> wächst: <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>β</mi><mn>1</mn></msub><mo>&lt;</mo><msub><mi>β</mi><mn>2</mn></msub><mo>&lt;</mo><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mo>&lt;</mo><msub><mi>β</mi><mi>t</mi></msub></mrow><annotation encoding="application/x-tex">\beta_1 &lt; \beta_2 &lt; ... &lt; \beta_t</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:-0.0528em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">&lt;</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:-0.0528em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">&lt;</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.5782em;vertical-align:-0.0391em;"></span><span class="mord">...</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">&lt;</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0528em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> und daher <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mover accent="true"><msub><mi>α</mi><mn>1</mn></msub><mo stretchy="true">‾</mo></mover><mo>&gt;</mo><mover accent="true"><msub><mi>α</mi><mn>2</mn></msub><mo stretchy="true">‾</mo></mover><mo>&gt;</mo><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mo>&gt;</mo><mover accent="true"><msub><mi>α</mi><mi>t</mi></msub><mo stretchy="true">‾</mo></mover></mrow><annotation encoding="application/x-tex">\overline{\alpha_1} &gt; \overline{\alpha_2} &gt; ... &gt; \overline{\alpha_t}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.7806em;vertical-align:-0.15em;"></span><span class="mord overline"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6306em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal" style="margin-right:0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:-0.0037em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.5506em;"><span class="pstrut" style="height:3em;"></span><span class="overline-line" style="border-bottom-width:0.04em;"></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">&gt;</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.7806em;vertical-align:-0.15em;"></span><span class="mord overline"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6306em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal" style="margin-right:0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:-0.0037em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.5506em;"><span class="pstrut" style="height:3em;"></span><span class="overline-line" style="border-bottom-width:0.04em;"></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">&gt;</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.5782em;vertical-align:-0.0391em;"></span><span class="mord">...</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">&gt;</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.7806em;vertical-align:-0.15em;"></span><span class="mord overline"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6306em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal" style="margin-right:0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0037em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.5506em;"><span class="pstrut" style="height:3em;"></span><span class="overline-line" style="border-bottom-width:0.04em;"></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span>.</p>

<h2 id="verbindung-zu-stochastic-gradient-langevin-dynamics">Verbindung zu Stochastic Gradient Langevin Dynamics</h2>

<p>Langevin Dynamics ist ein Konzept aus der Physik das zur statistischen Modellierung von molekularen Systemen entwickelt wurde. Wenn dieses Verfahren mit stochastic gradient descent kombiniert wird, erhalten wir <em>stochastic gradient langevin dynamics</em> (<a href="https://www.stats.ox.ac.uk/~teh/research/compstats/WelTeh2011a.pdf">Welling &amp; Teh 2011</a>). Dieses Verfahren kann Stichproben von einer Wahrscheinlichkeitsverteilung <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">p(x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span> ziehen und benötigt hierfür nur die Gradienten der Log-Wahrscheinlichkeit <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi mathvariant="normal">∇</mi><mi>x</mi></msub><mi>log</mi><mo>⁡</mo><mi>p</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\nabla_x \log p(x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord">∇</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">x</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop">lo<span style="margin-right:0.01389em;">g</span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span>. Die Gradienten werden mit einem Rauschterm kombiniert, um die Stichproben zu erzeugen. Die Stichproben werden dann verwendet, um die Gradienten zu schätzen, und der Prozess wird wiederholt. Dieser iterative Prozess kann als Markovkette bestehend aus Updates beschrieben werden:</p>

<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><msub><mi>x</mi><mi>t</mi></msub><mo>=</mo><msub><mi>x</mi><mrow><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub><mo>+</mo><mfrac><mi>δ</mi><mn>2</mn></mfrac><msub><mi mathvariant="normal">∇</mi><mi>x</mi></msub><mi>log</mi><mo>⁡</mo><mi>p</mi><mo stretchy="false">(</mo><msub><mi>x</mi><mrow><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub><mo stretchy="false">)</mo><mo>+</mo><msqrt><mi>δ</mi></msqrt><msub><mi>ϵ</mi><mi>t</mi></msub><mo separator="true">;</mo><mspace width="1.0037em"/><msub><mi>ϵ</mi><mi>t</mi></msub><mo>∼</mo><mi mathvariant="script">N</mi><mo stretchy="false">(</mo><mn>0</mn><mo separator="true">,</mo><mtext mathvariant="bold">I</mtext><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">x_t = x_{t-1} + \frac{\delta}{2} \nabla_x \log p(x_{t-1}) + \sqrt{\delta} \epsilon_t; \hspace{10px} \epsilon_t \sim \mathcal{N}(0, \textbf{I})</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.5806em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.7917em;vertical-align:-0.2083em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:2.0574em;vertical-align:-0.686em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.3714em;"><span style="top:-2.314em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord">2</span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03785em;">δ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.686em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mord"><span class="mord">∇</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">x</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop">lo<span style="margin-right:0.01389em;">g</span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1.1755em;vertical-align:-0.1944em;"></span><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9811em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style="padding-left:0.833em;"><span class="mord mathnormal" style="margin-right:0.03785em;">δ</span></span></span><span style="top:-2.9411em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail" style="min-width:0.853em;height:1.08em;"><svg xmlns="http://www.w3.org/2000/svg" width='400em' height='1.08em' viewBox='0 0 400000 1080' preserveAspectRatio='xMinYMin slice'><path d='M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z'/></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.0589em;"><span></span></span></span></span></span><span class="mord"><span class="mord mathnormal">ϵ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">;</span><span class="mspace" style="margin-right:1.0037em;"></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal">ϵ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">∼</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathcal" style="margin-right:0.14736em;">N</span><span class="mopen">(</span><span class="mord">0</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord text"><span class="mord textbf">I</span></span><span class="mclose">)</span></span></span></span></span>

<p>mit <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>δ</mi></mrow><annotation encoding="application/x-tex">\delta</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord mathnormal" style="margin-right:0.03785em;">δ</span></span></span></span> als die Schrittgröße der Updates. Wenn wir <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>T</mi><mo>→</mo><mn>0</mn></mrow><annotation encoding="application/x-tex">T \to 0</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.13889em;">T</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">→</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">0</span></span></span></span> gehen lassen, geht <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>ϵ</mi><mo>→</mo><mn>0</mn></mrow><annotation encoding="application/x-tex">\epsilon \to 0</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord mathnormal">ϵ</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">→</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">0</span></span></span></span> und wir erhalten die tatsächliche Wahrscheinlichkeitsverteilung <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">p(x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span>.</p>

<p>Verglichen mit standard Gradient Descent Methoden, die nur die Gradienten der Log-Wahrscheinlichkeit verwenden, fügen wir hier einen Rauschterm hinzu. Hierdurch verhindern wir den Kollaps in lokale Minima der Wahrscheinlichkeitsverteilung.</p>

<h2 id="reverse-diffusion-process">Reverse Diffusion Process</h2>

<p>Wenn wir den oben beschriebenen <em>forward diffusion process</em> umkehren und somit Stichproben von <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>q</mi><mo stretchy="false">(</mo><msub><mi>x</mi><mrow><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub><mi mathvariant="normal">∣</mi><msub><mi>x</mi><mi>t</mi></msub><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">q(x_{t-1} \vert x_{t})</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">q</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span> ziehen könnten, können wir aus Gauss’schen Rauschen <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mstyle mathcolor="#cc0000"><mtext>\x</mtext></mstyle><mi>T</mi></msub><mo>∼</mo><mi mathvariant="script">N</mi><mo stretchy="false">(</mo><mn>0</mn><mo separator="true">,</mo><mtext mathvariant="bold">I</mtext><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\x_T \sim \mathcal{N}(0, \textbf{I})</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.0497em;vertical-align:-0.2997em;"></span><span class="mord"><span class="mord text" style="color:#cc0000;"><span class="mord" style="color:#cc0000;">\x</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1786em;"><span style="top:-2.4003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2997em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">∼</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathcal" style="margin-right:0.14736em;">N</span><span class="mopen">(</span><span class="mord">0</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord text"><span class="mord textbf">I</span></span><span class="mclose">)</span></span></span></span> Stichproben von <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">p(x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span> ziehen. Dieser Prozess wird als <em>reverse diffusion process</em> bezeichnet.</p>

<p>Falls <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>β</mi><mi>t</mi></msub></mrow><annotation encoding="application/x-tex">\beta_t</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0528em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> klein genoug ist, wird <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>q</mi><mo stretchy="false">(</mo><msub><mi>x</mi><mrow><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub><mi mathvariant="normal">∣</mi><msub><mi>x</mi><mi>t</mi></msub><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">q(x_{t-1} \vert x_{t})</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">q</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span> ebenfalls einer Normalverteilung folgen.</p>

<p>Leider müssten wir die gesamte Datenverteilung <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">p(x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span> kennen, um <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>q</mi><mo stretchy="false">(</mo><msub><mi>x</mi><mrow><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub><mi mathvariant="normal">∣</mi><msub><mi>x</mi><mi>t</mi></msub><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">q(x_{t-1} \vert x_{t})</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">q</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span> zu berechnen. Dies ist in der Praxis nicht möglich. Wir können jedoch ein Modell <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>p</mi><mi>θ</mi></msub></mrow><annotation encoding="application/x-tex">p_{\theta}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em;">θ</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> lernen, dass diese bedingten Wahrscheinlichkeiten approximiert. Mithilfe dieses Modells können wir dann den <em>reverse diffusion process</em> durchführen und näherungsweise Stichproben von <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">p(x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span> ziehen:</p>

<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><msub><mi>p</mi><mi>θ</mi></msub><mo stretchy="false">(</mo><msub><mi>x</mi><mrow><mn>0</mn><mo>:</mo><mi>T</mi></mrow></msub><mo stretchy="false">)</mo><mo>=</mo><mi>p</mi><mo stretchy="false">(</mo><msub><mi>x</mi><mi>T</mi></msub><mo stretchy="false">)</mo><munderover><mo>∏</mo><mrow><mi>t</mi><mo>=</mo><mn>1</mn></mrow><mi>T</mi></munderover><msub><mi>p</mi><mi>θ</mi></msub><mo stretchy="false">(</mo><msub><mi>x</mi><mrow><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub><mi mathvariant="normal">∣</mi><msub><mi>x</mi><mi>t</mi></msub><mo stretchy="false">)</mo><mspace width="1.0037em"/><msub><mi>p</mi><mi>θ</mi></msub><mo stretchy="false">(</mo><msub><mi>x</mi><mrow><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub><mi mathvariant="normal">∣</mi><msub><mi>x</mi><mi>t</mi></msub><mo stretchy="false">)</mo><mo>=</mo><mi mathvariant="script">N</mi><mo stretchy="false">(</mo><msub><mi>x</mi><mrow><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub><mo separator="true">;</mo><msub><mi>μ</mi><mi>θ</mi></msub><mo stretchy="false">(</mo><msub><mi>x</mi><mi>t</mi></msub><mo separator="true">,</mo><mi>t</mi><mo stretchy="false">)</mo><mo separator="true">,</mo><msub><mi mathvariant="normal">Σ</mi><mi>θ</mi></msub><mo stretchy="false">(</mo><msub><mi>x</mi><mi>t</mi></msub><mo separator="true">,</mo><mi>t</mi><mo stretchy="false">)</mo><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">p_{\theta}(x_{0:T}) = p(x_T) \prod^T_{t=1} p_{\theta}(x_{t-1} | x_{t}) \hspace{10px} p_{\theta}(x_{t-1} | x_{t}) = \mathcal{N}(x_{t-1}; \mu_{\theta}(x_t, t), \Sigma_{\theta}(x_t,t))</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em;">θ</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3283em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">0</span><span class="mrel mtight">:</span><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:3.0954em;vertical-align:-1.2671em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3283em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.8283em;"><span style="top:-1.8829em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mrel mtight">=</span><span class="mord mtight">1</span></span></span></span><span style="top:-3.05em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op">∏</span></span></span><span style="top:-4.3em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.2671em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em;">θ</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:1.0037em;"></span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em;">θ</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathcal" style="margin-right:0.14736em;">N</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span><span class="mpunct">;</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal">μ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em;">θ</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathnormal">t</span><span class="mclose">)</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord">Σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em;">θ</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathnormal">t</span><span class="mclose">))</span></span></span></span></span>

<p><img src="/assets/img/blog/diffusion_models/diffusion-example.png" alt="Reverse Diffusion Process" /></p>

<p class="figcaption">Fig. 3. Beispielhaftes Training eines Diffusion Models zum Modellieren von 2D swiss roll daten. (Quelle: <a href="https://arxiv.org/abs/1503.03585">Sohl-Dickstein et al. 2015</a>)</p>

<p>Es ist bemerkenswert dass die reverse bedingte Wahrscheinlichkeit berechnet werden kan, wenn diese auf <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>x</mi><mn>0</mn></msub></mrow><annotation encoding="application/x-tex">x_0</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.5806em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">0</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> bedingt ist:</p>

<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><mi>q</mi><mo stretchy="false">(</mo><msub><mi>x</mi><mrow><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub><mi mathvariant="normal">∥</mi><msub><mi>x</mi><mi>t</mi></msub><mo stretchy="false">)</mo><mo>=</mo><mi mathvariant="script">N</mi><mo stretchy="false">(</mo><msub><mi>x</mi><mrow><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub><mo separator="true">;</mo><mstyle mathcolor="blue"><mover accent="true"><mi>μ</mi><mo stretchy="true">~</mo></mover><mo stretchy="false">(</mo><msub><mi>x</mi><mi>t</mi></msub><mo separator="true">,</mo><mi>t</mi><mo stretchy="false">)</mo><mo separator="true">,</mo><mstyle mathcolor="red"><msub><mover accent="true"><mi>β</mi><mo>~</mo></mover><mi>t</mi></msub><mi mathvariant="bold">I</mi><mo stretchy="false">)</mo></mstyle></mstyle></mrow><annotation encoding="application/x-tex">q(x_{t-1} \| x_{t}) = \mathcal{N}(x_{t-1}; \color{blue} \widetilde{\mu}(x_t, t), \color{red} \tilde{\beta}_t \mathbf{I})</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">q</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span><span class="mord">∥</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1.1813em;vertical-align:-0.25em;"></span><span class="mord mathcal" style="margin-right:0.14736em;">N</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span><span class="mpunct">;</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord accent" style="color:blue;"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6906em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="color:blue;">μ</span></span><span class="svg-align" style="width:calc(100% - 0.0556em);margin-left:0.0556em;top:-3.4306em;"><span class="pstrut" style="height:3em;"></span><span style="color:blue;height:0.26em;"><svg xmlns="http://www.w3.org/2000/svg" width='100%' height='0.26em' viewBox='0 0 600 260' preserveAspectRatio='none'><path d='M200 55.538c-77 0-168 73.953-177 73.953-3 0-7
-2.175-9-5.437L2 97c-1-2-2-4-2-6 0-4 2-7 5-9l20-12C116 12 171 0 207 0c86 0
 114 68 191 68 78 0 168-68 177-68 4 0 7 2 9 5l12 19c1 2.175 2 4.35 2 6.525 0
 4.35-2 7.613-5 9.788l-19 13.05c-92 63.077-116.937 75.308-183 76.128
-68.267.847-113-73.952-191-73.952z'/></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.1944em;"><span></span></span></span></span></span><span class="mopen" style="color:blue;">(</span><span class="mord" style="color:blue;"><span class="mord mathnormal" style="color:blue;">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style="color:blue;"><span class="mord mathnormal mtight" style="color:blue;">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct" style="color:blue;">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathnormal" style="color:blue;">t</span><span class="mclose" style="color:blue;">)</span><span class="mpunct" style="color:blue;">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord" style="color:red;"><span class="mord accent" style="color:red;"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9313em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal" style="margin-right:0.05278em;color:red;">β</span></span><span style="top:-3.6134em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.1667em;"><span class="mord" style="color:red;">~</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.1944em;"><span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0528em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style="color:red;"><span class="mord mathnormal mtight" style="color:red;">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord mathbf" style="color:red;">I</span><span class="mclose" style="color:red;">)</span></span></span></span></span>

<p>Mit dem Satz von Bayes erhalten wir folgendes:</p>

<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mtable rowspacing="0.25em" columnalign="right left" columnspacing="0em"><mtr><mtd><mstyle scriptlevel="0" displaystyle="true"><msub><mi>x</mi><mi>t</mi></msub></mstyle></mtd><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow><mrow></mrow><mo>=</mo></mrow></mstyle></mtd></mtr><mtr><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow></mrow></mstyle></mtd><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow><mrow></mrow><mo>=</mo></mrow></mstyle></mtd></mtr><mtr><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow></mrow></mstyle></mtd><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow><mrow></mrow><mo>=</mo></mrow></mstyle></mtd></mtr><mtr><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow></mrow></mstyle></mtd><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow><mrow></mrow><mo>=</mo></mrow></mstyle></mtd></mtr><mtr><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow><mi>q</mi><mo stretchy="false">(</mo><msub><mi>x</mi><mi>t</mi></msub><mi mathvariant="normal">∣</mi><msub><mi>x</mi><mn>0</mn></msub><mo stretchy="false">)</mo><mo>=</mo><mi mathvariant="script">N</mi><mo stretchy="false">(</mo><msub><mi>x</mi><mi>t</mi></msub><mo separator="true">;</mo><msqrt><mover accent="true"><msub><mi>α</mi><mi>t</mi></msub><mo stretchy="true">‾</mo></mover></msqrt><msub><mi>x</mi><mn>0</mn></msub><mo separator="true">,</mo><mo stretchy="false">(</mo><mn>1</mn><mo>−</mo><mover accent="true"><msub><mi>α</mi><mi>t</mi></msub><mo stretchy="true">‾</mo></mover><mo stretchy="false">)</mo><mtext mathvariant="bold">I</mtext><mo stretchy="false">)</mo></mrow></mstyle></mtd></mtr></mtable><annotation encoding="application/x-tex">\begin{aligned}%!!15
    x_t &amp;=  \\[1em]
        &amp;=  \\[1em]
        &amp;=  \\[1em]
        &amp;=  \\[1em]

    q(x_t \vert x_{0}) = \mathcal{N}(x_t; \sqrt{\overline{\alpha_t}} x_{0}, (1-\overline{\alpha_t}) \textbf{I})
\end{aligned}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:11.5342em;vertical-align:-5.5171em;"></span><span class="mord"><span class="mtable"><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:6.0171em;"><span style="top:-8.1771em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span style="top:-5.6771em;"><span class="pstrut" style="height:3em;"></span><span class="mord"></span></span><span style="top:-3.1771em;"><span class="pstrut" style="height:3em;"></span><span class="mord"></span></span><span style="top:-0.6771em;"><span class="pstrut" style="height:3em;"></span><span class="mord"></span></span><span style="top:1.8571em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">q</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">0</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mord mathcal" style="margin-right:0.14736em;">N</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">;</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8742em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style="padding-left:0.833em;"><span class="mord overline"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6306em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal" style="margin-right:0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0037em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.5506em;"><span class="pstrut" style="height:3em;"></span><span class="overline-line" style="border-bottom-width:0.04em;"></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span style="top:-2.8342em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail" style="min-width:0.853em;height:1.08em;"><svg xmlns="http://www.w3.org/2000/svg" width='400em' height='1.08em' viewBox='0 0 400000 1080' preserveAspectRatio='xMinYMin slice'><path d='M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z'/></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.1658em;"><span></span></span></span></span></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">0</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mopen">(</span><span class="mord">1</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord overline"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6306em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal" style="margin-right:0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0037em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.5506em;"><span class="pstrut" style="height:3em;"></span><span class="overline-line" style="border-bottom-width:0.04em;"></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span><span class="mclose">)</span><span class="mord text"><span class="mord textbf">I</span></span><span class="mclose">)</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:5.5171em;"><span></span></span></span></span></span><span class="col-align-l"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:6.0171em;"><span style="top:-8.1771em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span></span></span><span style="top:-5.6771em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span></span></span><span style="top:-3.1771em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span></span></span><span style="top:-0.6771em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:3.9829em;"><span></span></span></span></span></span></span></span></span></span></span></span>

<h2 id="credits">Credits</h2>

<p>Vielen Dank an Lilian Weng für ihre coolen Blogposts und die Möglichkeit, diesen Blogpost zu übersetzen!</p>

<p><span>Photo by <a href="https://unsplash.com/@jjying?utm_source=unsplash&amp;utm_medium=referral&amp;utm_content=creditCopyText">JJ Ying</a> on <a href="https://unsplash.com/?utm_source=unsplash&amp;utm_medium=referral&amp;utm_content=creditCopyText">Unsplash</a></span></p>]]></content><author><name>Kieran Didi</name><email>kieran.didi@gmail.com</email></author><category term="ml" /><summary type="html"><![CDATA[Was hinter dem Hype steckt]]></summary><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" url="/assets/img/blog/diffusion_models/diffusion_models_cover.jpg" /><media:content medium="image" url="/assets/img/blog/diffusion_models/diffusion_models_cover.jpg" xmlns:media="http://search.yahoo.com/mrss/" /></entry><entry><title type="html">Annoying Errors and how to fix them</title><link href="/blog/programming/2023-03-23-troubleshoot/" rel="alternate" type="text/html" title="Annoying Errors and how to fix them" /><published>2023-03-23T00:00:00+00:00</published><updated>2024-04-07T17:41:32+00:00</updated><id>/blog/programming/troubleshoot</id><content type="html" xml:base="/blog/programming/2023-03-23-troubleshoot/"><![CDATA[<p>We all now how annoying fixing bugs can be. One of the few things that is more annoying though is facing an error you encountered before but do not remember the solution for. Therefore, this post is an ongoing mental log for some of the errors I encountered at least twice to avoid repeating the bugfix process all over again from the start.</p>

<ul id="markdown-toc">
  <li><a href="#cuda-error-failed-to-initialize-nvml-driverlibrary-version-mismatch" id="markdown-toc-cuda-error-failed-to-initialize-nvml-driverlibrary-version-mismatch">CUDA Error: <code class="language-plaintext highlighter-rouge">Failed to initialize NVML: Driver/library version mismatch</code></a></li>
  <li><a href="#pdb-debugging-error-if-selfquitting-raise-bdbquit-bdbbdbquit" id="markdown-toc-pdb-debugging-error-if-selfquitting-raise-bdbquit-bdbbdbquit">pdb debugging error: <code class="language-plaintext highlighter-rouge">if self.quitting: raise BdbQuit (bdb.BdbQuit)</code></a></li>
</ul>

<h2 id="cuda-error-failed-to-initialize-nvml-driverlibrary-version-mismatch">CUDA Error: <code class="language-plaintext highlighter-rouge">Failed to initialize NVML: Driver/library version mismatch</code></h2>

<p>I encountered this one after having some PyTorch/CUDA errors, trying to reinstall some GPU drivers and failing miserably. Fortunately, after some digging I found <a href="https://forums.developer.nvidia.com/t/failed-to-initialize-nvml-driver-library-version-mismatch/190421/2">this discussion on the NVIDIA forum</a> which cleared things up.</p>

<p>In summary, this error is caused by your GPU having a different CUDA driver version than the one you have installed on your host machine. Sometimes you already get the version as part of the error message and can rectify it based on that. If not, follow these steps:</p>

<ol>
  <li>run <code class="language-plaintext highlighter-rouge">run sudo nvidia-bug-report.sh</code></li>
  <li>extract the bug report via <code class="language-plaintext highlighter-rouge">gzip -d nvidia-bug-report.gz</code></li>
  <li>Open extracted <code class="language-plaintext highlighter-rouge">nvidia-bug-report.log</code> and search for “API Mismatch”. Note down which version your client (i.e. your GPU) has.</li>
  <li>run <code class="language-plaintext highlighter-rouge">sudo apt install nvidia-driver-470</code> and replace <code class="language-plaintext highlighter-rouge">470</code> with whatever version your client reported in the bug report.</li>
</ol>

<h2 id="pdb-debugging-error-if-selfquitting-raise-bdbquit-bdbbdbquit">pdb debugging error: <code class="language-plaintext highlighter-rouge">if self.quitting: raise BdbQuit (bdb.BdbQuit)</code></h2>

<p>This one I got when I was debugging an ML program. I made a local editable install of my ML repo via <code class="language-plaintext highlighter-rouge">pip install -e .</code>, had a <code class="language-plaintext highlighter-rouge">breakpoint()</code> in my dataloader and ran my model with <code class="language-plaintext highlighter-rouge">WandB</code> logging.</p>

<p>What happens is that the dataloader was then executed in the background and waited for a signal to continue, step or something else, but waited with no avail and finally quit with <code class="language-plaintext highlighter-rouge">BdBQuit</code>. After reading up on this <a href="https://stackoverflow.com/questions/34914704/bdbquit-raised-when-debugging-python-with-pdb">here</a>, I managed to fix it by just running the dataloader itself locally in debug mode.</p>]]></content><author><name>Kieran Didi</name><email>kieran.didi@gmail.com</email></author><category term="programming" /><summary type="html"><![CDATA[Demystifying some of the errors out there for my future self]]></summary><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" url="/assets/img/blog/bugfixes/bugfixes.jpeg" /><media:content medium="image" url="/assets/img/blog/bugfixes/bugfixes.jpeg" xmlns:media="http://search.yahoo.com/mrss/" /></entry><entry><title type="html">De novo proteins and where to find them - RosettaCon 2022</title><link href="/blog/proteins/2022-11-17-rosettacon/" rel="alternate" type="text/html" title="De novo proteins and where to find them - RosettaCon 2022" /><published>2022-11-17T00:00:00+00:00</published><updated>2022-11-17T15:49:10+00:00</updated><id>/blog/proteins/rosettacon</id><content type="html" xml:base="/blog/proteins/2022-11-17-rosettacon/"><![CDATA[<p>There has been a lot happening recently in protein design, and it is easy to get lost in the daily flood of new papers and exciting ideas (for an overview of current models and approaches see <a href="https://www.biorxiv.org/content/10.1101/2022.08.31.505981v1">this review</a> which includes a great <a href="https://github.com/hefeda/design_tools">model table</a>).</p>

<p>In such situations, it often helps to step back for a second, look at the bigger picture and chat to people about what is happening and what the future might hold. Luckily, RosettaCon was happening this August and provided a venue for exactly that: chatting to people about protein design and fascinating ideas! In this post I want to highlight some presentations from the conference that I think are representative of some broader directions in the field.</p>
<ul id="markdown-toc">
  <li><a href="#scaffolding-protein-motifs-using-deep-learning---baker-lab-ipd-uw" id="markdown-toc-scaffolding-protein-motifs-using-deep-learning---baker-lab-ipd-uw">Scaffolding protein motifs using Deep Learning - (Baker Lab, IPD, UW)</a></li>
  <li><a href="#manifold-sampling-for-function-guided-antibody-design---vladimir-gligorijevic-prescient-design" id="markdown-toc-manifold-sampling-for-function-guided-antibody-design---vladimir-gligorijevic-prescient-design">Manifold Sampling for function-guided antibody design - Vladimir Gligorijevic (Prescient Design)</a></li>
  <li><a href="#designing-epitope-specific-binders-in-silico---possu-huang-stanford" id="markdown-toc-designing-epitope-specific-binders-in-silico---possu-huang-stanford">Designing epitope-specific binders in silico - Possu Huang (Stanford)</a></li>
  <li><a href="#bringing-de-novo-proteins-into-the-clinic---javier-castellanos-neoleukin" id="markdown-toc-bringing-de-novo-proteins-into-the-clinic---javier-castellanos-neoleukin">Bringing de novo proteins into the clinic - Javier Castellanos (Neoleukin)</a></li>
  <li><a href="#closing-thoughts" id="markdown-toc-closing-thoughts">Closing thoughts</a></li>
  <li><a href="#credits" id="markdown-toc-credits">Credits</a></li>
</ul>

<h2 id="scaffolding-protein-motifs-using-deep-learning---baker-lab-ipd-uw">Scaffolding protein motifs using Deep Learning - (Baker Lab, IPD, UW)</h2>

<p>While complete de novo design of proteins is still the holy grail of the field, in practice you often want to incorporate a motif known from nature into a new custom-made scaffold. This motif-scaffolding problem can be quite tricky, since the desired motif is often energetically unfavourable, which means that the scaffold is required to balance this out in order to form a stable folded protein. In <a href="https://www.science.org/doi/10.1126/science.abn2100">this Science publication</a>, a team from the IPD in Seattle approached this problem via two different deep learning methods: hallucination and inpainting.</p>

<p>Both rely on the impressive advances in protein structure prediction in recent years. More explicitly, hallucination describes an iterative procedure in which a protein structure prediction network repeatedly predicts the structure for a given input sequence. After this prediction, a loss is calculated based on both quality of the structure in general and recapitulation of the desired motif. This loss is then used to update the network and produce a better sequence, repeating this process until a desired performance threshold is achieved.</p>

<p>On the other hand, inpainting treats the scaffolding problem as an information recovery task. Here, part of the sequence input is masked and the model is asked to predict this missing residues, generating novel stable proteins.</p>

<p>This paper is not the only one to tackle this problem. There have been efforts to use <a href="https://arxiv.org/abs/2206.04119">diffusion models for the scaffolding problem</a> as well as <a href="https://openreview.net/forum?id=ZTsoE8G3GG">graph neural networks</a>. All these approaches sound promising; time will tell which of these will be applicable to which design goal.</p>
<h2 id="manifold-sampling-for-function-guided-antibody-design---vladimir-gligorijevic-prescient-design">Manifold Sampling for function-guided antibody design - Vladimir Gligorijevic (Prescient Design)</h2>

<p>Another exciting talk by Vladimir Gligorijevic showcased some of the work that has been going on at Prescient Design and that has been published in this <a href="https://www.biorxiv.org/content/10.1101/2021.12.22.473759v1.full">Manifold Design paper</a>. The general idea of the approach (as you can see from the name) builds upon the <a href="https://www.lcayton.com/resexam.pdf">manifold hypothesis</a>, i.e. your high-dimensional data is normally not widely spread out in these dimensions, but is often restricted to a lower-dimensional manifold.</p>

<p>Since functional proteins only occupy a small fraction of overall sequence space, thinking about protein design in terms of manifold sampling sounds like a reasonable idea and is what drove the recent advances in protein language models, which basically learn to generate sequences that lie on this manifold of natural sequences (one of the early examples of explicitly formulating the problem in this way was <a href="https://www.biorxiv.org/content/10.1101/2022.04.10.487811v1.full">this paper</a> by Hie et al. from Stanford).</p>

<p>But this team tackled the problem via a different approach: they built a Denoising Auto-Encoder (DAE) that takes as input a protein sequence, perturbs it and then generates a new protein sequence from there. The cool thing about this unsupervised approach is a separate supervised function classifier that predicts the function of the newly generated sequence based on Gene Ontology (GO) terms and therefore serves as a guide for the DAE to generate sequences with the desired function.</p>

<p>The team shows some pretty cool applications of their approach, from generating new Calcium-binding proteins to an ion transporter with a novel <em>all-alpha</em> fold (no pun intended), and that by starting from a protein with an <em>all-beta</em> fold! In a <a href="https://arxiv.org/abs/2205.04259">follow-up paper</a> they describe an approach to only design certain regions of the sequence, enabling workflows similar to the RFDesign pipeline mentioned above.</p>

<p>All in all this seems like a promising way to generate very diverse sequences conditioned on function.</p>
<h2 id="designing-epitope-specific-binders-in-silico---possu-huang-stanford">Designing epitope-specific binders in silico - Possu Huang (Stanford)</h2>

<p>Generating custom antibodies binding to a specific target is already quite a feat, but doing it not only target- but even epitope-specific would be impressive. Nothing less is what Possu Huang presented at RosettaCon. They published a lot of work on protein design and specifically backbone generation over the last couple of years, using diverse approaches ranging from <a href="https://openreview.net/forum?id=SJxnVL8YOV">Generative Adversarial Networks (GANs)</a> to <a href="http://www.proteindesign.org/uploads/1/2/1/9/121933886/2020_madani_neurips.pdf">language models</a>.</p>

<p>In 2020 they published a preprint on Ig-VAE, a Variational Autoencoder for generating antibody structures, and published the final version in <a href="https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1010271">PLOS CompBio this June</a>. 
This model is inspired by nature using a single antibody scaffold and adapting it to the problem at hand. They wanted to do something similiar <em>in silico</em>, so they chose to create a model with three important properties:</p>

<ol>
  <li>rotational and translational invariance should be maintained</li>
  <li>the model should be aware of torsion angles since these are very important for protein structure and function</li>
  <li>the output should directly be 3D structures and not an intermediate output such as distance maps.</li>
</ol>

<p>With this new generative model, they were able to embed structures in a latent space and then sample from this latent space generating novel sequences. The direct output of structures made sure that there are no inconsistencies due to an overdetermined system, and the rotational and translational invariance meant that via this inductive bias they were able to circumvent the hassle of data augmentation for different rotational poses and learn a meaningful protein representation directly.</p>

<p>Part of the team behind the paper has continued their work on protein structure generation since then and recently published <a href="https://arxiv.org/abs/2205.15019">a preprint</a> in which they show that diffusion models can be used in a similar way to generate protein structures, showing again that novel advances in ML are often transferable to applications in biology. It will be exciting to see what the next breakthrough at this interface will be!</p>

<h2 id="bringing-de-novo-proteins-into-the-clinic---javier-castellanos-neoleukin">Bringing de novo proteins into the clinic - Javier Castellanos (Neoleukin)</h2>

<p><a href="https://www.neoleukin.com/">Neoleukin</a> is one of several protein design companies originating in the IPD in Seattle. Their particular focus is therapeutic design with applications in e.g. <a href="https://www.sciencedirect.com/science/article/pii/S1367593120300181?via%3Dihub">cancer immunotherapy</a>.</p>

<p>Their lead candidate NL-201 is based on <a href="https://eorder.sheridan.com/3_0/app/orders/8675/article.php">this publication in Nature</a> in which they showed that this de novo protein is an effective activator of IL-2 and IL-15 agonist, which means it can activate cells which express the receptor for this signalling molecule, for example NK or T cells that are important in immunotherapy against cancer. Other IL-2 based therapeutic approaches have been challenging in the past since they bind strongly to CD25, a subunit of another subtype of IL-2-receptor that is expressed in off-target cell types but which greatly enhances both affinity and activity of IL-2. Therefore, these off-target cells experience an even higher activation than the intended cell targets, leading to a plethora of potential side effects. Via computational protein design, they were able to create a protein that selectively binds to the intended subpopulation of IL-2 receptors, showing effective responses in immunotherapy studies.</p>

<p>They now use several of the open-source available tools to optimize their designed sequences, which shows that protein design with machine learning is not a far-fetched dream, but actually already here!</p>

<h2 id="closing-thoughts">Closing thoughts</h2>

<p>The talks mentioned above show the breadth and depth of work going on in the protein design field, from combining established methods with new methods to the development of new algorithmic ideas. It was especially fascinating to see that de novo proteins are moving into the clinic and to the patient now, since advancing disease therapy is the goal of many protein design projects.</p>

<p>Finally, the collaborative and friendly atmosphere at the conference was exceptional; everybody took time for answering questions, helping others out and explaining new advances in order to move the whole field forward. For me personally, attending RosettaCon remotely was a great experience, and I hope to repeat it in person soon!</p>

<h2 id="credits">Credits</h2>

<p>Thanks a lot to the organisers of the RosettaCon conference, both for making the conference a great experience and for allowing me to post this summary on their website and use their logo for the post on my website.</p>]]></content><author><name>Kieran Didi</name><email>kieran.didi@gmail.com</email></author><category term="proteins" /><summary type="html"><![CDATA[About the breathtaking pace of innovation in the space and the amazing community that drives it]]></summary><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" url="/assets/img/blog/rosetta_logo.png" /><media:content medium="image" url="/assets/img/blog/rosetta_logo.png" xmlns:media="http://search.yahoo.com/mrss/" /></entry><entry><title type="html">sed, awk &amp;amp; co - master the shell</title><link href="/blog/programming/2022-11-08-shell-commands/" rel="alternate" type="text/html" title="sed, awk &amp;amp; co - master the shell" /><published>2022-11-08T00:00:00+00:00</published><updated>2022-11-08T21:29:21+00:00</updated><id>/blog/programming/shell-commands</id><content type="html" xml:base="/blog/programming/2022-11-08-shell-commands/"><![CDATA[<p>I recently saw two great videos regarding the command line tools <a href="https://www.youtube.com/watch?v=EACe7aiGczw">sed</a> and <a href="https://www.youtube.com/watch?v=9YOZmI-zWok">awk</a> and thought it might be a good idea to put the commands and varieties explained in these videos here in order to have a quick reference for myself and others in case one struggles again to find the right pattern or syntax for using one of these tools. If you do not know <code class="language-plaintext highlighter-rouge">awk</code> and <code class="language-plaintext highlighter-rouge">sed</code> yet, I highly recommend watching these videos and getting familiar with them; using them for text manipulation and quick processing is often way quicker than writing a Python or R script for this kind of job. For those who know the two tools already, I hope that this provides a good reference for their usage!</p>

<ul id="markdown-toc">
  <li><a href="#sed---your-search-and-replace-function" id="markdown-toc-sed---your-search-and-replace-function">sed - your search and replace function</a></li>
  <li><a href="#awk---the-allrounder" id="markdown-toc-awk---the-allrounder">awk - the allrounder</a></li>
  <li><a href="#getting-help---mantldr" id="markdown-toc-getting-help---mantldr">getting help - man/tldr</a></li>
  <li><a href="#closing-thoughts" id="markdown-toc-closing-thoughts">Closing thoughts</a></li>
</ul>

<h2 id="sed---your-search-and-replace-function">sed - your search and replace function</h2>

<p>sed stands from <em>stream editor</em> and you can imagine it as your automated search and replace function: with it you can look for patterns and replace them with other patterns. In this part of the post we will use the following text file as an example:</p>

<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c"># file: balance.txt</span>
- 25,13 EUR Mon Supermarket <span class="nt">-------</span>

+ 13,40 EUR Tue Pizza/Drinks -
- 05,00 EUR Tue Bus <span class="nt">--</span>

+ 40,00 EUR Wed Refund <span class="nt">----</span>
</code></pre></div></div>

<p>Examples:</p>

<ul>
  <li><code class="language-plaintext highlighter-rouge">sed s/,/./ &lt;balance.txt &gt;balance_int.txt</code>: read text from file1, substitute the first comma on each line with a full stop and write the output to file2</li>
  <li><code class="language-plaintext highlighter-rouge">sed s/,/./g &lt;balance.txt &gt;balance_int.txt</code>: same as above, but with the global option <code class="language-plaintext highlighter-rouge">/g</code> sed substitute every comma with a full stop, not just the first one in each line</li>
  <li><code class="language-plaintext highlighter-rouge">echo "15,3" | sed s/,/./</code>: pipe input from other commands
important: sed is searching for strings, not for words!</li>
  <li><code class="language-plaintext highlighter-rouge">sed -i s/,/./g balance.txt</code>: <code class="language-plaintext highlighter-rouge">-i</code> flag makes it read and write to the same file; input/output flags not needed in this case</li>
  <li><code class="language-plaintext highlighter-rouge">sed '/+/s/,/./g' balance.txt</code>: look for lines in balance.txt that contain a + and substitute , with . in these lines.</li>
  <li><code class="language-plaintext highlighter-rouge">sed '/-/d' balance.txt</code>: look for lines in balance.txt that contain a - and delete these lines</li>
  <li><code class="language-plaintext highlighter-rouge">sed -e 's/Mon/Monday/g' -e 's/Tue/Tuesday/g' -f balance.txt</code>: normally, sed takes first argument as expression and second input as file. In case we want to use multiple expressions and/or files, we can make this explicit with the <code class="language-plaintext highlighter-rouge">-e</code> and <code class="language-plaintext highlighter-rouge">-f</code> flags.</li>
  <li><code class="language-plaintext highlighter-rouge">sed s/Pizza\/Drinks/Party/g</code>: if the search pattern itself contains a /, we can escape that with a backslash.</li>
  <li><code class="language-plaintext highlighter-rouge">sed s#Pizza/Drinks#Part#g</code>: other possibility to circumvent this problem: just use other separators! sed is not very picky about which separators you use and is smart enough to understand what you are trying to do.</li>
  <li><code class="language-plaintext highlighter-rouge">sed -n /-/p &lt;balance.txt</code>: print lines from balance.txt that have a - in them. By default <code class="language-plaintext highlighter-rouge">sed</code> prints all the input it processed except for deletions. <code class="language-plaintext highlighter-rouge">-n</code> (no) suppresses this output, and the print option <code class="language-plaintext highlighter-rouge">/p</code> prints the lines that match our pattern.</li>
  <li><code class="language-plaintext highlighter-rouge">sed -i 's/-*$//' balance.txt</code>: find regex pattern in each line (here dashes (<code class="language-plaintext highlighter-rouge">-</code>), an arbitrary number of them (<code class="language-plaintext highlighter-rouge">*</code>) at the end of the line (<code class="language-plaintext highlighter-rouge">$</code>)) and substitute them with nothing (<code class="language-plaintext highlighter-rouge">//</code>).</li>
  <li><code class="language-plaintext highlighter-rouge">sed '/^$/d'</code>: find every empty line (nothing in between start (<code class="language-plaintext highlighter-rouge">^</code>) and end (<code class="language-plaintext highlighter-rouge">$</code>) of line) and delete it.</li>
  <li><code class="language-plaintext highlighter-rouge">sed 's/[A-Z]/\L&amp;/g'</code>: find every uppercase letter and make it lowercase. To do it the other way around, replace <code class="language-plaintext highlighter-rouge">[A-Z]</code> with <code class="language-plaintext highlighter-rouge">[a-z]</code> and <code class="language-plaintext highlighter-rouge">\L</code> with <code class="language-plaintext highlighter-rouge">\U</code>.</li>
  <li><code class="language-plaintext highlighter-rouge">sed 10q balance.txt</code>: use it as replacement for <code class="language-plaintext highlighter-rouge">head</code> command. without any flags, <code class="language-plaintext highlighter-rouge">head balance.txt</code> gives you the first ten lines of a file.</li>
</ul>

<p>It is important to use single quotes for the sed pattern instead of double quotes. If you use single quotes, sed gets exactly the pattern that you write. But when you use double quotes, the string is first passed
to the shell and interpreted by it, which can be problematic in case of special symbols and variable/command names. It can also be beneficial, but only if you know what you are doing; otherwise, stay to single quotes
(see <a href="https://askubuntu.com/questions/1146789/single-quote-and-double-quotes-in-sed#:~:text=%40DummyHead%20Here's%20another%20approach%3A%20if,eventually%20passed%20on%20to%20sed.">this thread</a> for a more detailed discussion).</p>

<h2 id="awk---the-allrounder">awk - the allrounder</h2>

<p><code class="language-plaintext highlighter-rouge">awk</code> is another very powerful command line tool. Most people use it for text manipulation (similar to <code class="language-plaintext highlighter-rouge">sed</code>), but being a full scripting language, it can do a whole bunch more! Fun fact: it got its name from its three creators who wrote the tool in the AT&amp;T Bell Labs in 1977: Alfread Aho, Peter Weinberger, and Brian Kernighan. It is especially useful if your text has some structure in it (like a tsv/csv file for example). Here some examples on what to do with it:</p>

<ul>
  <li><code class="language-plaintext highlighter-rouge">awk '{print $2}' balance.txt</code>: print first field/column of each line. By default, spaces separate columns in <code class="language-plaintext highlighter-rouge">awk</code> (can be customized).</li>
  <li><code class="language-plaintext highlighter-rouge">awk '{print $0}' balance.txt</code>: print whole lines (equivalent to <code class="language-plaintext highlighter-rouge">cat</code>); same output if you just use <code class="language-plaintext highlighter-rouge">'{print}'</code> as command for <code class="language-plaintext highlighter-rouge">awk</code>.</li>
  <li><code class="language-plaintext highlighter-rouge">awk -F ":" '{print $1}' /etc/passwd</code>: use colons instead of spaces as filed separator to get all users on Linux system</li>
  <li><code class="language-plaintext highlighter-rouge">awk -F ":" '{print $1"\t"$6","$7}' /etc/passwd</code>: print several columns with a tab between the first and second column and a comma between the second and third</li>
  <li><code class="language-plaintext highlighter-rouge">awk {'BEGIN{FS=":"; OFS="-"} {print $1,$6,$7}' /etc/passwd</code>: change field separator to different character as part of the input</li>
  <li><code class="language-plaintext highlighter-rouge">awk -F "/" '/^\// {print $NF}' /etc/shells</code>: set <code class="language-plaintext highlighter-rouge">/</code> as the field separator for the contents of <code class="language-plaintext highlighter-rouge">/etc/passwd</code>. Then, search for the regex pattern between slashes (<code class="language-plaintext highlighter-rouge">^\/</code>), which looks for lines that start with a slash (<code class="language-plaintext highlighter-rouge">\</code> is needed to escape <code class="language-plaintext highlighter-rouge">/</code> since it is normally recognised as a special character). Then, print the last field of each line (i.e. the name of the corresponding shells).</li>
  <li><code class="language-plaintext highlighter-rouge">awk -F "/" '/^\// {print $NF}' /etc/shells | uniq | sort</code>: output from above, just with the duplicates removed and alphabetically sorted</li>
  <li><code class="language-plaintext highlighter-rouge">df | awk '/\dev\/loop/ {print $1"\t"$2+$3}'</code>:</li>
  <li><code class="language-plaintext highlighter-rouge">awk 'length($0) &gt; 10' /etc/shells</code>: only print</li>
  <li><code class="language-plaintext highlighter-rouge">ps -ef | awk '{ if($NF == "/bin/zsh') print $0}'</code>: print all processes that are currently running and have <code class="language-plaintext highlighter-rouge">/bin/zsh</code> as end of the line</li>
  <li><code class="language-plaintext highlighter-rouge">ps -ef | awk BEGIN { for(i=1; i&lt;=10; i++) print "Process ", i, ": ", $0}</code></li>
  <li><code class="language-plaintext highlighter-rouge">awk '$1 ~ /^[b,c]/ {print $0}' .bashrc</code>: look at the content of <code class="language-plaintext highlighter-rouge">.bashrc</code>, check if the first column matches the regular expression <code class="language-plaintext highlighter-rouge">^[b,c]</code> (i.e. does the first column start with b or c). If yes, print the line.</li>
  <li><code class="language-plaintext highlighter-rouge">awk '{print substr($0, 4)} /etc/passwd</code>: look at the content of <code class="language-plaintext highlighter-rouge">passwd</code> and print every line from the fourth character on</li>
  <li><code class="language-plaintext highlighter-rouge">awk 'match($0, /,/) {print $1 " has \"\,\" character at " RSTART}' file.txt</code>: look at the content of file.txt and look for all lines that match the pattern <code class="language-plaintext highlighter-rouge">,</code>. then, print the first field of that line, followed by a string that contains the position at which <code class="language-plaintext highlighter-rouge">,</code> appeared in the line (<code class="language-plaintext highlighter-rouge">RSTART</code>).</li>
  <li><code class="language-plaintext highlighter-rouge">df | awk 'NR%2 == 0 {print "Even"}; NR%2 !=0 {print "Odd"}'</code>: NR gives you the line number. Here, take the output of df and print “Even” if the line number is even and “Odd” if the line number is odd.</li>
  <li><code class="language-plaintext highlighter-rouge">awk 'END {print NR} /etc/shells /etc/passwd'</code>: line count combined of given files
    <h2 id="getting-help---mantldr">getting help - man/tldr</h2>
  </li>
</ul>

<p>It is often easy to get lost with all the varieties of tools out there, so here are some pointers to resources to look for help:</p>

<ul>
  <li><code class="language-plaintext highlighter-rouge">man sed</code>: gives you the (long) manual page of sed, explaining the different options</li>
  <li><code class="language-plaintext highlighter-rouge">tldr sed</code>: gives a more concise summary of the sed command, similar to a cheat sheet</li>
  <li><a href="https://www.gnu.org/software/sed/manual/sed.html">online man page</a> often a bit easier to read than terminal version</li>
  <li>great YouTube channels such as <a href="https://www.youtube.com/c/DistroTube">DistroTube</a> explaining many of the tricks for shell commands; many of the example commands from this article are inspired by his videos!</li>
  <li>as always, <a href="https://stackoverflow.com/">StackOverflow</a> is often the best place to visit if you try to solve a specific problem and need inspiration for how to tackle it.</li>
</ul>

<h2 id="closing-thoughts">Closing thoughts</h2>

<p>As with many things, shell scripting feels very cumbersome and inefficient at the start. But once you pass this initial struggle, you will see how convenient they really are (especially since they are present on virtually any Linux machine) and how quickly you can get stuff done with them!</p>]]></content><author><name>Kieran Didi</name><email>kieran.didi@gmail.com</email></author><category term="programming" /><summary type="html"><![CDATA[How to automate annoying tasks with shell scripts]]></summary><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" url="/assets/img/blog/terminal.png" /><media:content medium="image" url="/assets/img/blog/terminal.png" xmlns:media="http://search.yahoo.com/mrss/" /></entry><entry><title type="html">Python for Data Science</title><link href="/blog/programming/2022-11-03-python-dsintro/" rel="alternate" type="text/html" title="Python for Data Science" /><published>2022-11-03T00:00:00+00:00</published><updated>2022-11-06T13:38:27+00:00</updated><id>/blog/programming/python-dsintro</id><content type="html" xml:base="/blog/programming/2022-11-03-python-dsintro/"><![CDATA[<p><em>This post is intended as recommended reading for the participants of the first part of the lecture series “Python for Data Science” at Heidelberg University which was conceptualised and organised by Lukas Jarosch and me, but should be interesting to anyone who wants to start working with Python.</em></p>

<p>Computers seem to be everwhere today: in our offices, our kitchens and increasingly in our labs as well. As a scientist in the natural sciences, you have better and better tools at your disposal that generate more and more data. And while back in the day an Excel table or even a lab notebook would have been sufficient, nowadays you often need software to process your data. While there is a growing amount of no-code software available that you can use without programming yourself, programming will probably form a growing part of your day-to-day job. This is why we are holding <a href="https://github.com/kierandidi/python_for_scientists">this course</a> to get you started, with this post as your initial overview of what we are going to cover!</p>

<ul id="markdown-toc">
  <li><a href="#python-the-swiss-army-knife" id="markdown-toc-python-the-swiss-army-knife">Python: the swiss army knife</a></li>
  <li><a href="#python-libraries-your-specialist-tools" id="markdown-toc-python-libraries-your-specialist-tools">Python libraries: your specialist tools</a>    <ul>
      <li><a href="#pandas-the-scissor-to-change-data-the-way-you-like" id="markdown-toc-pandas-the-scissor-to-change-data-the-way-you-like">Pandas: The scissor to change data the way you like</a></li>
      <li><a href="#seaborn-your-magnifying-glass" id="markdown-toc-seaborn-your-magnifying-glass">Seaborn: your magnifying glass</a></li>
    </ul>
  </li>
  <li><a href="#visual-studio-code-an-editor-you-will-learn-to-love" id="markdown-toc-visual-studio-code-an-editor-you-will-learn-to-love">Visual Studio Code: an editor you will learn to love</a></li>
  <li><a href="#notebooks-a-quick-way-to-get-started" id="markdown-toc-notebooks-a-quick-way-to-get-started">Notebooks: a quick way to get started</a></li>
  <li><a href="#github-collaboration-is-key" id="markdown-toc-github-collaboration-is-key">GitHub: collaboration is key</a></li>
  <li><a href="#stackoverflow-where-you-will-spend-most-of-your-days" id="markdown-toc-stackoverflow-where-you-will-spend-most-of-your-days">StackOverflow: where you will spend most of your days</a></li>
  <li><a href="#closing-thoughts" id="markdown-toc-closing-thoughts">Closing thoughts</a></li>
</ul>

<h2 id="python-the-swiss-army-knife">Python: the swiss army knife</h2>

<p>Many beginners who want to learn programming are confused about which language to start with: should I learn Python? R? The newest and coolest language like Julia and Rust? Or a classic like Java or C++?</p>

<p>First of all, it is important to note that the choice of your first programming language is actually not hugely important. If you continue with coding you will learn more than one language anyway, and once you have mastered one language the ideas and concepts can often be easily transferred to another.</p>

<p>Nevertheless, I want to give you an intuition of what programming languages are out there and why we choose to teach you Python.</p>

<p>At the end of the day, programming languages are just another tool to help you get your work done, similar to programs like Excel or Word, or even physical tools like a hammer or a drilling machine. So, as in real life, it does not make sense to do everything with a hammer; otherwise, every problem will look like a nail to you. So you should choose a tool or a set of tools that can do multiple things for you.</p>

<p>In addition, you do not want to become a craftsperson and work with complicated and specialised equipment only to fix your new picture on the wall. So, the tools you choose should not only be flexible, but also easy to handle.</p>

<p>With this metaphor in mind, let’s transfer these insights to coding:</p>

<p>First, let’s talk about versatility. Although there are many ways in which you can classify programming languages, for the purpose of this post we will keep it simple: in general there are <em>general-purpose programming languages (GPL)</em> and <em>domain specific programming languages (DSL)</em>. As the names suggest, the former are languages that are used for all kind of applications, whereas the latter were designed with a specific application in mind. 
This does not mean that DSLs cannot perform calculations that GPLs can (most programming languages are <a href="">Turing-complete</a> anyway), but their syntax and structure are optimised for a specific purpose, which may make it harder to adapt them for others. 
The division is quite blurry in real life, but I like to keep it in the back of my head to keep my thoughts organised. <a href="">FORTRAN</a> for example was originally designed for numeric computation, and although some people used it for other purposes it stayed mostly in that area. <a href="">SQL</a> is another example of a language that was designed for querying databases and is nearly exclusively used for that purpose.</p>

<p align="center">
  <img src="/assets/img/blog/python_intro/powerdrill.png" width="50%" height="50%" />
</p>

<p><em>A power drill is useful for drilling holes, but not very useful for anything else.</em></p>

<p>While these languages might be great for their respective domain, they are not suited as a first programming language since you want to learn a breadth of applications before specialising later in something you want to work on with more focus.</p>

<p>Therefore, we will teach you a GPL. There are many out there, from C++ over Java to Python or Julia. So, which one to choose?</p>

<p>Well, now our second consideration comes into play: ease of use.</p>

<p>Generally, people often refer to <em>high-level</em> and <em>low-level languages</em> in this area. What they mean by that is how close the language you write is to the machine code your computer reads in the end and how many of the steps in between are abstracted away by your programming language. <a href="https://en.wikipedia.org/wiki/Assembly_language">Assembler</a> is an example of a language that is nearly at machine level: it gives you a lot of power and insight into the machine, but makes it useless for day-to-day tasks.</p>

<p>C++ is an example for a fairly low-level language. While you can also work on a higher level with libraries that give you access to object-oriented programming and other abstractions, you can still mess up your programs by playing around with low-level constructs such as <a href="https://hackaday.com/2018/04/04/the-basics-and-pitfalls-of-pointers-in-c/">pointers</a>. In terms of our previous metaphor, you can think of it as a toolbox with an extensive number of complicated tools: sure, now you are not limited to one tool and are flexible, but each individual one of these is still quite difficult to work with.</p>

<p align="center">
  <img src="/assets/img/blog/python_intro/toolbox.jpg" width="50%" height="50%" />
</p>

<p><em>A toolbox offers you a lot of flexibility, but requires quite some expertise to be used correctly.</em></p>

<p>C++ and Java are great languages, don’t get me wrong: while learning them I learnt a lot about programming itself and the different choices you have as a programmer in how to put an abstract project specification into practice. But the course we are teaching is not primarily for programmers; it is for scientists. You do not only want to write programs, but do a lot of other things as well like experimenting in the lab and generating the data that you will analyse via your code in the end. So although I would recommend also learning a lower-level language to anyone with a deeper interest in computer science, in our course we will focus on Python.</p>

<p>Python is what I would call the swiss army knife of programming languages. It is easy to learn, quick to prototype with and versatile in what it can be applied to.</p>

<p align="center">
  <img src="/assets/img/blog/python_intro/swissknife.jpg" width="50%" height="50%" />
</p>

<p><em>A swiss army knife combines the best of both worlds: it is versatile and straightforward to use.</em></p>

<p>Similar to the Swiss army knife, there are situations in which Python is not the most efficient tool. If you want to write a program doing efficient numerical calculations, Python itself will not be your saviour (but maybe one of its libraries, as we will see later). In that case, a lower-level language like C++ might be more suitable. However, for the purposes of a scientist, Python is a great way into coding, both from a didactic and a practical point of view. Plus there is a large community using <a href="https://www.python.org/">Python</a> already out there, so if you get stuck, there is a high probability that someone out there had the problem before you and posted a solution!</p>

<h2 id="python-libraries-your-specialist-tools">Python libraries: your specialist tools</h2>

<p>In this course, we will teach you two Python libraries that will come in very handy when you analyse data: Pandas and Seaborn.</p>

<h3 id="pandas-the-scissor-to-change-data-the-way-you-like">Pandas: The scissor to change data the way you like</h3>

<p>As a scientist, you will often find yourself doing an experiment and generating large amounts of data from it that you want to analyse. Doing it by hand is impossible due to the number of data points, and tools like Excel start to freeze already when opening your data file. So, what do you do?</p>

<p>Enter <a href="https://pandas.pydata.org/">Pandas</a>, a library for data analysis in Python. With it, many of the tasks for which you would need to write your own functions in base Python (opening Excel sheets, calculating summary statistics, filtering/combining data) are just a one-liner. It comes with its own data structure called a <a href="https://realpython.com/pandas-dataframe/">Pandas DataFrame</a>, which you will learn a lot more about during the course. You can imagine it as a table storing your data in a convenient and efficient format.</p>

<p align="center">
  <img src="/assets/img/blog/python_intro/scissors.png" width="50%" height="50%" />
</p>

<p><em>Similar to a pair of scissors, Pandas can slice and dice your data the way you want it, reshape it and transform it so that it fits your needs.</em></p>

<p>The Pandas website has some great <a href="https://pandas.pydata.org/docs/getting_started/intro_tutorials/index.html">tutorials</a> on how to get started with Pandas and a <a href="https://pandas.pydata.org/pandas-docs/stable/user_guide/cookbook.html">cookbook</a> on how to use it for specific cases, so if you find a use case after the course that you did not encounter before or you do not remember how to handle, these are great places to start looking!</p>

<p>Pandas builds on another library called <a href="https://numpy.org/">NumPy</a>. NumPy is a library optimised for efficiency via <a href="https://en.wikipedia.org/wiki/Array_programming">vectorization</a> and is often used in scientific applications. However, often you do not need all the flexibility that NumPy offers you and want a more concise and pre-structured formulation of your code. Here Pandas can shine: it sits in between the simplicity of base Python and the efficiency of NumPy.</p>

<h3 id="seaborn-your-magnifying-glass">Seaborn: your magnifying glass</h3>

<p>Often, transforming your data via Pandas is only the first step of your data analysis. Numbers can only show so much, and it is often more efficient to visualise your data via graphics, both for your own understanding of your data as well as for communicating your insights to others. Here <a href="https://seaborn.pydata.org/">Seaborn</a> comes to your rescue: it is a data visualisation library that allows you to directly take your data from a DataFrame and plot it easily via a myriad of formats, with all kinds of styles and customisations available. It integrates very well with Pandas and makes creating graphs dead-easy. In case you are looking for inspiration, there is an <a href="https://seaborn.pydata.org/examples/index.html">example gallery</a> showing plots with the corresponding code so that you can easily adapt them for your own purposes.</p>

<p align="center">
<img src="/assets/img/blog/python_intro/lupe.png" width="50%" height="50%" />
</p>

<p><em>Like a magnifying glass, seaborn allows you to see things in your data that you cannot see by just staring it at, and it allows you to show these insights to others.</em></p>

<p>Similar to Pandas, Seaborn did not come from nowhere: it is based on the library <a href="https://matplotlib.org/">matplotlib</a>, which is the go-to library for visualisation in Python. Again, similar to NumPy it offers you a lot of flexibility, but often you will prefer readable over extremely flexible code when analysing your data. Especially the strong integration with Pandas gives you a good reason to use Seaborn. That being said, in many circumstances I find myself switching back and forth between Pandas/Seaborn and NumPy/matplotlib; since the former two are based on the latter two, using them together often works quite well!</p>

<h2 id="visual-studio-code-an-editor-you-will-learn-to-love">Visual Studio Code: an editor you will learn to love</h2>

<p>Pandas and Seaborn and all the rest of it are great, but where do you actually write the code containing all these libraries? While you could do that just via the command line, there are way better tools available nowadays that make your life a lot easier. These are often called <a href="https://en.wikipedia.org/wiki/Integrated_development_environment">integrated development environment</a>(IDE) and if you participated in the course <a href="https://jmbuhr.de/dataintro/">Data Analysis with R</a> by Jannik Buhr, you already met one of these: RStudio.</p>

<p>While RStudio is a nice IDE, it is very R-centric in many of its design decisions. For our purposes, we want a general-purpose IDE that can act as our workbench: we bring whatever tools we want to work with (programming language, packages, data etc.) and our IDE should support our work with these. That is why we decided to teach this course using <a href="https://code.visualstudio.com/">VS Code</a>. It is versatile, open-source and has a massive library of extensions that make your life as a programmer easier. To get started, see <a href="https://realpython.com/python-development-visual-studio-code/">this amazing guide</a> which will take you through the different steps of installing and setting up VS Code (there is also a short walkthrough-guide available on the <a href="https://code.visualstudio.com/docs/python/python-tutorial">VS Code website</a>).</p>

<p align="center">
  <img src="/assets/img/blog/python_intro/workbench.png" width="50%" height="50%" />
</p>

<p><em>All your tools at the right place: VS Code is your workbench, making it easy to access everything that you need and navigate between different tasks.</em></p>

<p>To learn more about the cool things you can do with VS Code (support for R, connecting to remote machines, keyboard shortcuts etc) you can have a look at <a href="https://kdidi.netlify.app/blog/tools/2022-09-17-mac-setup/">this post</a> which I published some time ago.</p>

<h2 id="notebooks-a-quick-way-to-get-started">Notebooks: a quick way to get started</h2>

<p>Using the tools we mentioned so far you can create great programs for analysing your data. But how to get started? And how to document what you have done? And how to put your work into a format that is suitable for presenting at e.g. a lab meeting?</p>

<p>This is were noteboks come in (<a href="https://jupyter.org/">Jupyter notebooks</a> for Python more specifically with the file ending <code class="language-plaintext highlighter-rouge">.ipynb</code>]). You can just power up your notebook and get started coding; it shows you the output directly in the document, no matter if it is a number, a plot or an image.</p>

<p>And once you finish your analysis, you can just add some text cells in which you describe your code more eloquently via <a href="https://www.datacamp.com/tutorial/markdown-in-jupyter-notebook">Markdown syntax</a> than inline-comments in your code could ever do.</p>

<p align="center">
  <img src="/assets/img/blog/python_intro/pencil.png" width="50%" height="50%" />
</p>

<p><em>Notebooks help you to get a quick draft of your program into code, similar to how a pencil lets you quickly draft something on paper which can be refined afterwards.</em></p>

<p>These notebooks can serve both as a quick start to some data exploration (e.g. visualising your data for quality control) and as a documentation of your work you can show to colleagues and collaborators. In fact, the lecture slides we will deliver as part of the Python course are made from notebooks!</p>

<h2 id="github-collaboration-is-key">GitHub: collaboration is key</h2>

<p>Notebooks are a great way to share your results once you are done with analysing your data, but what if you want to share it with your colleagues who want to work on the analysis together with you? The most straightforward way would be to just send them the <code class="language-plaintext highlighter-rouge">.ipynb</code> file which they can then use to work on the data. This however has some caveats: first of all, your colleague needs to have the exact same format of the data as you in order to make the analysis reproducible. Second, if your analysis becomes more complicated, you may want to split your analysis into different Python files and notebooks, making exchanging these more complicated. In addition, you cannot work on the analysis at the same time, since the changes you and your colleague introduce might not be compatible and therefore merging these changes into a consistent program in the end might just be impossible. Even worse, every time you change something in the code you have to send your colleague a new version of your code and vice versa, a huge waste of time and effort. That is why GitHub (and Git) were created.</p>

<p align="center">
  <img src="/assets/img/blog/python_intro/github.png" width="50%" height="50%" />
</p>

<p><em>Coding is teamwork, and GitHub helps you discuss ideas with others and show your work to the world.</em></p>

<p><a href="https://github.com/">GitHub</a> is an online service for software development and <a href="https://en.wikipedia.org/wiki/Version_control">version control</a>. It uses Git, a system for local version control, and distributes it globally so that you can work together with people all over the world. In addition, it provides some nice features that make the software development process more structured and organised: wikis, pull requests, issues, taks management, continuous integration, basically the whole software shebang you could wish for.</p>

<p>Git is a great system, though it takes some time to get used to. But I can assure you that this time will be well-spent since it is the de facto standard for publishing and sharing software with the world. In case you like gamified learning, <a href="https://ohmygit.org/">here</a> is the link to a game that makes the dive into Git a bit more fun.</p>
<h2 id="stackoverflow-where-you-will-spend-most-of-your-days">StackOverflow: where you will spend most of your days</h2>

<p>When you think about programming, you might still have the image of a hacker on your mind, sitting with his hoodie in a dark room, relentlessly maltreating his keyboard without rhyme or reason. That is far from the actual reality: you will spend most of your time looking up stuff which you either have no clue about or had a clue about at some point but lost it in the flood of other important information such as Shakira lyrics and the results of last weekend’s football match. Here, <a href="https://stackoverflow.com/tour">StackOverflow</a> is your friend.</p>

<p align="center">
  <img src="/assets/img/blog/python_intro/stackoverflow.png" width="50%" height="50%" />
</p>

<p><em>In case you are stuck on how to use a tool, nice people on StackOverflow can show you how to use it.</em></p>

<p>On StackOverflow, you can ask questions about basically everything related to programming and will often receive high-quality answers (given that you formulated your question well). But in most cases, you won’t even have to ask a question since another person had the same or a very similar problem before you and solved it with help from StackOverflow. Using this tool right is a skill that cannot be overestimated since you will spend a significant amount of time using this and other similar resources. There are <a href="https://www.freecodecamp.org/news/5-steps-to-become-a-better-stack-overflow-user-4ce85711c0f9/">many guides out there</a> on how to use it best; for me, learning by doing has helped the most. And once you are advanced enough, answering questions there can be a great way to give back and stay sharp on your coding skills at the same time!</p>

<h2 id="closing-thoughts">Closing thoughts</h2>

<p>There are many great resources out there for scientist who want to learn coding for their research, such as <a href="https://education.molssi.org/python-scripting-biochemistry/chapters/setup.html">this workshop webpage</a>. We hope that with our lecture series we give you both the skills to apply coding for basic tasks in your research and the enthusiam to continue learning new things in order to improve even more!</p>]]></content><author><name>Kieran Didi</name><email>kieran.didi@gmail.com</email></author><category term="programming" /><summary type="html"><![CDATA[How to get started with Python and what tools to use]]></summary><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" url="/assets/img/blog/python_intro/toolbox_front.png" /><media:content medium="image" url="/assets/img/blog/python_intro/toolbox_front.png" xmlns:media="http://search.yahoo.com/mrss/" /></entry></feed>