Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit f3d9750

Browse files
authored
Improve Mosaic memory profiling tutorial (#3753)
## Summary This PR improves the Mosaic memory profiling tutorial with several enhancements: - **Add GPT-2 memory profile images**: Visual comparison of memory usage with and without activation checkpointing - **Add Google Colab download instructions**: Code block showing how to download generated snapshot/profile files - **Fix subprocess output visibility**: Mosaic CLI output is now captured and printed so users can see the analysis results - **Split Mosaic analysis into separate code blocks**: Better readability for the baseline vs buggy model comparison - **Refactor GPT2WithDebugOverhead**: Changed from subclassing to wrapper pattern, fixing transformers version compatibility - **Update section formatting**: Bold headers instead of RST section underlines for cleaner rendering ## Test plan - [x] Build tutorial locally with `GALLERY_PATTERN="mosaic_memory_profiling_tutorial" make html-noplot` - [x] Verify new images render correctly - [ ] Verify subprocess output is visible in rendered tutorial - [ ] Verify GPT2WithDebugOverhead works with current transformers version
1 parent b99753a commit f3d9750

4 files changed

Lines changed: 94 additions & 32 deletions

File tree

.ci/docker/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ onnxruntime
5050
evaluate
5151
accelerate>=0.20.1
5252
git+https://github.com/facebookresearch/mosaic.git
53+
altair
54+
omegaconf
5355

5456
importlib-metadata==6.8.0
5557

164 KB
Loading
278 KB
Loading

beginner_source/mosaic_memory_profiling_tutorial.py

Lines changed: 92 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -175,13 +175,13 @@
175175

176176
# Install dependencies if needed
177177
try:
178-
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
178+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
179179
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
180180
except ImportError:
181181
subprocess.check_call(
182182
[sys.executable, "-m", "pip", "install", "-q", "transformers"]
183183
)
184-
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
184+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
185185
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
186186

187187
try:
@@ -441,6 +441,7 @@ def run_training_ac(
441441
print("=" * 60)
442442

443443
# Generate HTML profiles using subprocess
444+
print("\nGenerating baseline profile...")
444445
result1 = subprocess.run(
445446
[
446447
"mosaic_get_memory_profile",
@@ -454,10 +455,14 @@ def run_training_ac(
454455
"--plotter_sampling_rate",
455456
"20",
456457
],
458+
capture_output=True,
459+
text=True,
457460
)
461+
print(result1.stdout)
462+
if result1.stderr:
463+
print(result1.stderr)
458464

459-
print()
460-
465+
print("\nGenerating activation checkpointing profile...")
461466
result2 = subprocess.run(
462467
[
463468
"mosaic_get_memory_profile",
@@ -471,7 +476,12 @@ def run_training_ac(
471476
"--plotter_sampling_rate",
472477
"20",
473478
],
479+
capture_output=True,
480+
text=True,
474481
)
482+
print(result2.stdout)
483+
if result2.stderr:
484+
print(result2.stderr)
475485

476486
if result1.returncode == 0 and result2.returncode == 0:
477487
print("\nGenerated profile_baseline.html")
@@ -481,10 +491,46 @@ def run_training_ac(
481491
print("\nNote: Mosaic profile generation encountered issues.")
482492
print("This may happen if running in an environment without full Mosaic support.")
483493

494+
######################################################################
495+
# Download Generated Files (Google Colab)
496+
# ----------------------------------------
497+
#
498+
# If running in Google Colab, uncomment the following lines to download
499+
# the generated snapshot and profile files:
500+
501+
# from google.colab import files
502+
#
503+
# print("Downloading memory snapshots and profiles...")
504+
# files.download('snapshot_baseline.pickle')
505+
# files.download('snapshot_with_ac.pickle')
506+
# files.download('profile_baseline.html')
507+
# files.download('profile_with_ac.html')
508+
484509
######################################################################
485510
# Results Interpretation: Activation Checkpointing
486511
# -------------------------------------------------
487512
#
513+
# The generated HTML profiles visualize memory usage over time, with
514+
# allocations colored by category. Here's what the profiles look like:
515+
#
516+
# .. figure:: /_static/img/mosaic/mosaic-categorical-memory-profiling-gpt2-without-ac.png
517+
# :alt: GPT-2 memory profile without activation checkpointing
518+
# :align: center
519+
# :width: 600px
520+
#
521+
# **Baseline (without activation checkpointing):** Notice the large
522+
# activation memory (shown in one color) that persists throughout
523+
# the forward pass.
524+
#
525+
# .. figure:: /_static/img/mosaic/mosaic-categorical-memory-profiling-gpt2-with-ac.png
526+
# :alt: GPT-2 memory profile with activation checkpointing
527+
# :align: center
528+
# :width: 600px
529+
#
530+
# **With activation checkpointing:** Activation memory is significantly
531+
# reduced as intermediate activations are discarded and recomputed
532+
# during the backward pass.
533+
#
488534
# What We Observed
489535
# ~~~~~~~~~~~~~~~~
490536
#
@@ -580,11 +626,17 @@ def run_training_ac(
580626
# debugging, but forgot to remove them before training.
581627

582628

583-
class GPT2WithDebugOverhead(GPT2LMHeadModel):
584-
"""GPT2 with abandoned 'feature analysis' code that bloats peak memory."""
629+
class GPT2WithDebugOverhead(torch.nn.Module):
630+
"""GPT2 wrapper with abandoned 'feature analysis' code that bloats peak memory.
631+
632+
This wrapper adds extra projection layers that consume memory but serve no
633+
purpose - simulating abandoned debug code that was never cleaned up.
634+
"""
585635

586-
def __init__(self, config):
587-
super().__init__(config)
636+
def __init__(self, base_model):
637+
super().__init__()
638+
self.base_model = base_model
639+
config = base_model.config
588640

589641
# BUG: Large projection layers from an abandoned experiment
590642
self.debug_projections = torch.nn.ModuleList(
@@ -600,7 +652,7 @@ def __init__(self, config):
600652

601653
def forward(self, input_ids=None, labels=None, **kwargs):
602654
# Run normal GPT-2 forward with hidden states
603-
outputs = super().forward(
655+
outputs = self.base_model(
604656
input_ids=input_ids,
605657
labels=labels,
606658
output_hidden_states=True,
@@ -680,14 +732,9 @@ def run_training_with_bug(snapshot_path, num_steps=3):
680732
device = torch.device("cuda")
681733

682734
print("Loading buggy model with debug overhead...")
683-
config = GPT2Config.from_pretrained("gpt2")
684-
model = GPT2WithDebugOverhead(config).to(device)
685-
686-
# Load pretrained weights
687-
pretrained = GPT2LMHeadModel.from_pretrained("gpt2")
688-
model.load_state_dict(pretrained.state_dict(), strict=False)
689-
del pretrained
690-
torch.cuda.empty_cache()
735+
# Load pretrained GPT-2 and wrap it with the debug overhead
736+
base_model = GPT2LMHeadModel.from_pretrained("gpt2")
737+
model = GPT2WithDebugOverhead(base_model).to(device)
691738

692739
model.train()
693740

@@ -745,35 +792,50 @@ def run_training_with_bug(snapshot_path, num_steps=3):
745792
print("Training with debug projection overhead (BUG)")
746793
print("=" * 60)
747794

748-
try:
749-
buggy_memory = run_training_with_bug("snapshot_with_bug.pickle", num_steps=3)
750-
except (AttributeError, ValueError) as e:
751-
# Handle transformers version compatibility issues
752-
print(f"Note: Skipping buggy model demo due to transformers compatibility: {e}")
753-
buggy_memory = baseline_memory_debug
795+
buggy_memory = run_training_with_bug("snapshot_with_bug.pickle", num_steps=3)
754796

755797
######################################################################
756798
# Use Mosaic to Find the Problem
757799
# -------------------------------
758800
#
759801
# Analyze both snapshots to identify the source of extra memory usage.
802+
# We'll run Mosaic's peak memory analysis on each snapshot separately.
803+
804+
######################################################################
805+
# Analyze the Baseline (Clean) Snapshot
806+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
760807

761808
if HAS_CUDA and HAS_MOSAIC_CLI:
762-
print("\n" + "=" * 60)
809+
print("=" * 60)
763810
print("MOSAIC: Analyzing the Baseline Snapshot")
764811
print("=" * 60)
765812

766-
subprocess.run(
813+
result = subprocess.run(
767814
["mosaic_get_memory_usage_peak", "--snapshot", "snapshot_debug_baseline.pickle"],
815+
capture_output=True,
816+
text=True,
768817
)
818+
print(result.stdout)
819+
if result.stderr:
820+
print(result.stderr)
769821

770-
print("\n" + "=" * 60)
822+
######################################################################
823+
# Analyze the Buggy Snapshot
824+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
825+
826+
if HAS_CUDA and HAS_MOSAIC_CLI:
827+
print("=" * 60)
771828
print("MOSAIC: Analyzing the Buggy Snapshot")
772829
print("=" * 60)
773830

774-
subprocess.run(
831+
result = subprocess.run(
775832
["mosaic_get_memory_usage_peak", "--snapshot", "snapshot_with_bug.pickle"],
833+
capture_output=True,
834+
text=True,
776835
)
836+
print(result.stdout)
837+
if result.stderr:
838+
print(result.stderr)
777839

778840
######################################################################
779841
# Analyzing The Mosaic Output
@@ -783,8 +845,7 @@ def run_training_with_bug(snapshot_path, num_steps=3):
783845
# memory allocation. Let's look at how to find abandoned or unnecessary code
784846
# that's bloating the memory.
785847
#
786-
# 1. Optimizer State Allocations Delta
787-
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
848+
# **1. Optimizer State Allocations Delta**
788849
#
789850
# In the buggy snapshot output, we can see that the first two stack traces
790851
# represent the **optimizer state allocations** (like ``zeros_like`` for Adam
@@ -809,11 +870,10 @@ def run_training_with_bug(snapshot_path, num_steps=3):
809870
# - 148 calls
810871
# - 0.464 GB + 0.464 GB
811872
#
812-
# **What this tells us:** The optimizer is tracking more tensors! This is your
873+
# What this tells us: The optimizer is tracking more tensors! This is your
813874
# first clue that there are extra parameters or tensors in the computation graph.
814875
#
815-
# 2. Additional Activation Allocations
816-
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
876+
# **2. Additional Activation Allocations**
817877
#
818878
# The buggy version shows **extra allocations** that don't appear in the
819879
# baseline model. Scrolling down the Mosaic output of the buggy model we can

0 commit comments

Comments
 (0)