diff --git a/optional-requirements.txt b/optional-requirements.txt index ea95c4fc9fb..2f609e087d0 100644 --- a/optional-requirements.txt +++ b/optional-requirements.txt @@ -19,3 +19,6 @@ ipython[all] ## pandas deps for some matplotlib functionality ## pandas + +## scipy deps for some FigureFactory functions ## +scipy diff --git a/plotly/tests/test_core/test_tools/test_figure_factory.py b/plotly/tests/test_core/test_tools/test_figure_factory.py index 658a48cecfc..ff90b0d2cea 100644 --- a/plotly/tests/test_core/test_tools/test_figure_factory.py +++ b/plotly/tests/test_core/test_tools/test_figure_factory.py @@ -686,3 +686,22 @@ def test_datetime_candlestick(self): self.assertEqual(candle, exp_candle) + +# class TestDistplot(TestCase): + +# def test_scipy_import_error(self): + +# # make sure Import Error is raised when _scipy_imported = False + +# hist_data = [[1.1, 1.1, 2.5, 3.0, 3.5, +# 3.5, 4.1, 4.4, 4.5, 4.5, +# 5.0, 5.0, 5.2, 5.5, 5.5, +# 5.5, 5.5, 5.5, 6.1, 7.0]] + +# group_labels = ['distplot example'] + +# self.assertRaisesRegexp(ImportError, +# "FigureFactory.create_distplot requires scipy", +# tls.FigureFactory.create_distplot, +# hist_data, group_labels) + diff --git a/plotly/tests/test_optional/test_opt_tracefactory.py b/plotly/tests/test_optional/test_opt_tracefactory.py index 2c87c20c232..b6de29eb8e0 100644 --- a/plotly/tests/test_optional/test_opt_tracefactory.py +++ b/plotly/tests/test_optional/test_opt_tracefactory.py @@ -9,6 +9,146 @@ import numpy as np +class TestDistplot(TestCase): + + def test_wrong_curve_type(self): + + # check: PlotlyError (and specific message) is raised if curve_type is + # not 'kde' or 'normal' + + kwargs = {'hist_data': [[1, 2, 3]], 'group_labels': ['group'], + 'curve_type': 'curve'} + self.assertRaisesRegexp(PlotlyError, "curve_type must be defined as " + "'kde' or 'normal'", + tls.FigureFactory.create_distplot, **kwargs) + + def test_wrong_histdata_format(self): + + # check: PlotlyError if hist_data is not a list of lists or list of + # np.ndarrays (if hist_data is entered as just a list the function + # will fail) + + kwargs = {'hist_data': [1, 2, 3], 'group_labels': ['group']} + self.assertRaises(PlotlyError, tls.FigureFactory.create_distplot, + **kwargs) + + def test_unequal_data_label_length(self): + kwargs = {'hist_data': [[1, 2]], 'group_labels': ['group', 'group2']} + self.assertRaises(PlotlyError, tls.FigureFactory.create_distplot, + **kwargs) + + kwargs = {'hist_data': [[1, 2], [1, 2, 3]], 'group_labels': ['group']} + self.assertRaises(PlotlyError, tls.FigureFactory.create_distplot, + **kwargs) + + def test_simple_distplot(self): + + # we should be able to create a single distplot with a simple dataset + # and default kwargs + + dp = tls.FigureFactory.create_distplot(hist_data=[[1, 2, 2, 3]], + group_labels=['distplot']) + expected_dp_layout = {'barmode': 'overlay', + 'hovermode': 'closest', + 'legend': {'traceorder': 'reversed'}, + 'xaxis1': {'anchor': 'y2', 'domain': [0.0, 1.0], 'zeroline': False}, + 'yaxis1': {'anchor': 'free', 'domain': [0.35, 1], 'position': 0.0}, + 'yaxis2': {'anchor': 'x1', + 'domain': [0, 0.25], + 'dtick': 1, + 'showticklabels': False}} + self.assertEqual(dp['layout'], expected_dp_layout) + + expected_dp_data_hist = {'autobinx': False, + 'histnorm': 'probability', + 'legendgroup': 'distplot', + 'marker': {'color': 'rgb(31, 119, 180)'}, + 'name': 'distplot', + 'opacity': 0.7, + 'type': 'histogram', + 'x': [1, 2, 2, 3], + 'xaxis': 'x1', + 'xbins': {'end': 3.0, 'size': 1.0, 'start': 1.0}, + 'yaxis': 'y1'} + self.assertEqual(dp['data'][0], expected_dp_data_hist) + + expected_dp_data_rug = {'legendgroup': 'distplot', + 'marker': {'color': 'rgb(31, 119, 180)', + 'symbol': 'line-ns-open'}, + 'mode': 'markers', + 'name': 'distplot', + 'showlegend': False, + 'text': None, + 'type': 'scatter', + 'x': [1, 2, 2, 3], + 'xaxis': 'x1', + 'y': ['distplot', 'distplot', + 'distplot', 'distplot'], + 'yaxis': 'y2'} + self.assertEqual(dp['data'][2], expected_dp_data_rug) + + def test_distplot_more_args(self): + + # we should be able to create a distplot with 2 datasets no + # rugplot, defined bin_size, and added title + + hist1_x = [0.8, 1.2, 0.2, 0.6, 1.6, + -0.9, -0.07, 1.95, 0.9, -0.2, + -0.5, 0.3, 0.4, -0.37, 0.6] + hist2_x = [0.8, 1.5, 1.5, 0.6, 0.59, + 1.0, 0.8, 1.7, 0.5, 0.8, + -0.3, 1.2, 0.56, 0.3, 2.2] + + hist_data = [hist1_x] + [hist2_x] + group_labels = ['2012', '2013'] + + dp = tls.FigureFactory.create_distplot(hist_data, group_labels, + show_rug=False, bin_size=.2) + dp['layout'].update(title='Dist Plot') + + expected_dp_layout = {'barmode': 'overlay', + 'hovermode': 'closest', + 'legend': {'traceorder': 'reversed'}, + 'title': 'Dist Plot', + 'xaxis1': {'anchor': 'y2', 'domain': [0.0, 1.0], + 'zeroline': False}, + 'yaxis1': {'anchor': 'free', 'domain': [0.0, 1], + 'position': 0.0}} + self.assertEqual(dp['layout'], expected_dp_layout) + + expected_dp_data_hist_1 = {'autobinx': False, + 'histnorm': 'probability', + 'legendgroup': '2012', + 'marker': {'color': 'rgb(31, 119, 180)'}, + 'name': '2012', + 'opacity': 0.7, + 'type': 'histogram', + 'x': [0.8, 1.2, 0.2, 0.6, 1.6, -0.9, -0.07, + 1.95, 0.9, -0.2, -0.5, 0.3, 0.4, + -0.37, 0.6], + 'xaxis': 'x1', + 'xbins': {'end': 1.95, 'size': 0.2, + 'start': -0.9}, + 'yaxis': 'y1'} + self.assertEqual(dp['data'][0], expected_dp_data_hist_1) + + expected_dp_data_hist_2 = {'autobinx': False, + 'histnorm': 'probability', + 'legendgroup': '2013', + 'marker': {'color': 'rgb(255, 127, 14)'}, + 'name': '2013', + 'opacity': 0.7, + 'type': 'histogram', + 'x': [0.8, 1.5, 1.5, 0.6, 0.59, 1.0, 0.8, + 1.7, 0.5, 0.8, -0.3, 1.2, 0.56, 0.3, + 2.2], + 'xaxis': 'x1', + 'xbins': {'end': 2.2, 'size': 0.2, + 'start': -0.3}, + 'yaxis': 'y1'} + self.assertEqual(dp['data'][1], expected_dp_data_hist_2) + + class TestStreamline(TestCase): def test_wrong_arrow_scale(self): diff --git a/plotly/tools.py b/plotly/tools.py index dc51a0a9ac5..bb175bba7ec 100644 --- a/plotly/tools.py +++ b/plotly/tools.py @@ -51,6 +51,13 @@ def warning_on_one_line(message, category, filename, lineno, except ImportError: _numpy_imported = False +try: + import scipy + import scipy.stats + _scipy_imported = True +except ImportError: + _scipy_imported = False + PLOTLY_DIR = os.path.join(os.path.expanduser("~"), ".plotly") CREDENTIALS_FILE = os.path.join(PLOTLY_DIR, ".credentials") CONFIG_FILE = os.path.join(PLOTLY_DIR, ".config") @@ -1460,7 +1467,7 @@ class FigureFactory(object): """ @staticmethod - def validate_equal_length(*args): + def _validate_equal_length(*args): """ Validates that data lists or ndarrays are the same length. @@ -1472,7 +1479,7 @@ def validate_equal_length(*args): "should be the same length.") @staticmethod - def validate_ohlc(open, high, low, close, direction, **kwargs): + def _validate_ohlc(open, high, low, close, direction, **kwargs): """ ohlc and candlestick specific validations @@ -1515,7 +1522,44 @@ def validate_ohlc(open, high, low, close, direction, **kwargs): "'both'") @staticmethod - def validate_positive_scalars(**kwargs): + def _validate_distplot(hist_data, curve_type): + """ + distplot specific validations + + :raises: (PlotlyError) If hist_data is not a list of lists + :raises: (PlotlyError) If curve_type is not valid (i.e. not 'kde' or + 'normal'). + """ + try: + import pandas as pd + _pandas_imported = True + except ImportError: + _pandas_imported = False + + hist_data_types = (list,) + if _numpy_imported: + hist_data_types += (np.ndarray,) + if _pandas_imported: + hist_data_types += (pd.core.series.Series,) + + if not isinstance(hist_data[0], hist_data_types): + raise exceptions.PlotlyError("Oops, this function was written " + "to handle multiple datasets, if " + "you want to plot just one, make " + "sure your hist_data variable is " + "still a list of lists, i.e. x = " + "[1, 2, 3] -> x = [[1, 2, 3]]") + + curve_opts = ('kde', 'normal') + if curve_type not in curve_opts: + raise exceptions.PlotlyError("curve_type must be defined as " + "'kde' or 'normal'") + + if _scipy_imported is False: + raise ImportError("FigureFactory.create_distplot requires scipy") + + @staticmethod + def _validate_positive_scalars(**kwargs): """ Validates that all values given in key/val pairs are positive. @@ -1532,7 +1576,7 @@ def validate_positive_scalars(**kwargs): .format(key, val)) @staticmethod - def validate_streamline(x, y): + def _validate_streamline(x, y): """ streamline specific validations @@ -1558,7 +1602,7 @@ def validate_streamline(x, y): "evenly spaced array") @staticmethod - def flatten(array): + def _flatten(array): """ Uses list comprehension to flatten array @@ -1656,9 +1700,9 @@ def create_quiver(x, y, u, v, scale=.1, arrow_scale=.3, py.plot(fig, filename='quiver') ``` """ - FigureFactory.validate_equal_length(x, y, u, v) - FigureFactory.validate_positive_scalars(arrow_scale=arrow_scale, - scale=scale) + FigureFactory._validate_equal_length(x, y, u, v) + FigureFactory._validate_positive_scalars(arrow_scale=arrow_scale, + scale=scale) barb_x, barb_y = _Quiver(x, y, u, v, scale, arrow_scale, angle).get_barbs() @@ -1757,10 +1801,10 @@ def create_streamline(x, y, u, v, py.plot(fig, filename='streamline') ``` """ - FigureFactory.validate_equal_length(x, y) - FigureFactory.validate_equal_length(u, v) - FigureFactory.validate_streamline(x, y) - FigureFactory.validate_positive_scalars(density=density, + FigureFactory._validate_equal_length(x, y) + FigureFactory._validate_equal_length(u, v) + FigureFactory._validate_streamline(x, y) + FigureFactory._validate_positive_scalars(density=density, arrow_scale=arrow_scale) streamline_x, streamline_y = _Streamline(x, y, u, v, @@ -1984,10 +2028,10 @@ def create_ohlc(open, high, low, close, ``` """ if dates is not None: - FigureFactory.validate_equal_length(open, high, low, close, dates) + FigureFactory._validate_equal_length(open, high, low, close, dates) else: - FigureFactory.validate_equal_length(open, high, low, close) - FigureFactory.validate_ohlc(open, high, low, close, direction, + FigureFactory._validate_equal_length(open, high, low, close) + FigureFactory._validate_ohlc(open, high, low, close, direction, **kwargs) if direction is 'increasing': @@ -2214,10 +2258,10 @@ def create_candlestick(open, high, low, close, ``` """ if dates is not None: - FigureFactory.validate_equal_length(open, high, low, close, dates) + FigureFactory._validate_equal_length(open, high, low, close, dates) else: - FigureFactory.validate_equal_length(open, high, low, close) - FigureFactory.validate_ohlc(open, high, low, close, direction, + FigureFactory._validate_equal_length(open, high, low, close) + FigureFactory._validate_ohlc(open, high, low, close, direction, **kwargs) if direction is 'increasing': @@ -2238,6 +2282,186 @@ def create_candlestick(open, high, low, close, layout = graph_objs.Layout() return dict(data=data, layout=layout) + @staticmethod + def create_distplot(hist_data, group_labels, + bin_size=1., curve_type='kde', + colors=[], rug_text=[], + show_hist=True, show_curve=True, + show_rug=True): + """ + BETA function that creates a distplot similar to seaborn.distplot + + The distplot can be composed of all or any combination of the following + 3 components: (1) histogram, (2) curve: (a) kernal density estimation + or (b) normal curve, and (3) rug plot. Additionally, multiple distplots + (from multiple datasets) can be created in the same plot. + + :param (list[list]) hist_data: Use list of lists to plot multiple data + sets on the same plot. + :param (list[str]) group_labels: Names for each data set. + :param (float) bin_size: Size of histogram bins. Default = 1. + :param (str) curve_type: 'kde' or 'normal'. Default = 'kde' + :param (bool) show_hist: Add histogram to distplot? Default = True + :param (bool) show_curve: Add curve to distplot? Default = True + :param (bool) show_rug: Add rug to distplot? Default = True + :param (list[str]) colors: Colors for traces. + :param (list[list]) rug_text: Hovertext values for rug_plot, + :return (dict): Representation of a distplot figure. + + Example 1: Simple distplot of 1 data set + ``` + import plotly.plotly as py + from plotly.tools import FigureFactory as FF + + hist_data = [[1.1, 1.1, 2.5, 3.0, 3.5, + 3.5, 4.1, 4.4, 4.5, 4.5, + 5.0, 5.0, 5.2, 5.5, 5.5, + 5.5, 5.5, 5.5, 6.1, 7.0]] + + group_labels = ['distplot example'] + + fig = FF.create_distplot(hist_data, group_labels) + + url = py.plot(fig, filename='Simple distplot', validate=False) + ``` + + Example 2: Two data sets and added rug text + ``` + import plotly.plotly as py + from plotly.tools import FigureFactory as FF + + # Add histogram data + hist1_x = [0.8, 1.2, 0.2, 0.6, 1.6, + -0.9, -0.07, 1.95, 0.9, -0.2, + -0.5, 0.3, 0.4, -0.37, 0.6] + hist2_x = [0.8, 1.5, 1.5, 0.6, 0.59, + 1.0, 0.8, 1.7, 0.5, 0.8, + -0.3, 1.2, 0.56, 0.3, 2.2] + + # Group data together + hist_data = [hist1_x, hist2_x] + + group_labels = ['2012', '2013'] + + # Add text + rug_text_1 = ['a1', 'b1', 'c1', 'd1', 'e1', + 'f1', 'g1', 'h1', 'i1', 'j1', + 'k1', 'l1', 'm1', 'n1', 'o1'] + + rug_text_2 = ['a2', 'b2', 'c2', 'd2', 'e2', + 'f2', 'g2', 'h2', 'i2', 'j2', + 'k2', 'l2', 'm2', 'n2', 'o2'] + + # Group text together + rug_text_all = [rug_text_1, rug_text_2] + + # Create distplot + fig = FF.create_distplot( + hist_data, group_labels, rug_text=rug_text_all, bin_size=.2) + + # Add title + fig['layout'].update(title='Dist Plot') + + # Plot! + url = py.plot(fig, filename='Distplot with rug text', validate=False) + ``` + + Example 3: Plot with normal curve and hide rug plot + ``` + import plotly.plotly as py + from plotly.tools import FigureFactory as FF + import numpy as np + + x1 = np.random.randn(190) + x2 = np.random.randn(200)+1 + x3 = np.random.randn(200)-1 + x4 = np.random.randn(210)+2 + + hist_data = [x1, x2, x3, x4] + group_labels = ['2012', '2013', '2014', '2015'] + + fig = FF.create_distplot( + hist_data, group_labels, curve_type='normal', + show_rug=False, bin_size=.4) + + url = py.plot(fig, filename='hist and normal curve', validate=False) + + Example 4: Distplot with Pandas + ``` + import plotly.plotly as py + from plotly.tools import FigureFactory as FF + import numpy as np + import pandas as pd + + df = pd.DataFrame({'2012': np.random.randn(200), + '2013': np.random.randn(200)+1}) + py.iplot(FF.create_distplot([df[c] for c in df.columns], df.columns), + filename='examples/distplot with pandas', + validate=False) + ``` + """ + FigureFactory._validate_distplot(hist_data, curve_type) + FigureFactory._validate_equal_length(hist_data, group_labels) + + hist = _Distplot( + hist_data, group_labels, bin_size, + curve_type, colors, rug_text, + show_hist, show_curve).make_hist() + + if curve_type == 'normal': + curve = _Distplot( + hist_data, group_labels, bin_size, + curve_type, colors, rug_text, + show_hist, show_curve).make_normal() + else: + curve = _Distplot( + hist_data, group_labels, bin_size, + curve_type, colors, rug_text, + show_hist, show_curve).make_kde() + + rug = _Distplot( + hist_data, group_labels, bin_size, + curve_type, colors, rug_text, + show_hist, show_curve).make_rug() + + data = [] + if show_hist: + data.append(hist) + if show_curve: + data.append(curve) + if show_rug: + data.append(rug) + layout = graph_objs.Layout( + barmode='overlay', + hovermode='closest', + legend=dict(traceorder='reversed'), + xaxis1=dict(domain=[0.0, 1.0], + anchor='y2', + zeroline=False), + yaxis1=dict(domain=[0.35, 1], + anchor='free', + position=0.0), + yaxis2=dict(domain=[0, 0.25], + anchor='x1', + dtick=1, + showticklabels=False)) + else: + layout = graph_objs.Layout( + barmode='overlay', + hovermode='closest', + legend=dict(traceorder='reversed'), + xaxis1=dict(domain=[0.0, 1.0], + anchor='y2', + zeroline=False), + yaxis1=dict(domain=[0., 1], + anchor='free', + position=0.0)) + + data = sum(data, []) + dist_fig = dict(data=data, layout=layout) + + return dist_fig + class _Quiver(FigureFactory): """ @@ -2246,22 +2470,22 @@ class _Quiver(FigureFactory): def __init__(self, x, y, u, v, scale, arrow_scale, angle, **kwargs): try: - x = FigureFactory.flatten(x) + x = FigureFactory._flatten(x) except exceptions.PlotlyError: pass try: - y = FigureFactory.flatten(y) + y = FigureFactory._flatten(y) except exceptions.PlotlyError: pass try: - u = FigureFactory.flatten(u) + u = FigureFactory._flatten(u) except exceptions.PlotlyError: pass try: - v = FigureFactory.flatten(v) + v = FigureFactory._flatten(v) except exceptions.PlotlyError: pass @@ -2305,8 +2529,8 @@ def get_barbs(self): self.end_x = [i + j for i, j in zip(self.x, self.u)] self.end_y = [i + j for i, j in zip(self.y, self.v)] empty = [None] * len(self.x) - barb_x = self.flatten(zip(self.x, self.end_x, empty)) - barb_y = self.flatten(zip(self.y, self.end_y, empty)) + barb_x = FigureFactory._flatten(zip(self.x, self.end_x, empty)) + barb_y = FigureFactory._flatten(zip(self.y, self.end_y, empty)) return barb_x, barb_y def get_quiver_arrows(self): @@ -2376,8 +2600,10 @@ def get_quiver_arrows(self): # Combine lists to create arrow empty = [None] * len(self.end_x) - arrow_x = self.flatten(zip(point1_x, self.end_x, point2_x, empty)) - arrow_y = self.flatten(zip(point1_y, self.end_y, point2_y, empty)) + arrow_x = FigureFactory._flatten(zip(point1_x, self.end_x, + point2_x, empty)) + arrow_y = FigureFactory._flatten(zip(point1_y, self.end_y, + point2_y, empty)) return arrow_x, arrow_y @@ -2725,8 +2951,8 @@ def get_increase(self): trace, flat_increase_y: y=values for the increasing trace and text_increase: hovertext for the increasing trace """ - flat_increase_x = FigureFactory.flatten(self.increase_x) - flat_increase_y = FigureFactory.flatten(self.increase_y) + flat_increase_x = FigureFactory._flatten(self.increase_x) + flat_increase_y = FigureFactory._flatten(self.increase_y) text_increase = (("Open", "Open", "High", "Low", "Close", "Close", '') * (len(self.increase_x))) @@ -2741,8 +2967,8 @@ def get_decrease(self): trace, flat_decrease_y: y=values for the decreasing trace and text_decrease: hovertext for the decreasing trace """ - flat_decrease_x = FigureFactory.flatten(self.decrease_x) - flat_decrease_y = FigureFactory.flatten(self.decrease_y) + flat_decrease_x = FigureFactory._flatten(self.decrease_x) + flat_decrease_y = FigureFactory._flatten(self.decrease_y) text_decrease = (("Open", "Open", "High", "Low", "Close", "Close", '') * (len(self.decrease_x))) @@ -2785,7 +3011,7 @@ def get_candle_increase(self): increase_x.append(self.x[index]) increase_x = [[x, x, x, x, x, x] for x in increase_x] - increase_x = FigureFactory.flatten(increase_x) + increase_x = FigureFactory._flatten(increase_x) return increase_x, increase_y @@ -2809,7 +3035,159 @@ def get_candle_decrease(self): decrease_x.append(self.x[index]) decrease_x = [[x, x, x, x, x, x] for x in decrease_x] - decrease_x = FigureFactory.flatten(decrease_x) + decrease_x = FigureFactory._flatten(decrease_x) return decrease_x, decrease_y + +class _Distplot(FigureFactory): + """ + Refer to TraceFactory.create_distplot() for docstring + """ + def __init__(self, hist_data, group_labels, + bin_size, curve_type, colors, + rug_text, show_hist, show_curve): + self.hist_data = hist_data + self.group_labels = group_labels + self.bin_size = bin_size + self.show_hist = show_hist + self.show_curve = show_curve + self.trace_number = len(hist_data) + if rug_text: + self.rug_text = rug_text + else: + self.rug_text = [None] * self.trace_number + + self.start = [] + self.end = [] + if colors: + self.colors = colors + else: + self.colors = [ + "rgb(31, 119, 180)", "rgb(255, 127, 14)", + "rgb(44, 160, 44)", "rgb(214, 39, 40)", + "rgb(148, 103, 189)", "rgb(140, 86, 75)", + "rgb(227, 119, 194)", "rgb(127, 127, 127)", + "rgb(188, 189, 34)", "rgb(23, 190, 207)"] + self.curve_x = [None] * self.trace_number + self.curve_y = [None] * self.trace_number + + for trace in self.hist_data: + self.start.append(min(trace) * 1.) + self.end.append(max(trace) * 1.) + + def make_hist(self): + """ + Makes the histogram(s) for FigureFactory.create_distplot(). + + :rtype (list) hist: list of histogram representations + """ + hist = [None] * self.trace_number + + for index in range(self.trace_number): + hist[index] = dict(type='histogram', + x=self.hist_data[index], + xaxis='x1', + yaxis='y1', + histnorm='probability', + name=self.group_labels[index], + legendgroup=self.group_labels[index], + marker=dict(color=self.colors[index]), + autobinx=False, + xbins=dict(start=self.start[index], + end=self.end[index], + size=self.bin_size), + opacity=.7) + return hist + + def make_kde(self): + """ + Makes the kernal density estimation(s) for create_distplot(). + + This is called when curve_type = 'kde' in create_distplot(). + + :rtype (list) curve: list of kde representations + """ + curve = [None] * self.trace_number + for index in range(self.trace_number): + self.curve_x[index] = [self.start[index] + + x * (self.end[index] - self.start[index]) + / 500 for x in range(500)] + self.curve_y[index] = (scipy.stats.gaussian_kde + (self.hist_data[index]) + (self.curve_x[index])) + self.curve_y[index] *= self.bin_size + + for index in range(self.trace_number): + curve[index] = dict(type='scatter', + x=self.curve_x[index], + y=self.curve_y[index], + xaxis='x1', + yaxis='y1', + mode='lines', + name=self.group_labels[index], + legendgroup=self.group_labels[index], + showlegend=False if self.show_hist else True, + marker=dict(color=self.colors[index])) + return curve + + def make_normal(self): + """ + Makes the normal curve(s) for create_distplot(). + + This is called when curve_type = 'normal' in create_distplot(). + + :rtype (list) curve: list of normal curve representations + """ + curve = [None] * self.trace_number + mean = [None] * self.trace_number + sd = [None] * self.trace_number + + for index in range(self.trace_number): + mean[index], sd[index] = (scipy.stats.norm.fit + (self.hist_data[index])) + self.curve_x[index] = [self.start[index] + + x * (self.end[index] - self.start[index]) + / 500 for x in range(500)] + self.curve_y[index] = scipy.stats.norm.pdf( + self.curve_x[index], loc=mean[index], scale=sd[index]) + self.curve_y[index] *= self.bin_size + + for index in range(self.trace_number): + curve[index] = dict(type='scatter', + x=self.curve_x[index], + y=self.curve_y[index], + xaxis='x1', + yaxis='y1', + mode='lines', + name=self.group_labels[index], + legendgroup=self.group_labels[index], + showlegend=False if self.show_hist else True, + marker=dict(color=self.colors[index])) + return curve + + def make_rug(self): + """ + Makes the rug plot(s) for create_distplot(). + + :rtype (list) rug: list of rug plot representations + """ + rug = [None] * self.trace_number + for index in range(self.trace_number): + + rug[index] = dict(type='scatter', + x=self.hist_data[index], + y=([self.group_labels[index]] * + len(self.hist_data[index])), + xaxis='x1', + yaxis='y2', + mode='markers', + name=self.group_labels[index], + legendgroup=self.group_labels[index], + showlegend=(False if self.show_hist or + self.show_curve else True), + text=self.rug_text[index], + marker=dict(color=self.colors[index], + symbol='line-ns-open')) + return rug + diff --git a/plotly/version.py b/plotly/version.py index cfe644736a1..fa2822c607a 100644 --- a/plotly/version.py +++ b/plotly/version.py @@ -1 +1 @@ -__version__ = '1.8.3' +__version__ = '1.8.4'