Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit fe86651

Browse files
px_combine prototype now smarter
reflows colors, changes adds figure titles to legend names to differentiate the source figures in the final plot, includes all annotation-like objects in the final plot.
1 parent 440774a commit fe86651

File tree

6 files changed

+233
-28
lines changed

6 files changed

+233
-28
lines changed

proto/px_combine/find_field.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import plotly.graph_objects as go
2+
from plotly import basedatatypes
3+
4+
# Search down an object's composition tree and find fields with a given name
5+
6+
7+
def find_field(obj, field, basepath="", max_path_len=80, forbidden=["parent"]):
8+
if obj is not None and len(basepath) < max_path_len:
9+
for f in dir(obj):
10+
joined_path = ".".join([basepath, f])
11+
if f == field:
12+
print(joined_path)
13+
if (
14+
(f not in forbidden)
15+
and (not f.startswith("_"))
16+
and (not f.endswith("_"))
17+
):
18+
find_field(eval("obj.%s" % (f,)), field, joined_path)
19+
20+
21+
def find_all_xy_traces():
22+
for field in dir(go):
23+
call_str = "go.%s" % (field,)
24+
call = eval(call_str)
25+
try:
26+
if issubclass(call, basedatatypes.BaseTraceType):
27+
obj = call()
28+
if "xaxis" in obj and "yaxis" in obj:
29+
yield (call_str)
30+
except TypeError:
31+
pass
32+
33+
34+
# s=go.Scatter()
35+
# s=go.Bar()
36+
# find_field(s,"color",basepath="scatter")
37+
# print()
38+
# find_field(s,"color",basepath="bar")
39+
40+
for call_str in find_all_xy_traces():
41+
call = eval(call_str)
42+
find_field(call(), "color", basepath=call_str)
43+
print()
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import test_data
2+
import numpy as np
3+
import plotly.express as px
4+
from px_combine import px_combine_secondary_y, px_simple_combine
5+
6+
df = test_data.multilayered_data(d_divs=[2, 3, 4, 2], rwalk=0.1)
7+
print(df)
8+
last_cat = df.columns[3]
9+
figs = []
10+
for px_call, last_cat_0 in zip([px.line, px.bar], list(set(df[last_cat]))):
11+
df_slice = df.loc[df[last_cat] == last_cat_0]
12+
fig = px_call(
13+
df_slice,
14+
x="x",
15+
y="y",
16+
facet_row=df.columns[0],
17+
facet_col=df.columns[1],
18+
color=df.columns[2],
19+
)
20+
fig.update_layout(title="%s=%s" % (last_cat, last_cat_0,))
21+
figs.append(fig)
22+
23+
figs[0].add_hline(y=1, row=1, col="all")
24+
figs[1].add_vline(x=10, row="all", col=2)
25+
figs[0].add_annotation(
26+
x=0.25, y=0.5, xref="x domain", yref="y domain", row=2, col=3, text="yo"
27+
)
28+
figs[1].add_annotation(
29+
x=0.5, y=0.35, xref="x domain", yref="y domain", row=1, col=2, text="budday"
30+
)
31+
figs[0].layout.barmode = "group"
32+
figs[1].layout.barmode = "relative"
33+
final_fig = px_simple_combine(*figs)
34+
for fig in figs:
35+
fig.show()
36+
final_fig.show()

proto/px_combine_proto/px_combine.py renamed to proto/px_combine/px_combine.py

Lines changed: 86 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from plotly.subplots import make_subplots
99
import test_data
1010
import json
11-
from itertools import product
11+
from itertools import product, cycle, chain
12+
from functools import reduce
1213

1314

1415
def multi_index(*kwargs):
@@ -58,21 +59,93 @@ def px_simple_combine(fig0, fig1):
5859
raise ValueError(
5960
"Only two figures with the same subplot geometry can be combined."
6061
)
61-
if fig0.layout.annotations != fig1.layout.annotations:
62-
raise ValueError(
63-
"Only two figures created with Plotly Express with "
64-
"identical faceting can be combined."
65-
)
66-
fig = go.Figure(data=fig0.data + fig1.data, layout=fig0.layout)
62+
# reflow the colors
63+
colorway = fig0.layout.template.layout.colorway
64+
fig = make_subplots(*fig_grid_ref_shape(fig0))
65+
for r, c in multi_index(*fig_grid_ref_shape(fig)):
66+
for (tr, title), color in zip(
67+
chain(
68+
*[
69+
zip(
70+
f.select_traces(row=r + 1, col=c + 1),
71+
cycle([f.layout.title.text]),
72+
)
73+
for f in [fig0, fig1]
74+
]
75+
),
76+
cycle(colorway),
77+
):
78+
set_main_trace_color(tr, color)
79+
# use figure title to differentiate the legend items
80+
tr["name"] = "%s %s" % (title, tr["name"])
81+
# TODO: argument to group legend items?
82+
tr["legendgroup"] = None
83+
fig.add_trace(tr, row=r + 1, col=c + 1)
84+
fig.update_layout(fig0.layout)
85+
# title will be wrong
86+
fig.layout.title = None
87+
# preserve bar mode
88+
# if both figures have barmode set, the first is taken, otherwise the set one is taken
89+
# TODO argument to force barmode? or the user can just update it after
90+
fig.layout.barmode = get_first_set_barmode([fig0, fig1])
91+
# also include annotations, shapes and layout images from fig1
92+
for kw in ["annotations", "shapes", "images"]:
93+
fig.layout[kw] += fig1.layout[kw]
6794
return fig
6895

6996

97+
def select_all_traces(figs):
98+
traces = list(
99+
reduce(
100+
lambda a, b: a + b,
101+
map(lambda t: list(go.Figure.select_traces(t)), figs),
102+
[],
103+
)
104+
)
105+
return traces
106+
107+
108+
def check_trace_type_xy(tr):
109+
return ("xaxis" in tr) and ("yaxis" in tr)
110+
111+
112+
def check_figs_trace_types_xy(figs):
113+
traces = select_all_traces(figs)
114+
xy_traces = list(map(check_trace_type_xy, traces))
115+
return xy_traces
116+
117+
118+
def set_main_trace_color(tr, color):
119+
# Set the main color of a trace
120+
if type(tr) == type(go.Scatter()):
121+
if tr["mode"] == "lines":
122+
tr["line_color"] = color
123+
else:
124+
tr["marker_color"] = color
125+
elif type(tr) == type(go.Bar()):
126+
tr["marker_color"] = color
127+
128+
129+
def get_first_set_barmode(figs):
130+
barmode = None
131+
try:
132+
barmode = list(
133+
filter(lambda x: x is not None, [f.layout.barmode for f in figs])
134+
)[0]
135+
except IndexError:
136+
# if no figure sets barmode, then it is not set
137+
pass
138+
return barmode
139+
140+
70141
def px_combine_secondary_y(fig0, fig1):
71142
"""
72143
Combines two figures that have the same faceting but whose y axes refer
73144
to different data by referencing the second figure's y-data to secondary
74145
y-axes.
75146
"""
147+
if not all(check_figs_trace_types_xy([fig0, fig1])):
148+
raise ValueError('Only subplots containing "xy" trace types may be combined')
76149
grid_ref_shape = fig_grid_ref_shape(fig0)
77150
if grid_ref_shape != fig_grid_ref_shape(fig1):
78151
raise ValueError(
@@ -148,7 +221,9 @@ def secondary_y_combine_example():
148221
return fig
149222

150223

151-
fig_simple = simple_combine_example()
152-
fig_secondary_y = secondary_y_combine_example()
153-
fig_simple.show()
154-
fig_secondary_y.show()
224+
if __name__ == "__main__":
225+
fig_simple = simple_combine_example()
226+
fig_secondary_y = secondary_y_combine_example()
227+
fig_simple.show()
228+
fig_secondary_y.show()
229+
fig_secondary_y.write_json("/tmp/fig.json")
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
#/bin/bash
2+
PYTHONPATH=proto/px_combine python3 proto/px_combine/multilayered_data_test.py

proto/px_combine/test_data.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import numpy as np
2+
import plotly.express as px
3+
import pandas as pd
4+
from random import sample
5+
from itertools import product
6+
from functools import reduce
7+
8+
# some made up data for demos
9+
10+
11+
def words(remove_non_letters=True):
12+
with open("/usr/share/dict/british-english", "r") as fd:
13+
ws = fd.readlines()
14+
return [w.strip().replace("'s", "") for w in ws]
15+
16+
17+
def aug_tips():
18+
""" The tips data buf with "calories consumed". """
19+
tips = px.data.tips()
20+
calories = np.clip(
21+
tips["total_bill"] * 30 + np.random.standard_normal(tips.shape[0]) * 100,
22+
100,
23+
None,
24+
)
25+
tips["calories_consumed"] = calories
26+
return tips
27+
28+
29+
def take(it, N):
30+
return [next(it) for n in range(N)]
31+
32+
33+
def multilayered_data(
34+
N=20, d_divs=[2, 3, 4], rseed=np.random.RandomState(seed=2), rwalk=0.1
35+
):
36+
"""
37+
Generate data that can be faceted in len(d_divs) ways (e.g., row, col and
38+
trace color/linestyle. etc.)
39+
"""
40+
ws = words()
41+
tot_divs = np.cumprod(d_divs)[-1]
42+
sample_i = np.arange(len(ws), dtype="int")
43+
rseed.shuffle(sample_i)
44+
names = iter(ws[i] for i in sample_i[: tot_divs + len(d_divs)])
45+
x = np.arange(N)
46+
cat_div_names = []
47+
for div in d_divs:
48+
# generate category names
49+
div_names = [next(names) for _ in range(div)]
50+
cat_div_names.append(div_names)
51+
cat_names = [next(names) for _ in d_divs]
52+
dfs = []
53+
for cat_combo in product(*cat_div_names):
54+
d = dict()
55+
for cat_name, c in zip(cat_names, cat_combo):
56+
d[cat_name] = c
57+
d["x"] = x
58+
if rwalk is not None:
59+
y = np.cumsum(rseed.standard_normal(N)) * rwalk
60+
else:
61+
y = rseed.standard_normal(N)
62+
d["y"] = y
63+
dfs.append(pd.DataFrame(d))
64+
# combine all the dicts
65+
df = reduce(lambda a, b: pd.concat([a, b]), dfs, pd.DataFrame())
66+
return df

proto/px_combine_proto/test_data.py

Lines changed: 0 additions & 17 deletions
This file was deleted.

0 commit comments

Comments
 (0)