Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Feature request: add BYOL, SimSiam, DINO and Barlow Twins losses to optax.losses #1528

@surajyadav-research

Description

@surajyadav-research

I’d like to propose adding a few popular self-supervised losses to
optax.losses._self_supervised:

  • byol_loss
  • simsiam_loss
  • dino_loss
  • barlow_twins_loss

Motivation

These objectives are widely used in modern self-supervised representation
learning pipelines, especially for vision, and having them in Optax would:

  • Make it easier to prototype and compare self-supervised methods on top of
    JAX/Flax.
  • Provide a single, well-tested implementation instead of many slightly
    different copies.
  • Nicely complement the existing ntxent and triplet margin losses already
    present in _self_supervised.py.

Proposed API (high-level)

All functions follow the same style as existing Optax losses:

  • Pure JAX functions, compatible with jit/vmap.
  • jax.typing.ArrayLike arguments and jax.Array return types.
  • Shape checks and utils.check_subdtype for float inputs.
  • Docstrings with examples and references.

Rough signatures:

def byol_loss(
    online_projection_1: jax.typing.ArrayLike,
    target_projection_2: jax.typing.ArrayLike,
    online_projection_2: jax.typing.ArrayLike,
    target_projection_1: jax.typing.ArrayLike,
    eps: jax.typing.ArrayLike = 1e-6,
) -> jax.Array:


def simsiam_loss(
    predictor_projection_1: jax.typing.ArrayLike,
    target_projection_2: jax.typing.ArrayLike,
    predictor_projection_2: jax.typing.ArrayLike,
    target_projection_1: jax.typing.ArrayLike,
    eps: jax.typing.ArrayLike = 1e-6,
) -> jax.Array:


def dino_loss(
    student_logits: jax.typing.ArrayLike,
    teacher_logits: jax.typing.ArrayLike,
    student_temperature: jax.typing.ArrayLike = 0.1,
    teacher_temperature: jax.typing.ArrayLike = 0.04,
    teacher_center: jax.typing.ArrayLike = 0.0,
) -> jax.Array:


def barlow_twins_loss(
    projection_1: jax.typing.ArrayLike,
    projection_2: jax.typing.ArrayLike,
    off_diagonal_scale: jax.typing.ArrayLike = 5e-3,
    eps: jax.typing.ArrayLike = 1e-12,
) -> jax.Array:

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions