-
-
Notifications
You must be signed in to change notification settings - Fork 26.4k
FEA Add array API support for LogisticRegression with LBFGS #32644
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
ogrisel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also please profile a run using mps or cuda using py-spy?
Some Benchmarksfrom time import time
import numpy as np
import torch as xp
from tqdm import tqdm
from sklearn import config_context
from sklearn.linear_model import LogisticRegression
n_samples, n_features, n_classes = 1000000, 300, 20
device = "cuda"
n_iter = 10
X_np = np.random.rand(n_samples, n_features)
y_np = np.random.randint(0, 10, n_samples)
numpy_fit_times = []
numpy_predict_times = []
for _ in tqdm(range(n_iter), desc="Numpy"):
lr = LogisticRegression(C=0.8, solver="lbfgs", max_iter=200)
start = time()
lr.fit(X_np, y_np)
numpy_fit_times.append(round(time() - start, 3))
start = time()
pred = lr.predict_proba(X_np)
numpy_predict_times.append(round(time() - start, 3))
avg_numpy_fit = round(sum(numpy_fit_times) / n_iter, 3)
avg_numpy_predict = round(sum(numpy_predict_times) / n_iter, 3)
torch_fit_times = []
torch_predict_times = []
X_xp = xp.asarray(X_np, device=device)
y_xp = xp.asarray(y_np, device=device)
for _ in tqdm(range(n_iter), desc=f"Torch {device}"):
with config_context(array_api_dispatch=True):
lr = LogisticRegression(C=0.8, solver="lbfgs", max_iter=200)
start = time()
lr.fit(X_xp, y_xp)
torch_fit_times.append(round(time() - start, 3))
start = time()
pred = lr.predict_proba(X_xp)
first = float(pred[0, 0])
torch_predict_times.append(round(time() - start, 3))
avg_torch_fit = round(sum(torch_fit_times) / n_iter, 3)
avg_torch_predict = round(sum(torch_predict_times) / n_iter, 3)
print(f"Average fit time numpy: {avg_numpy_fit}")
print(f"Average fit time torch {device}: {avg_torch_fit}")
print(f"Torch {device} fit speedup: {round(avg_numpy_fit / avg_torch_fit, 2)}X")
print(f"Average predict time numpy: {avg_numpy_predict}")
print(f"Average predict time torch {device}: {avg_torch_predict}")
print(
f"Torch {device} predict speedup: {round(avg_numpy_predict / avg_torch_predict, 2)}"
"X"
)
Average fit time numpy: 23.526 Average predict time numpy: 1.133 |
|
It's nice to get a speed-up with CUDA besides the conversion of the raw predictions and pointwise gradient values of the loss at each iteration. Can you post the SVG of the py-spy profiling results both for the pytorch/CUDA and the numpy/CPU runs? If the conversion of the raw predictions / pointwise gradients are significant, I think we should try to implement an alternative to the Cython gradient function using the array API to skip those conversions directly as part of this PR. |
ogrisel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some more feedback.
ogrisel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another pass of feedback:
|
For the record, I observe a speed-up of 2x when using PyTorch/MPS vs NumPy/CPU (with OpenBLAS) on an Apple M4 laptop using a 50-class classification problem. I see no significant speed-up for binary classification. EDIT: I increased the dataset size and I now also observe a factor of 2x speed up when using PyTorch/MPS vs NumPy/CPU (with OpenBLAS) for binary classification. |
|
While evaluating locally, I found a bug for a specific dataset size: import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["SCIPY_ARRAY_API"] = "1"
import torch
import numpy as np
from sklearn import set_config
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score
from sklearn.preprocessing import minmax_scale
# Uncomment when using newaccelerate BLAS on macOS.
# import warnings
# warnings.filterwarnings("ignore", category=RuntimeWarning)
set_config(array_api_dispatch=True)
X, y = make_classification(
n_samples=int(1e6), n_classes=2, n_features=100, n_informative=90, random_state=0
)
X = X.astype("float32")
# X = minmax_scale(X)
X_mps = torch.from_numpy(X).to("mps")
y_mps = torch.from_numpy(y).to("mps")
clf = LogisticRegression(max_iter=1000)
clf.fit(X, y).n_iter_, clf.score(X, y)outputs: (array([37], dtype=int32), 0.851278)while: clf.fit(X_mps, y_mps).n_iter_, clf.score(X_mps, y_mps)outputs a bad model without any error message nor (tensor([1], device='mps:0', dtype=torch.int32), 0.4999749958515167)@OmarManzoor can you reproduce? The fact that the model converges to a chance level prediction function after only 1 iteration without any error message sounds like a bug to me. Note that the problem goes away when:
|
|
Since the problem happens at the first iteration, we could debug by printing (parts of) the sample-wise and parameter-wise gradient vectors for each run with |
|
I updated the This is the output of the above script which was causing issues: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I started to take a deeper look at the actual changes, and here is a first pass of feedback.
ogrisel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some more feedback:
|
I've not yet looked at the diff but have run the benchmark script from #32644 (comment) with the following setting The results are: I've not systematically tried different shapes/sizes, but from trying smaller The GPU is a A6000, CPU details in the fold out. Details
|
|
@betatim This seems to be a lot slower than what I observed on a colab T4 gpu, maybe because of the higher n_samples |
It's also possible that the CPUs on the T4 instance of collab are particularly slow and therefore inflate the impact of using CUDA on that machine. |
sklearn/utils/validation.py
Outdated
| dtype_param = {} | ||
| # We need to force `float32` with the `mps` device. | ||
| if str(device).startswith("mps"): | ||
| dtype_param["dtype"] = max_float_type | ||
| sample_weight = xp.asarray(sample_weight, device=device, **dtype_param) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The following should work instead:
| dtype_param = {} | |
| # We need to force `float32` with the `mps` device. | |
| if str(device).startswith("mps"): | |
| dtype_param["dtype"] = max_float_type | |
| sample_weight = xp.asarray(sample_weight, device=device, **dtype_param) | |
| if isinstance(dtype, (list, tuple)) and len(dtype) == 1: | |
| # If there is only one possible dtype, use it directly to avoid | |
| # multiple conversions or attempts to convert to a dtype that is | |
| # not supported by the device. | |
| sample_weight = xp.asarray(sample_weight, device=device, dtype=dtype[0]) | |
| else: | |
| # Let check_array choose the best float dtype from the list. | |
| sample_weight = xp.asarray(sample_weight, device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't work in the case where sample_weight is a numpy array and dtype is just a simple xp dtype like torch.float32
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand how this can happen. Why would we pass a torch dtype and a numpy input using the public API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When fitting LogisticRegression with X a torch tensor with float32 device and sample_weight a float64 numpy array?
But then, this is not a problem, right?
>>> import numpy as np
>>> import array_api_compat.torch as xp_torch
>>> xp_torch.asarray(np.ones(3, dtype=np.float64), dtype=xp_torch.float32, device="mps")
tensor([1., 1., 1.], device='mps:0')There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes but we are in this case:
sample_weight = xp.asarray(sample_weight, device=device)Because dtype=X.dtype (torch.float32)
because of this in LogisticRegression
if sample_weight is not None or class_weight is not None:
sample_weight = _check_sample_weight(
sample_weight, X, dtype=X.dtype, copy=True, ensure_same_device=True
)|
I checked for CPU on colab. I think the CPU might be the difference in these observed timings: processor : 0
vendor_id : GenuineIntel
cpu family : 6
model : 85
model name : Intel(R) Xeon(R) CPU @ 2.00GHz
stepping : 3
microcode : 0xffffffff
cpu MHz : 2000.180
cache size : 39424 KB
physical id : 0
siblings : 2
core id : 0
cpu cores : 1
apicid : 0
initial apicid : 0
fpu : yes
fpu_exception : yes
cpuid level : 13
wp : yes
flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat md_clear arch_capabilities
bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa mmio_stale_data retbleed bhi its
bogomips : 4000.36
clflush size : 64
cache_alignment : 64
address sizes : 46 bits physical, 48 bits virtual
power management:
processor : 1
vendor_id : GenuineIntel
cpu family : 6
model : 85
model name : Intel(R) Xeon(R) CPU @ 2.00GHz
stepping : 3
microcode : 0xffffffff
cpu MHz : 2000.180
cache size : 39424 KB
physical id : 0
siblings : 2
core id : 0
cpu cores : 1
apicid : 1
initial apicid : 1
fpu : yes
fpu_exception : yes
cpuid level : 13
wp : yes
flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat md_clear arch_capabilities
bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa mmio_stale_data retbleed bhi its
bogomips : 4000.36
clflush size : 64
cache_alignment : 64
address sizes : 46 bits physical, 48 bits virtual
power management:With Average fit time numpy: 22.72 Average predict time numpy: 1.082 |
Reference Issues/PRs
Towards: #32611
What does this implement/fix? Explain your changes.
Any other comments?