Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 1cec7a4

Browse files
Add sizes to scatter plots (#289)
* Added a size feature. Not sure if it works yet, still need to add tests. * Added a few tests for the new sizes feature. * Forgot to add this file last commit. * Improved scatter_size.py example * Made some changes addressing Kushal's comments. * fixed an error caused by me forgetting to remove a cell in one of the notebooks * scattter plot added * updated to snake_case * updated to fixed some missing dependencies and remove unecessary code in the notebooks
1 parent 1aa83ab commit 1cec7a4

File tree

7 files changed

+331
-18
lines changed

7 files changed

+331
-18
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""
2+
Scatter Plot
3+
============
4+
Example showing point size change for scatter plot.
5+
"""
6+
7+
# test_example = true
8+
import numpy as np
9+
import fastplotlib as fpl
10+
11+
# grid with 2 rows and 3 columns
12+
grid_shape = (2,1)
13+
14+
# pan-zoom controllers for each view
15+
# views are synced if they have the
16+
# same controller ID
17+
controllers = [
18+
[0],
19+
[0]
20+
]
21+
22+
23+
# you can give string names for each subplot within the gridplot
24+
names = [
25+
["scalar_size"],
26+
["array_size"]
27+
]
28+
29+
# Create the grid plot
30+
plot = fpl.GridPlot(
31+
shape=grid_shape,
32+
controllers=controllers,
33+
names=names,
34+
size=(1000, 1000)
35+
)
36+
37+
# get y_values using sin function
38+
angles = np.arange(0, 20*np.pi+0.001, np.pi / 20)
39+
y_values = 30*np.sin(angles) # 1 thousand points
40+
x_values = np.array([x for x in range(len(y_values))], dtype=np.float32)
41+
42+
data = np.column_stack([x_values, y_values])
43+
44+
plot["scalar_size"].add_scatter(data=data, sizes=5, colors="blue") # add a set of scalar sizes
45+
46+
non_scalar_sizes = np.abs((y_values / np.pi)) # ensure minimum size of 5
47+
plot["array_size"].add_scatter(data=data, sizes=non_scalar_sizes, colors="red")
48+
49+
for graph in plot:
50+
graph.auto_scale(maintain_aspect=True)
51+
52+
plot.show()
53+
54+
if __name__ == "__main__":
55+
print(__doc__)
56+
fpl.run()
Lines changed: 3 additions & 0 deletions
Loading
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"from time import time\n",
10+
"\n",
11+
"import numpy as np\n",
12+
"import fastplotlib as fpl\n",
13+
"\n",
14+
"plot = fpl.Plot()\n",
15+
"\n",
16+
"points = np.array([[-1,0,1],[-1,0,1]], dtype=np.float32).swapaxes(0,1)\n",
17+
"size_delta_scales = np.array([10, 40, 100], dtype=np.float32)\n",
18+
"min_sizes = 6\n",
19+
"\n",
20+
"def update_positions():\n",
21+
" current_time = time()\n",
22+
" newPositions = points + np.sin(((current_time / 4) % 1)*np.pi)\n",
23+
" plot.graphics[0].data = newPositions\n",
24+
" plot.camera.width = 4*np.max(newPositions[0,:])\n",
25+
" plot.camera.height = 4*np.max(newPositions[1,:])\n",
26+
"\n",
27+
"def update_sizes():\n",
28+
" current_time = time()\n",
29+
" sin_sample = np.sin(((current_time / 4) % 1)*np.pi)\n",
30+
" size_delta = sin_sample*size_delta_scales\n",
31+
" plot.graphics[0].sizes = min_sizes + size_delta\n",
32+
"\n",
33+
"points = np.array([[0,0], \n",
34+
" [1,1], \n",
35+
" [2,2]])\n",
36+
"scatter = plot.add_scatter(points, colors=[\"red\", \"green\", \"blue\"], sizes=12)\n",
37+
"plot.add_animations(update_positions, update_sizes)\n",
38+
"plot.show(autoscale=True)"
39+
]
40+
},
41+
{
42+
"cell_type": "code",
43+
"execution_count": null,
44+
"metadata": {},
45+
"outputs": [],
46+
"source": []
47+
}
48+
],
49+
"metadata": {
50+
"kernelspec": {
51+
"display_name": "fastplotlib-dev",
52+
"language": "python",
53+
"name": "python3"
54+
},
55+
"language_info": {
56+
"codemirror_mode": {
57+
"name": "ipython",
58+
"version": 3
59+
},
60+
"file_extension": ".py",
61+
"mimetype": "text/x-python",
62+
"name": "python",
63+
"nbconvert_exporter": "python",
64+
"pygments_lexer": "ipython3",
65+
"version": "3.11.4"
66+
},
67+
"orig_nbformat": 4
68+
},
69+
"nbformat": 4,
70+
"nbformat_minor": 2
71+
}
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"\"\"\"\n",
10+
"Scatter Plot\n",
11+
"============\n",
12+
"Example showing point size change for scatter plot.\n",
13+
"\"\"\"\n",
14+
"\n",
15+
"# test_example = true\n",
16+
"import numpy as np\n",
17+
"import fastplotlib as fpl\n",
18+
"\n",
19+
"# grid with 2 rows and 3 columns\n",
20+
"grid_shape = (2,1)\n",
21+
"\n",
22+
"# pan-zoom controllers for each view\n",
23+
"# views are synced if they have the \n",
24+
"# same controller ID\n",
25+
"controllers = [\n",
26+
" [0],\n",
27+
" [0]\n",
28+
"]\n",
29+
"\n",
30+
"\n",
31+
"# you can give string names for each subplot within the gridplot\n",
32+
"names = [\n",
33+
" [\"scalar_size\"],\n",
34+
" [\"array_size\"]\n",
35+
"]\n",
36+
"\n",
37+
"# Create the grid plot\n",
38+
"plot = fpl.GridPlot(\n",
39+
" shape=grid_shape,\n",
40+
" controllers=controllers,\n",
41+
" names=names,\n",
42+
" size=(1000, 1000)\n",
43+
")\n",
44+
"\n",
45+
"# get y_values using sin function\n",
46+
"angles = np.arange(0, 20*np.pi+0.001, np.pi / 20)\n",
47+
"y_values = 30*np.sin(angles) # 1 thousand points\n",
48+
"x_values = np.array([x for x in range(len(y_values))], dtype=np.float32)\n",
49+
"\n",
50+
"data = np.column_stack([x_values, y_values])\n",
51+
"\n",
52+
"plot[\"scalar_size\"].add_scatter(data=data, sizes=5, colors=\"blue\") # add a set of scalar sizes\n",
53+
"\n",
54+
"non_scalar_sizes = np.abs((y_values / np.pi)) # ensure minimum size of 5\n",
55+
"plot[\"array_size\"].add_scatter(data=data, sizes=non_scalar_sizes, colors=\"red\")\n",
56+
"\n",
57+
"for graph in plot:\n",
58+
" graph.auto_scale(maintain_aspect=True)\n",
59+
"\n",
60+
"plot.show()"
61+
]
62+
}
63+
],
64+
"metadata": {
65+
"kernelspec": {
66+
"display_name": "fastplotlib-dev",
67+
"language": "python",
68+
"name": "python3"
69+
},
70+
"language_info": {
71+
"codemirror_mode": {
72+
"name": "ipython",
73+
"version": 3
74+
},
75+
"file_extension": ".py",
76+
"mimetype": "text/x-python",
77+
"name": "python",
78+
"nbconvert_exporter": "python",
79+
"pygments_lexer": "ipython3",
80+
"version": "3.11.4"
81+
},
82+
"orig_nbformat": 4
83+
},
84+
"nbformat": 4,
85+
"nbformat_minor": 2
86+
}

fastplotlib/graphics/_features/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from ._colors import ColorFeature, CmapFeature, ImageCmapFeature, HeatmapCmapFeature
22
from ._data import PointsDataFeature, ImageDataFeature, HeatmapDataFeature
3+
from ._sizes import PointsSizesFeature
34
from ._present import PresentFeature
45
from ._thickness import ThicknessFeature
56
from ._base import GraphicFeature, GraphicFeatureIndexable, FeatureEvent, to_gpu_supported_dtype
@@ -11,6 +12,7 @@
1112
"ImageCmapFeature",
1213
"HeatmapCmapFeature",
1314
"PointsDataFeature",
15+
"PointsSizesFeature",
1416
"ImageDataFeature",
1517
"HeatmapDataFeature",
1618
"PresentFeature",
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from typing import Any
2+
3+
import numpy as np
4+
5+
import pygfx
6+
7+
from ._base import (
8+
GraphicFeatureIndexable,
9+
cleanup_slice,
10+
FeatureEvent,
11+
to_gpu_supported_dtype,
12+
cleanup_array_slice,
13+
)
14+
15+
16+
class PointsSizesFeature(GraphicFeatureIndexable):
17+
"""
18+
Access to the vertex buffer data shown in the graphic.
19+
Supports fancy indexing if the data array also supports it.
20+
"""
21+
22+
def __init__(self, parent, sizes: Any, collection_index: int = None):
23+
sizes = self._fix_sizes(sizes, parent)
24+
super(PointsSizesFeature, self).__init__(
25+
parent, sizes, collection_index=collection_index
26+
)
27+
28+
@property
29+
def buffer(self) -> pygfx.Buffer:
30+
return self._parent.world_object.geometry.sizes
31+
32+
def __getitem__(self, item):
33+
return self.buffer.data[item]
34+
35+
def _fix_sizes(self, sizes, parent):
36+
graphic_type = parent.__class__.__name__
37+
38+
n_datapoints = parent.data().shape[0]
39+
if not isinstance(sizes, (list, tuple, np.ndarray)):
40+
sizes = np.full(n_datapoints, sizes, dtype=np.float32) # force it into a float to avoid weird gpu errors
41+
elif not isinstance(sizes, np.ndarray): # if it's not a ndarray already, make it one
42+
sizes = np.array(sizes, dtype=np.float32) # read it in as a numpy.float32
43+
if (sizes.ndim != 1) or (sizes.size != parent.data().shape[0]):
44+
raise ValueError(
45+
f"sequence of `sizes` must be 1 dimensional with "
46+
f"the same length as the number of datapoints"
47+
)
48+
49+
sizes = to_gpu_supported_dtype(sizes)
50+
51+
if any(s < 0 for s in sizes):
52+
raise ValueError("All sizes must be positive numbers greater than or equal to 0.0.")
53+
54+
if sizes.ndim == 1:
55+
if graphic_type == "ScatterGraphic":
56+
sizes = np.array(sizes)
57+
else:
58+
raise ValueError(f"Sizes must be an array of shape (n,) where n == the number of data points provided.\
59+
Received shape={sizes.shape}.")
60+
61+
return np.array(sizes)
62+
63+
def __setitem__(self, key, value):
64+
if isinstance(key, np.ndarray):
65+
# make sure 1D array of int or boolean
66+
key = cleanup_array_slice(key, self._upper_bound)
67+
68+
# put sizes into right shape if they're only indexing datapoints
69+
if isinstance(key, (slice, int, np.ndarray, np.integer)):
70+
value = self._fix_sizes(value, self._parent)
71+
# otherwise assume that they have the right shape
72+
# numpy will throw errors if it can't broadcast
73+
74+
if value.size != self.buffer.data[key].size:
75+
raise ValueError(f"{value.size} is not equal to buffer size {self.buffer.data[key].size}.\
76+
If you want to set size to a non-scalar value, make sure it's the right length!")
77+
78+
self.buffer.data[key] = value
79+
self._update_range(key)
80+
# avoid creating dicts constantly if there are no events to handle
81+
if len(self._event_handlers) > 0:
82+
self._feature_changed(key, value)
83+
84+
def _update_range(self, key):
85+
self._update_range_indices(key)
86+
87+
def _feature_changed(self, key, new_data):
88+
if key is not None:
89+
key = cleanup_slice(key, self._upper_bound)
90+
if isinstance(key, (int, np.integer)):
91+
indices = [key]
92+
elif isinstance(key, slice):
93+
indices = range(key.start, key.stop, key.step)
94+
elif isinstance(key, np.ndarray):
95+
indices = key
96+
elif key is None:
97+
indices = None
98+
99+
pick_info = {
100+
"index": indices,
101+
"collection-index": self._collection_index,
102+
"world_object": self._parent.world_object,
103+
"new_data": new_data,
104+
}
105+
106+
event_data = FeatureEvent(type="sizes", pick_info=pick_info)
107+
108+
self._call_event_handlers(event_data)

fastplotlib/graphics/scatter.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@
55

66
from ..utils import parse_cmap_values
77
from ._base import Graphic
8-
from ._features import PointsDataFeature, ColorFeature, CmapFeature
8+
from ._features import PointsDataFeature, ColorFeature, CmapFeature, PointsSizesFeature
99

1010

1111
class ScatterGraphic(Graphic):
12-
feature_events = ("data", "colors", "cmap", "present")
12+
feature_events = ("data", "sizes", "colors", "cmap", "present")
1313

1414
def __init__(
1515
self,
1616
data: np.ndarray,
17-
sizes: Union[int, np.ndarray, list] = 1,
17+
sizes: Union[int, float, np.ndarray, list] = 1,
1818
colors: np.ndarray = "w",
1919
alpha: float = 1.0,
2020
cmap: str = None,
@@ -86,24 +86,11 @@ def __init__(
8686
self, self.colors(), cmap_name=cmap, cmap_values=cmap_values
8787
)
8888

89-
if isinstance(sizes, int):
90-
sizes = np.full(self.data().shape[0], sizes, dtype=np.float32)
91-
elif isinstance(sizes, np.ndarray):
92-
if (sizes.ndim != 1) or (sizes.size != self.data().shape[0]):
93-
raise ValueError(
94-
f"numpy array of `sizes` must be 1 dimensional with "
95-
f"the same length as the number of datapoints"
96-
)
97-
elif isinstance(sizes, list):
98-
if len(sizes) != self.data().shape[0]:
99-
raise ValueError(
100-
"list of `sizes` must have the same length as the number of datapoints"
101-
)
102-
89+
self.sizes = PointsSizesFeature(self, sizes)
10390
super(ScatterGraphic, self).__init__(*args, **kwargs)
10491

10592
world_object = pygfx.Points(
106-
pygfx.Geometry(positions=self.data(), sizes=sizes, colors=self.colors()),
93+
pygfx.Geometry(positions=self.data(), sizes=self.sizes(), colors=self.colors()),
10794
material=pygfx.PointsMaterial(vertex_colors=True, vertex_sizes=True),
10895
)
10996

0 commit comments

Comments
 (0)