Great work — really nice and clean implementation!
Just a small note: in model_jit.py (around L348), in_context_start is compared directly with the 0-based block index i, so setting in_context_start = 8 actually inserts tokens at the 9th block rather than the 8th as one might expect intuitively.
Would it be more intuitive to use (i + 1) == self.in_context_start instead?