@@ -671,3 +671,65 @@ def test_no_warn_big_data_when_loc_specified():
671671 ax .plot (np .arange (5000 ), label = idx )
672672 legend = ax .legend ('best' )
673673 fig .draw_artist (legend ) # Check that no warning is emitted.
674+
675+
676+ @pytest .mark .parametrize ('label_array' , [['low' , 'high' ],
677+ ('low' , 'high' ),
678+ np .array (['low' , 'high' ])])
679+ def test_plot_multiple_input_multiple_label (label_array ):
680+ # test ax.plot() with multidimensional input
681+ # and multiple labels
682+ x = [1 , 2 , 3 ]
683+ y = [[1 , 2 ],
684+ [2 , 5 ],
685+ [4 , 9 ]]
686+
687+ fig , ax = plt .subplots ()
688+ ax .plot (x , y , label = label_array )
689+ leg = ax .legend ()
690+ legend_texts = [entry .get_text () for entry in leg .get_texts ()]
691+ assert legend_texts == ['low' , 'high' ]
692+
693+
694+ @pytest .mark .parametrize ('label' , ['one' , 1 , int ])
695+ def test_plot_multiple_input_single_label (label ):
696+ # test ax.plot() with multidimensional input
697+ # and single label
698+ x = [1 , 2 , 3 ]
699+ y = [[1 , 2 ],
700+ [2 , 5 ],
701+ [4 , 9 ]]
702+
703+ fig , ax = plt .subplots ()
704+ ax .plot (x , y , label = label )
705+ leg = ax .legend ()
706+ legend_texts = [entry .get_text () for entry in leg .get_texts ()]
707+ assert legend_texts == [str (label )] * 2
708+
709+
710+ @pytest .mark .parametrize ('label_array' , [['low' , 'high' ],
711+ ('low' , 'high' ),
712+ np .array (['low' , 'high' ])])
713+ def test_plot_single_input_multiple_label (label_array ):
714+ # test ax.plot() with 1D array like input
715+ # and iterable label
716+ x = [1 , 2 , 3 ]
717+ y = [2 , 5 , 6 ]
718+ fig , ax = plt .subplots ()
719+ ax .plot (x , y , label = label_array )
720+ leg = ax .legend ()
721+ assert len (leg .get_texts ()) == 1
722+ assert leg .get_texts ()[0 ].get_text () == str (label_array )
723+
724+
725+ def test_plot_multiple_label_incorrect_length_exception ():
726+ # check that excepton is raised if multiple labels
727+ # are given, but number of on labels != number of lines
728+ with pytest .raises (ValueError ):
729+ x = [1 , 2 , 3 ]
730+ y = [[1 , 2 ],
731+ [2 , 5 ],
732+ [4 , 9 ]]
733+ label = ['high' , 'low' , 'medium' ]
734+ fig , ax = plt .subplots ()
735+ ax .plot (x , y , label = label )
0 commit comments