From b57f40e607f10d282c0216359bdc0ce1e44219c6 Mon Sep 17 00:00:00 2001 From: "josemorales@meta.com" Date: Tue, 22 Apr 2025 17:05:02 -0700 Subject: [PATCH] Fix AO SAM2 issues (#2109) Summary: Pull Request resolved: https://github.com/pytorch/ao/pull/2109 SAM2 issues - Whenever ```clear_old_points``` was enabled SAM2 would crash AAS Track mult issues - Enables ```multimask``` flags Rootcaused issues to failed assertion in the following lines in ```sam2_base.py::_track_step:L788```: ``` if prev_sam_mask_logits is not None: assert point_inputs is not None and mask_inputs is None mask_inputs = prev_sam_mask_logits multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) assert mask_inputs is None ``` Whenever ```prev_sam_mask_logits``` has a value it results in a crash. There are several situations where this is expected to be the case including during streamed runs, or when clearing points. Test Plan: aistudio test local aas_track_mult ``` Retrieving package values for `fbcode//ai_demos/server_model_zoo/models/aas_track_mult`: buck2 audit package-values --reuse-current-config fbcode//ai_demos/server_model_zoo/models/aas_track_mult Buck command to find test owners: buck2 uquery --reuse-current-config owner(/data/sandcastle/boxes/fbsource/fbcode/ai_demos/server_model_zoo/models/aas_track_mult/test_aas_track_mult_model.py) -a labels Buck command to invoke a test: buck2 test --reuse-current-config --write-build-id /tmp/.tmpS35tJk --client-metadata language=python --client-metadata id=testify.codelens --client-metadata session_id=d0229502-10cc-45e7-a6f6-6c5c276c2e17 fbcode//ai_demos/server_model_zoo/models/aas_track_mult:tests -- --regex ai_demos/server_model_zoo/models/aas_track_mult:tests \- .*(?:\(.*TestAasTrackMultModel\)$|TestAasTrackMultModel: .*) --run-disabled Buck UI: https://www.internalfb.com/buck2/bf9cbfaa-ae6a-4568-876c-0b128dd474bd Test UI: https://www.internalfb.com/intern/testinfra/testrun/6473924727918606 Network: Up: 0B Down: 0B (reSessionID-8b7877b7-4cf8-4850-ac7b-ee84571b005d) Command: test. Time elapsed: 1:07.6s Tests finished: Pass 4. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` Differential Revision: D73460163 Pulled By: jlbmorales --- torchao/_models/sam2/modeling/sam2_base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchao/_models/sam2/modeling/sam2_base.py b/torchao/_models/sam2/modeling/sam2_base.py index 4c2a24a0ef..a16fe8dd61 100644 --- a/torchao/_models/sam2/modeling/sam2_base.py +++ b/torchao/_models/sam2/modeling/sam2_base.py @@ -788,9 +788,10 @@ def _track_step( if prev_sam_mask_logits is not None: assert point_inputs is not None and mask_inputs is None mask_inputs = prev_sam_mask_logits + else: + assert mask_inputs is None multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) - assert mask_inputs is None assert multimask_output if point_inputs is not None: point_inputs = {k: point_inputs[k].contiguous() for k in point_inputs}