-
Notifications
You must be signed in to change notification settings - Fork 317
Open
Labels
type:featureNew feature or requestNew feature or request
Description
I’d like to propose adding a few popular self-supervised losses to
optax.losses._self_supervised:
byol_losssimsiam_lossdino_lossbarlow_twins_loss
Motivation
These objectives are widely used in modern self-supervised representation
learning pipelines, especially for vision, and having them in Optax would:
- Make it easier to prototype and compare self-supervised methods on top of
JAX/Flax. - Provide a single, well-tested implementation instead of many slightly
different copies. - Nicely complement the existing
ntxentand triplet margin losses already
present in_self_supervised.py.
Proposed API (high-level)
All functions follow the same style as existing Optax losses:
- Pure JAX functions, compatible with
jit/vmap. jax.typing.ArrayLikearguments andjax.Arrayreturn types.- Shape checks and
utils.check_subdtypefor float inputs. - Docstrings with examples and references.
Rough signatures:
def byol_loss(
online_projection_1: jax.typing.ArrayLike,
target_projection_2: jax.typing.ArrayLike,
online_projection_2: jax.typing.ArrayLike,
target_projection_1: jax.typing.ArrayLike,
eps: jax.typing.ArrayLike = 1e-6,
) -> jax.Array:
def simsiam_loss(
predictor_projection_1: jax.typing.ArrayLike,
target_projection_2: jax.typing.ArrayLike,
predictor_projection_2: jax.typing.ArrayLike,
target_projection_1: jax.typing.ArrayLike,
eps: jax.typing.ArrayLike = 1e-6,
) -> jax.Array:
def dino_loss(
student_logits: jax.typing.ArrayLike,
teacher_logits: jax.typing.ArrayLike,
student_temperature: jax.typing.ArrayLike = 0.1,
teacher_temperature: jax.typing.ArrayLike = 0.04,
teacher_center: jax.typing.ArrayLike = 0.0,
) -> jax.Array:
def barlow_twins_loss(
projection_1: jax.typing.ArrayLike,
projection_2: jax.typing.ArrayLike,
off_diagonal_scale: jax.typing.ArrayLike = 5e-3,
eps: jax.typing.ArrayLike = 1e-12,
) -> jax.Array:
References
-
BYOL – Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning
https://arxiv.org/abs/2006.07733 -
SimSiam – Exploring Simple Siamese Representation Learning
https://arxiv.org/abs/2011.10566 -
DINO – Emerging Properties in Self-Supervised Vision Transformers
https://arxiv.org/abs/2104.14294 -
Barlow Twins – Barlow Twins: Self-Supervised Learning via Redundancy Reduction
https://arxiv.org/abs/2103.03230
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
type:featureNew feature or requestNew feature or request