Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions benchmarks/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,5 @@ def sync(device):
def _sync():
if device == "cuda":
torch.cuda.synchronize()
return _sync


return _sync
55 changes: 32 additions & 23 deletions benchmarks/test_bspline_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@


@pytest.mark.perf
@pytest.mark.parametrize("batch,curves,dim,degree,n_ctrl", [
(256, 32, 64, 3, 16),
(256, 64, 64, 3, 32),
(256, 64, 64, 5, 32),
])
@pytest.mark.parametrize(
"batch,curves,dim,degree,n_ctrl",
[
(256, 32, 64, 3, 16),
(256, 64, 64, 3, 32),
(256, 64, 64, 5, 32),
],
)
def test_bspline_forward(benchmark, device, sync, batch, curves, dim, degree, n_ctrl):
"""Benchmark forward pass only (no gradients required)."""
torch.manual_seed(0)
Expand All @@ -29,11 +32,14 @@ def run():


@pytest.mark.perf
@pytest.mark.parametrize("batch,curves,dim,degree,n_ctrl", [
(128, 32, 64, 3, 16),
(128, 64, 64, 3, 32),
(128, 64, 64, 5, 32),
])
@pytest.mark.parametrize(
"batch,curves,dim,degree,n_ctrl",
[
(128, 32, 64, 3, 16),
(128, 64, 64, 3, 32),
(128, 64, 64, 5, 32),
],
)
def test_bspline_backward_params(benchmark, device, sync, batch, curves, dim, degree, n_ctrl):
"""Benchmark backward pass through parameters only (inputs don't require grad)."""
torch.manual_seed(0)
Expand All @@ -59,11 +65,14 @@ def run():


@pytest.mark.perf
@pytest.mark.parametrize("batch,curves,dim,degree,n_ctrl", [
(128, 32, 64, 3, 16),
(128, 64, 64, 3, 32),
(128, 64, 64, 5, 32),
])
@pytest.mark.parametrize(
"batch,curves,dim,degree,n_ctrl",
[
(128, 32, 64, 3, 16),
(128, 64, 64, 3, 32),
(128, 64, 64, 5, 32),
],
)
def test_bspline_backward_inputs(benchmark, device, sync, batch, curves, dim, degree, n_ctrl):
"""Benchmark backward pass through inputs only (parameters don't require grad)."""
torch.manual_seed(0)
Expand Down Expand Up @@ -91,11 +100,14 @@ def run():


@pytest.mark.perf
@pytest.mark.parametrize("batch,curves,dim,degree,n_ctrl", [
(128, 32, 64, 3, 16),
(128, 64, 64, 3, 32),
(128, 64, 64, 5, 32),
])
@pytest.mark.parametrize(
"batch,curves,dim,degree,n_ctrl",
[
(128, 32, 64, 3, 16),
(128, 64, 64, 3, 32),
(128, 64, 64, 5, 32),
],
)
def test_bspline_backward_both(benchmark, device, sync, batch, curves, dim, degree, n_ctrl):
"""Benchmark backward pass through both parameters and inputs."""
torch.manual_seed(0)
Expand All @@ -120,6 +132,3 @@ def run():
return loss

benchmark(run)



46 changes: 28 additions & 18 deletions benchmarks/test_legendre_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@


@pytest.mark.perf
@pytest.mark.parametrize("batch,curves,dim,degree", [
(256, 32, 64, 8),
(256, 64, 64, 16),
])
@pytest.mark.parametrize(
"batch,curves,dim,degree",
[
(256, 32, 64, 8),
(256, 64, 64, 16),
],
)
def test_legendre_forward(benchmark, device, sync, batch, curves, dim, degree):
"""Benchmark forward pass only (no gradients required)."""
torch.manual_seed(0)
Expand All @@ -28,10 +31,13 @@ def run():


@pytest.mark.perf
@pytest.mark.parametrize("batch,curves,dim,degree", [
(128, 32, 64, 8),
(128, 64, 64, 16),
])
@pytest.mark.parametrize(
"batch,curves,dim,degree",
[
(128, 32, 64, 8),
(128, 64, 64, 16),
],
)
def test_legendre_backward_params(benchmark, device, sync, batch, curves, dim, degree):
"""Benchmark backward pass through parameters only (inputs don't require grad)."""
torch.manual_seed(0)
Expand All @@ -57,10 +63,13 @@ def run():


@pytest.mark.perf
@pytest.mark.parametrize("batch,curves,dim,degree", [
(128, 32, 64, 8),
(128, 64, 64, 16),
])
@pytest.mark.parametrize(
"batch,curves,dim,degree",
[
(128, 32, 64, 8),
(128, 64, 64, 16),
],
)
def test_legendre_backward_inputs(benchmark, device, sync, batch, curves, dim, degree):
"""Benchmark backward pass through inputs only (parameters don't require grad)."""
torch.manual_seed(0)
Expand Down Expand Up @@ -88,10 +97,13 @@ def run():


@pytest.mark.perf
@pytest.mark.parametrize("batch,curves,dim,degree", [
(128, 32, 64, 8),
(128, 64, 64, 16),
])
@pytest.mark.parametrize(
"batch,curves,dim,degree",
[
(128, 32, 64, 8),
(128, 64, 64, 16),
],
)
def test_legendre_backward_both(benchmark, device, sync, batch, curves, dim, degree):
"""Benchmark backward pass through both parameters and inputs."""
torch.manual_seed(0)
Expand All @@ -116,5 +128,3 @@ def run():
return loss

benchmark(run)


36 changes: 16 additions & 20 deletions src/torchcurves/functional/_bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,27 +206,23 @@ def basis_derivative_coefficients(

"""
num_samples_n, num_curves_m = spans.shape
device, dtype = spans.device, knots.dtype # Use knot's dtype for coeffs
device, _ = spans.device, knots.dtype # Use knot's dtype for coeffs

degrees_range = torch.arange(degree + 1, device=device).view(1, 1, -1)
knots_idx = spans.unsqueeze(-1) - degree + degrees_range # (N, M, degree+1)
degrees_range = torch.arange(-degree, 1, device=device).view(1, 1, -1)
knots_idx = spans.unsqueeze(-1) + degrees_range # (N, M, degree+1)
max_knot_idx = knots.shape[0] - 1

# Gather knot values - knots[knots_idx] will broadcast correctly
knots_k = knots[knots_idx.clamp(min=0, max=knots.shape[0] - 1)]
knots_k_plus_deg = knots[(knots_idx + degree).clamp(min=0, max=knots.shape[0] - 1)]
knots_k_plus_1 = knots[(knots_idx + 1).clamp(min=0, max=knots.shape[0] - 1)]
knots_k_plus_deg_plus_1 = knots[(knots_idx + degree + 1).clamp(min=0, max=knots.shape[0] - 1)]
knots_k = knots[knots_idx]
knots_k_plus_deg = knots[(knots_idx + degree).clamp(max=max_knot_idx)]
knots_k_plus_1 = knots[(knots_idx + 1).clamp(max=max_knot_idx)]
knots_k_plus_deg_plus_1 = knots[(knots_idx + (degree + 1)).clamp(max=max_knot_idx)]

alpha_coeffs_batch = torch.zeros(num_samples_n, num_curves_m, degree + 1, device=device, dtype=dtype)
beta_coeffs_batch = torch.zeros(num_samples_n, num_curves_m, degree + 1, device=device, dtype=dtype)
alpha_coeffs_batch = degree / (knots_k_plus_deg - knots_k)
alpha_coeffs_batch.nan_to_num_(0, 0, 0)

denom_alpha = knots_k_plus_deg - knots_k
mask_alpha = torch.abs(denom_alpha) > _BSplineFunction.ZERO_TOLERANCE
alpha_coeffs_batch[mask_alpha] = degree / denom_alpha[mask_alpha]

denom_beta = knots_k_plus_deg_plus_1 - knots_k_plus_1
mask_beta = torch.abs(denom_beta) > _BSplineFunction.ZERO_TOLERANCE
beta_coeffs_batch[mask_beta] = degree / denom_beta[mask_beta]
beta_coeffs_batch = degree / (knots_k_plus_deg_plus_1 - knots_k_plus_1)
beta_coeffs_batch.nan_to_num_(0, 0, 0)

return alpha_coeffs_batch, beta_coeffs_batch

Expand Down Expand Up @@ -254,8 +250,8 @@ def compute_basis_derivatives(
# Pad (1,0) means add 1 zero to the left: [0, N0,...,N(deg-1)]
lower_pad_left = F.pad(lower_deg_basis, (1, 0))

basis_deriv = alpha * lower_pad_left - beta * lower_pad_right
return basis_deriv
# compute derivative without allocating redundant memory.
return torch.addcmul(alpha * lower_pad_left, -1, beta, lower_pad_right)

@staticmethod
def forward(
Expand All @@ -280,8 +276,8 @@ def forward(
ctx.n_control_points_per_curve = n_control_points_per_curve # C

# For re-computing control_point_indices in backward
degrees_range = torch.arange(degree + 1, device=spans.device).view(1, 1, -1)
ctx.control_point_indices = spans.unsqueeze(-1) - degree + degrees_range # (N,M,degree+1)
degrees_range = torch.arange(-degree, 1, device=spans.device).view(1, 1, -1)
ctx.control_point_indices = spans.unsqueeze(-1) + degrees_range # (N,M,degree+1)

return points

Expand Down
Loading