Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion albumentations/core/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,11 +1146,32 @@ def _apply_keypoints_instance_binding(self, targets: frozenset[str]) -> None:
user_fields = list(kp_params.label_fields or [])
internal_fields = [f"_ibl_kp_{f}" for f in user_fields]
self._kp_label_map = dict(zip(internal_fields, user_fields, strict=True))
user_to_internal = {user_name: internal_name for internal_name, user_name in self._kp_label_map.items()}
kp_params.label_mapping = self._remap_label_mapping_fields(kp_params.label_mapping, user_to_internal)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep keypoint label swaps within each instance

When instance_binding is enabled, remapping the public label mapping to _ibl_kp_* makes the existing keypoint swap logic run on the flattened keypoint array. That swap uses global np.where matches over all rows and swaps entire rows, including _kp_instance_id, so inputs where different instances do not each contain both sides of a mapping (for example one instance has only left_eye and another only right_eye) will either swap keypoints between instances during repack or hit a shape mismatch. The mapping needs to be applied per _KP_INSTANCE_ID group, or otherwise avoid global row swaps for bound instances.

Useful? React with 👍 / 👎.

internal_fields.append(_KP_INSTANCE_ID)
kp_params.label_fields = internal_fields
kp_params.remove_invisible = False
kp_params.check_each_transform = False

@staticmethod
def _remap_label_mapping_fields(
label_mapping: dict[str, dict[str, dict[Any, Any]]] | None,
field_map: dict[str, str],
) -> dict[str, dict[str, dict[Any, Any]]]:
"""Return a label_mapping copy with label-field keys renamed through field_map while preserving
transform-specific mappings for each configured transform.

Instance binding temporarily rewrites user keypoint label fields such as "name" to
internal fields such as "_ibl_kp_name". The public label_mapping should still use the
user field names, so Compose translates them when entering and leaving the bound path.
"""
if not label_mapping:
return {}
return {
transform_name: {field_map.get(field_name, field_name): mapping for field_name, mapping in fields.items()}
for transform_name, fields in label_mapping.items()
}

def _set_processors_for_transforms(self, transforms: TransformsSeqType) -> None:
for transform in transforms:
if isinstance(transform, BasicTransform):
Expand Down Expand Up @@ -2078,6 +2099,11 @@ def _clean_params_dict(
if label_fields:
user_fields = [label_map.get(f, f) for f in label_fields if f not in _INSTANCE_ID_FERRY_KEYS]
params_dict = {**params_dict, "label_fields": user_fields}
if params_dict.get("label_mapping"):
params_dict = {
**params_dict,
"label_mapping": self._remap_label_mapping_fields(params_dict["label_mapping"], label_map),
}
return params_dict

def to_dict_private(self) -> dict[str, Any]:
Expand Down Expand Up @@ -2286,7 +2312,7 @@ def _get_init_params(self) -> dict[str, Any]:
remove_invisible=kp.remove_invisible,
angle_in_degrees=kp.angle_in_degrees,
check_each_transform=kp.check_each_transform,
label_mapping=kp.label_mapping or None,
label_mapping=self._remap_label_mapping_fields(kp.label_mapping, self._kp_label_map) or None,
)
else:
kp_params = kp
Expand Down
20 changes: 19 additions & 1 deletion albumentations/core/transforms_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,6 +1088,11 @@ def _swap_keypoint_rows_by_labels(
"""
result = keypoints.copy()
label_col_start = 5 # After [x, y, z, angle, scale]
instance_id_col_idx = None
if "_kp_instance_id" in label_fields:
candidate_col_idx = label_col_start + label_fields.index("_kp_instance_id")
if candidate_col_idx < keypoints.shape[1]:
instance_id_col_idx = candidate_col_idx

# For each label field with mapping, perform row swapping
for i, label_field in enumerate(label_fields):
Expand All @@ -1096,7 +1101,7 @@ def _swap_keypoint_rows_by_labels(
if col_idx < keypoints.shape[1]:
mapping = field_mappings[label_field]
if mapping: # Only process if mapping is not empty
result = self._apply_single_field_mapping(result, col_idx, mapping)
result = self._apply_single_field_mapping(result, col_idx, mapping, instance_id_col_idx)
# Only apply mapping for the first label field that has mappings
break

Expand All @@ -1107,6 +1112,7 @@ def _apply_single_field_mapping(
keypoints: np.ndarray,
col_idx: int,
mapping: dict[int, int],
instance_id_col_idx: int | None = None,
) -> np.ndarray:
"""Apply label mapping to a single label column. Swaps rows for paired labels or updates
unpaired; used internally by _swap_keypoint_rows_by_labels.
Expand All @@ -1115,11 +1121,23 @@ def _apply_single_field_mapping(
keypoints (np.ndarray): Keypoints array
col_idx (int): Column index of the label field
mapping (dict[int, int]): Label swap mapping
instance_id_col_idx (int | None): Optional column that keeps bound instance ids. When provided,
row swaps are constrained to each instance-id group.

Returns:
np.ndarray: Keypoints array with rows swapped

"""
if instance_id_col_idx is not None:
for instance_id in np.unique(keypoints[:, instance_id_col_idx]):
instance_indices = np.where(keypoints[:, instance_id_col_idx] == instance_id)[0]
keypoints[instance_indices] = self._apply_single_field_mapping(
keypoints[instance_indices].copy(),
col_idx,
mapping,
)
return keypoints

col_data = keypoints[:, col_idx].astype(int)
processed_labels = set()

Expand Down
85 changes: 84 additions & 1 deletion tests/test_instance_binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,81 @@ def test_out_of_bounds_keypoints_preserved(self) -> None:
assert len(result["instances"]) == 1
assert result["instances"][0]["keypoints"].shape[0] == 2

def test_label_mapping_uses_public_keypoint_label_field_names(self) -> None:
transform = A.Compose(
[A.HorizontalFlip(p=1)],
bbox_params=A.BboxParams(coord_format="pascal_voc"),
keypoint_params=A.KeypointParams(
coord_format="xy",
label_fields=["name"],
label_mapping={
"HorizontalFlip": {
"name": {
"left_eye": "right_eye",
"right_eye": "left_eye",
},
},
},
),
instance_binding=["masks", "bboxes", "keypoints"],
)

image = _make_image()
instances = [
{
"mask": _make_mask(region=(10, 50, 10, 50)),
"bbox": np.array([10, 10, 50, 50], dtype=np.float32),
"keypoints": np.array([[20.0, 20.0], [30.0, 30.0]], dtype=np.float32),
"keypoint_labels": {"name": ["left_eye", "right_eye"]},
},
]

result = transform(image=image, instances=instances)
assert result["instances"][0]["keypoint_labels"]["name"] == ["right_eye", "left_eye"]

def test_label_mapping_does_not_swap_keypoints_between_instances(self) -> None:
transform = A.Compose(
[A.HorizontalFlip(p=1)],
bbox_params=A.BboxParams(coord_format="pascal_voc"),
keypoint_params=A.KeypointParams(
coord_format="xy",
label_fields=["name"],
label_mapping={
"HorizontalFlip": {
"name": {
"left_eye": "right_eye",
"right_eye": "left_eye",
},
},
},
),
instance_binding=["masks", "bboxes", "keypoints"],
)

image = _make_image()
instances = [
{
"mask": _make_mask(region=(10, 30, 10, 30)),
"bbox": np.array([10, 10, 30, 30], dtype=np.float32),
"keypoints": np.array([[20.0, 20.0]], dtype=np.float32),
"keypoint_labels": {"name": ["left_eye"]},
},
{
"mask": _make_mask(region=(60, 80, 60, 80)),
"bbox": np.array([60, 60, 80, 80], dtype=np.float32),
"keypoints": np.array([[70.0, 70.0]], dtype=np.float32),
"keypoint_labels": {"name": ["right_eye"]},
},
]

result = transform(image=image, instances=instances)

assert len(result["instances"]) == 2
assert result["instances"][0]["keypoint_labels"]["name"] == ["right_eye"]
assert result["instances"][1]["keypoint_labels"]["name"] == ["left_eye"]
np.testing.assert_allclose(result["instances"][0]["keypoints"][:, :2], np.array([[79.0, 20.0]]))
np.testing.assert_allclose(result["instances"][1]["keypoints"][:, :2], np.array([[29.0, 70.0]]))


class TestOverlappingLabelNames:
def test_same_label_name_bbox_and_keypoint(self) -> None:
Expand Down Expand Up @@ -497,18 +572,26 @@ def test_to_dict_excludes_hidden_fields(self) -> None:
transform = A.Compose(
[A.NoOp(p=1)],
bbox_params=A.BboxParams(coord_format="pascal_voc", label_fields=["class_id"]),
keypoint_params=A.KeypointParams(coord_format="xy"),
keypoint_params=A.KeypointParams(
coord_format="xy",
label_fields=["name"],
label_mapping={"HorizontalFlip": {"name": {"left_eye": "right_eye"}}},
),
instance_binding=["masks", "bboxes", "keypoints"],
)

d = transform.to_dict_private()
bbox_label_fields = d["bbox_params"]["label_fields"]
kp_label_fields = d["keypoint_params"]["label_fields"]
kp_label_mapping = d["keypoint_params"]["label_mapping"]

assert "_bbox_instance_id" not in bbox_label_fields
assert "_ibl_bbox_class_id" not in bbox_label_fields
assert "_kp_instance_id" not in kp_label_fields
assert "_ibl_kp_name" not in kp_label_mapping["HorizontalFlip"]
assert "class_id" in bbox_label_fields
assert "name" in kp_label_fields
assert "name" in kp_label_mapping["HorizontalFlip"]
assert d["instance_binding"] == ["bboxes", "keypoints", "masks"]

def test_to_dict_omits_binding_when_none(self) -> None:
Expand Down
Loading