@@ -13,6 +13,7 @@ def plot_recall(
1313    x_absolute = False ,
1414    y_absolute = False ,
1515    show_random = True ,
16+     show_perfect = True ,
1617    show_legend = True ,
1718    legend_values = None ,
1819    legend_kwargs = None ,
@@ -34,6 +35,8 @@ def plot_recall(
3435        If False, the fraction of all included records found is on the y-axis. 
3536    show_random: bool 
3637        Show the random curve in the plot. 
38+     show_perfect: bool 
39+         Show the perfect curve in the plot. 
3740    show_legend: bool 
3841        If state_obj contains multiple states, show a legend in the plot. 
3942    legend_values: list[str] 
@@ -61,6 +64,7 @@ def plot_recall(
6164        x_absolute = x_absolute ,
6265        y_absolute = y_absolute ,
6366        show_random = show_random ,
67+         show_perfect = show_perfect ,
6468        show_legend = show_legend ,
6569        legend_values = legend_values ,
6670        legend_kwargs = legend_kwargs ,
@@ -237,6 +241,7 @@ def _plot_recall(
237241    x_absolute = False ,
238242    y_absolute = False ,
239243    show_random = True ,
244+     show_perfect = True ,
240245    show_legend = True ,
241246    legend_values = None ,
242247    legend_kwargs = None ,
@@ -258,7 +263,10 @@ def _plot_recall(
258263    ax  =  _add_recall_info (ax , labels , x_absolute , y_absolute )
259264
260265    if  show_random :
261-         ax  =  _add_random_curve (ax , labels , x_absolute , y_absolute )
266+         ax  =  _add_random_curve (ax , labels , x_absolute , y_absolute )    
267+ 
268+     if  show_perfect :
269+         ax  =  _add_perfect_curve (ax , labels , x_absolute , y_absolute )
262270
263271    if  show_legend :
264272        if  legend_kwargs  is  None :
@@ -398,6 +406,33 @@ def _add_random_curve(ax, labels, x_absolute, y_absolute):
398406    return  ax 
399407
400408
409+ def  _add_perfect_curve (ax , labels , x_absolute , y_absolute ):
410+     """Add a perfect curve to a plot using step-wise increments. 
411+ 
412+     Returns 
413+     ------- 
414+     plt.axes.Axes 
415+         Axes with perfect curve added. 
416+     """ 
417+     # get total amount of positive labels 
418+     if  isinstance (labels [0 ], list ):
419+         n_pos_docs  =  max (sum (label_set ) for  label_set  in  labels )
420+         n_docs  =  max (len (label_set ) for  label_set  in  labels )
421+     else :
422+         n_pos_docs  =  sum (labels )
423+         n_docs  =  len (labels )
424+ 
425+     # Create x and y arrays for step plot 
426+     x  =  np .arange (0 , n_pos_docs  +  1 ) if  x_absolute  else  np .arange (0 , n_pos_docs  +  1 ) /  n_docs   # noqa: E501 
427+     y  =  np .arange (0 , n_pos_docs  +  1 ) if  y_absolute  else  np .arange (0 , n_pos_docs  +  1 ) /  n_pos_docs   # noqa: E501 
428+ 
429+     # Plot the stepwise perfect curve 
430+     ax .step (x , y , color = "grey" , where = "post" )
431+ 
432+     return  ax 
433+ 
434+ 
435+ 
401436def  _add_wss_curve (ax , labels , x_absolute = False , y_absolute = False , legend_label = None ):
402437    x , y  =  _wss_values (labels , x_absolute = x_absolute , y_absolute = y_absolute )
403438    ax .step (x , y , where = "post" , label = legend_label )
0 commit comments