|
1 | 1 | from collections import OrderedDict |
| 2 | +from collections.abc import Iterable |
2 | 3 | from contextlib import ExitStack |
3 | 4 | import functools |
4 | 5 | import inspect |
@@ -446,8 +447,22 @@ def _plot_args(self, tup, kwargs, return_kwargs=False): |
446 | 447 | ncx, ncy = x.shape[1], y.shape[1] |
447 | 448 | if ncx > 1 and ncy > 1 and ncx != ncy: |
448 | 449 | raise ValueError(f"x has {ncx} columns but y has {ncy} columns") |
| 450 | + |
| 451 | + if ('label' in kwargs and isinstance(kwargs['label'], Iterable) |
| 452 | + and not isinstance(kwargs['label'], str)): |
| 453 | + if len(kwargs['label']) != max(ncx, ncy): |
| 454 | + raise ValueError(f"if label is iterable label and input data" |
| 455 | + f" must have same length, but have lengths " |
| 456 | + f"{len(kwargs['label'])} and " |
| 457 | + f"{max(ncx, ncy)}") |
| 458 | + |
| 459 | + result = (func(x[:, j % ncx], y[:, j % ncy], kw, |
| 460 | + {**kwargs, 'label':kwargs['label'][j]}) |
| 461 | + for j in range(max(ncx, ncy))) |
| 462 | + |
449 | 463 | result = (func(x[:, j % ncx], y[:, j % ncy], kw, kwargs) |
450 | | - for j in range(max(ncx, ncy))) |
| 464 | + for j in range(max(ncx, ncy))) |
| 465 | + |
451 | 466 | if return_kwargs: |
452 | 467 | return list(result) |
453 | 468 | else: |
|
0 commit comments