175175
176176# Install dependencies if needed
177177try :
178- from transformers import GPT2Config , GPT2LMHeadModel , GPT2Tokenizer
178+ from transformers import GPT2LMHeadModel , GPT2Tokenizer
179179 from transformers .modeling_outputs import CausalLMOutputWithCrossAttentions
180180except ImportError :
181181 subprocess .check_call (
182182 [sys .executable , "-m" , "pip" , "install" , "-q" , "transformers" ]
183183 )
184- from transformers import GPT2Config , GPT2LMHeadModel , GPT2Tokenizer
184+ from transformers import GPT2LMHeadModel , GPT2Tokenizer
185185 from transformers .modeling_outputs import CausalLMOutputWithCrossAttentions
186186
187187try :
@@ -441,6 +441,7 @@ def run_training_ac(
441441 print ("=" * 60 )
442442
443443 # Generate HTML profiles using subprocess
444+ print ("\n Generating baseline profile..." )
444445 result1 = subprocess .run (
445446 [
446447 "mosaic_get_memory_profile" ,
@@ -454,10 +455,14 @@ def run_training_ac(
454455 "--plotter_sampling_rate" ,
455456 "20" ,
456457 ],
458+ capture_output = True ,
459+ text = True ,
457460 )
461+ print (result1 .stdout )
462+ if result1 .stderr :
463+ print (result1 .stderr )
458464
459- print ()
460-
465+ print ("\n Generating activation checkpointing profile..." )
461466 result2 = subprocess .run (
462467 [
463468 "mosaic_get_memory_profile" ,
@@ -471,7 +476,12 @@ def run_training_ac(
471476 "--plotter_sampling_rate" ,
472477 "20" ,
473478 ],
479+ capture_output = True ,
480+ text = True ,
474481 )
482+ print (result2 .stdout )
483+ if result2 .stderr :
484+ print (result2 .stderr )
475485
476486 if result1 .returncode == 0 and result2 .returncode == 0 :
477487 print ("\n Generated profile_baseline.html" )
@@ -481,10 +491,46 @@ def run_training_ac(
481491 print ("\n Note: Mosaic profile generation encountered issues." )
482492 print ("This may happen if running in an environment without full Mosaic support." )
483493
494+ ######################################################################
495+ # Download Generated Files (Google Colab)
496+ # ----------------------------------------
497+ #
498+ # If running in Google Colab, uncomment the following lines to download
499+ # the generated snapshot and profile files:
500+
501+ # from google.colab import files
502+ #
503+ # print("Downloading memory snapshots and profiles...")
504+ # files.download('snapshot_baseline.pickle')
505+ # files.download('snapshot_with_ac.pickle')
506+ # files.download('profile_baseline.html')
507+ # files.download('profile_with_ac.html')
508+
484509######################################################################
485510# Results Interpretation: Activation Checkpointing
486511# -------------------------------------------------
487512#
513+ # The generated HTML profiles visualize memory usage over time, with
514+ # allocations colored by category. Here's what the profiles look like:
515+ #
516+ # .. figure:: /_static/img/mosaic/mosaic-categorical-memory-profiling-gpt2-without-ac.png
517+ # :alt: GPT-2 memory profile without activation checkpointing
518+ # :align: center
519+ # :width: 600px
520+ #
521+ # **Baseline (without activation checkpointing):** Notice the large
522+ # activation memory (shown in one color) that persists throughout
523+ # the forward pass.
524+ #
525+ # .. figure:: /_static/img/mosaic/mosaic-categorical-memory-profiling-gpt2-with-ac.png
526+ # :alt: GPT-2 memory profile with activation checkpointing
527+ # :align: center
528+ # :width: 600px
529+ #
530+ # **With activation checkpointing:** Activation memory is significantly
531+ # reduced as intermediate activations are discarded and recomputed
532+ # during the backward pass.
533+ #
488534# What We Observed
489535# ~~~~~~~~~~~~~~~~
490536#
@@ -580,11 +626,17 @@ def run_training_ac(
580626# debugging, but forgot to remove them before training.
581627
582628
583- class GPT2WithDebugOverhead (GPT2LMHeadModel ):
584- """GPT2 with abandoned 'feature analysis' code that bloats peak memory."""
629+ class GPT2WithDebugOverhead (torch .nn .Module ):
630+ """GPT2 wrapper with abandoned 'feature analysis' code that bloats peak memory.
631+
632+ This wrapper adds extra projection layers that consume memory but serve no
633+ purpose - simulating abandoned debug code that was never cleaned up.
634+ """
585635
586- def __init__ (self , config ):
587- super ().__init__ (config )
636+ def __init__ (self , base_model ):
637+ super ().__init__ ()
638+ self .base_model = base_model
639+ config = base_model .config
588640
589641 # BUG: Large projection layers from an abandoned experiment
590642 self .debug_projections = torch .nn .ModuleList (
@@ -600,7 +652,7 @@ def __init__(self, config):
600652
601653 def forward (self , input_ids = None , labels = None , ** kwargs ):
602654 # Run normal GPT-2 forward with hidden states
603- outputs = super (). forward (
655+ outputs = self . base_model (
604656 input_ids = input_ids ,
605657 labels = labels ,
606658 output_hidden_states = True ,
@@ -680,14 +732,9 @@ def run_training_with_bug(snapshot_path, num_steps=3):
680732 device = torch .device ("cuda" )
681733
682734 print ("Loading buggy model with debug overhead..." )
683- config = GPT2Config .from_pretrained ("gpt2" )
684- model = GPT2WithDebugOverhead (config ).to (device )
685-
686- # Load pretrained weights
687- pretrained = GPT2LMHeadModel .from_pretrained ("gpt2" )
688- model .load_state_dict (pretrained .state_dict (), strict = False )
689- del pretrained
690- torch .cuda .empty_cache ()
735+ # Load pretrained GPT-2 and wrap it with the debug overhead
736+ base_model = GPT2LMHeadModel .from_pretrained ("gpt2" )
737+ model = GPT2WithDebugOverhead (base_model ).to (device )
691738
692739 model .train ()
693740
@@ -745,35 +792,50 @@ def run_training_with_bug(snapshot_path, num_steps=3):
745792 print ("Training with debug projection overhead (BUG)" )
746793 print ("=" * 60 )
747794
748- try :
749- buggy_memory = run_training_with_bug ("snapshot_with_bug.pickle" , num_steps = 3 )
750- except (AttributeError , ValueError ) as e :
751- # Handle transformers version compatibility issues
752- print (f"Note: Skipping buggy model demo due to transformers compatibility: { e } " )
753- buggy_memory = baseline_memory_debug
795+ buggy_memory = run_training_with_bug ("snapshot_with_bug.pickle" , num_steps = 3 )
754796
755797######################################################################
756798# Use Mosaic to Find the Problem
757799# -------------------------------
758800#
759801# Analyze both snapshots to identify the source of extra memory usage.
802+ # We'll run Mosaic's peak memory analysis on each snapshot separately.
803+
804+ ######################################################################
805+ # Analyze the Baseline (Clean) Snapshot
806+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
760807
761808if HAS_CUDA and HAS_MOSAIC_CLI :
762- print ("\n " + " =" * 60 )
809+ print ("=" * 60 )
763810 print ("MOSAIC: Analyzing the Baseline Snapshot" )
764811 print ("=" * 60 )
765812
766- subprocess .run (
813+ result = subprocess .run (
767814 ["mosaic_get_memory_usage_peak" , "--snapshot" , "snapshot_debug_baseline.pickle" ],
815+ capture_output = True ,
816+ text = True ,
768817 )
818+ print (result .stdout )
819+ if result .stderr :
820+ print (result .stderr )
769821
770- print ("\n " + "=" * 60 )
822+ ######################################################################
823+ # Analyze the Buggy Snapshot
824+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~
825+
826+ if HAS_CUDA and HAS_MOSAIC_CLI :
827+ print ("=" * 60 )
771828 print ("MOSAIC: Analyzing the Buggy Snapshot" )
772829 print ("=" * 60 )
773830
774- subprocess .run (
831+ result = subprocess .run (
775832 ["mosaic_get_memory_usage_peak" , "--snapshot" , "snapshot_with_bug.pickle" ],
833+ capture_output = True ,
834+ text = True ,
776835 )
836+ print (result .stdout )
837+ if result .stderr :
838+ print (result .stderr )
777839
778840######################################################################
779841# Analyzing The Mosaic Output
@@ -783,8 +845,7 @@ def run_training_with_bug(snapshot_path, num_steps=3):
783845# memory allocation. Let's look at how to find abandoned or unnecessary code
784846# that's bloating the memory.
785847#
786- # 1. Optimizer State Allocations Delta
787- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
848+ # **1. Optimizer State Allocations Delta**
788849#
789850# In the buggy snapshot output, we can see that the first two stack traces
790851# represent the **optimizer state allocations** (like ``zeros_like`` for Adam
@@ -809,11 +870,10 @@ def run_training_with_bug(snapshot_path, num_steps=3):
809870# - 148 calls
810871# - 0.464 GB + 0.464 GB
811872#
812- # ** What this tells us:** The optimizer is tracking more tensors! This is your
873+ # What this tells us: The optimizer is tracking more tensors! This is your
813874# first clue that there are extra parameters or tensors in the computation graph.
814875#
815- # 2. Additional Activation Allocations
816- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
876+ # **2. Additional Activation Allocations**
817877#
818878# The buggy version shows **extra allocations** that don't appear in the
819879# baseline model. Scrolling down the Mosaic output of the buggy model we can
0 commit comments