-
Notifications
You must be signed in to change notification settings - Fork 0
Add SVM classifier based separability scoring #10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@duopeng we have automatic formatting with pre-commit. You can install by just running |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this!
grassp/tools/scoring.py
Outdated
| Parameters | ||
| ---------- | ||
| data : DataFrame or AnnData |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstrings should not have the data type (:str , etc). Rather the arguments themselves, e.g.:
label_col: str = "consensus_graph_annnotation",
grassp/tools/scoring.py
Outdated
| auc_clustermap = sns.clustermap(auc_mat, | ||
| square=True, | ||
| annot=True, | ||
| fmt=".2f", | ||
| cmap="rocket", | ||
| vmin=0.5, | ||
| vmax=1, | ||
| cbar_kws=dict(label=f"ROC-AUC ({auc_model.upper()})"), | ||
| figsize=(heatmap_size[0], heatmap_size[1])) | ||
| auc_clustermap.fig.suptitle("Label separability\nPair-wise classifier-AUC") | ||
| auc_clustermap.ax_heatmap.set_xticklabels( | ||
| auc_clustermap.ax_heatmap.get_xticklabels(), rotation=45, ha='right') | ||
| figures['auc_fig'] = auc_clustermap |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Plotting should be handled separately from calculations in grassp.pl
grassp/tools/scoring.py
Outdated
| # Drop rows with missing coords or labels | ||
| df = df.dropna(subset=[label_col, *coord_cols]) | ||
|
|
||
| X_all = df[list(coord_cols)].values |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're converting a np.array (the .obsm or .X into a DataFrame and then back to a np.array. This seems inefficient
grassp/tools/scoring.py
Outdated
| if isinstance(data, ad.AnnData): | ||
| assert label_col in data.obs.columns, f"label_col {label_col} not in data.obs.columns" | ||
| X_all = sc.tools._utils._choose_representation(data, use_rep=use_rep, n_pcs=n_pcs) | ||
| df = pd.DataFrame(X_all) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're still converting X_all into a dataFrame and then back into an array (line 316).
grassp/tools/scoring.py
Outdated
| if DataFrame, then use column name as label | ||
| Defaults to "consensus_graph_annnotation" | ||
| use_rep : str, optional | ||
| coordinates (X in the classifier) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
X in the classifier might not mean much to users. Scanpy has Use the indicated representation. 'X' or any key for .obsm is valid.
grassp/tools/scoring.py
Outdated
| Defaults to "consensus_graph_annnotation" | ||
| use_rep : str, optional | ||
| coordinates (X in the classifier) | ||
| if AnnData, use .obsm[use_rep] if use_rep is a *str*, and .var[use_rep] if *list* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't need to support the .var[use_rep] case. Users can simply subset before!
grassp/tools/scoring.py
Outdated
| np.fill_diagonal(auc_mat.values, 0.5) | ||
|
|
||
| if inplace: | ||
| data.uns[f"separability ({label_col})"] = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prefer no spaces or special characters in names: separability_{label_col}. don't forget to also fix in plotting
grassp/plotting/heatmaps.py
Outdated
|
|
||
|
|
||
| def sep_auc_heatmap( | ||
| data: np.ndarray | pd.DataFrame, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should also be able to take an anndata object and look in .uns for entries with "separability_"
…satility to plotting function
|
I tested the code after Marika's cleanup, and it works nicely! we can merge this branch with main! |
No description provided.