11from collections import defaultdict
22from pathlib import Path
3+ import sys
34
45from matplotlib import pyplot as plt
56import numpy as np
@@ -54,13 +55,15 @@ def create_plot(ax, x_data: list, y_data: list, auc: float, type: str, color) ->
5455
5556def main ():
5657 print ("Generating figures" )
57- species_list = ["elegans" , "fly" , "bsub" , "yeast" , "zfish" ]
58+ species_list = ["elegans" , "fly" , "bsub" , "yeast" , "zfish" , "athaliana" , "ecoli" ]
5859 species_title = [
5960 "C. elegans" ,
6061 "D. melanogaster" ,
6162 "B. subtilis" ,
6263 "S. cerevisiae" ,
6364 "D. rerio" ,
65+ "A. thaliana" ,
66+ "E. coli" ,
6467 ]
6568
6669 file_directories = [
@@ -123,10 +126,10 @@ def main():
123126
124127 # Create a figure with 2 subplots (one for each species)
125128 fig , axes = plt .subplots (
126- 2 , 3 , figsize = (17 , 12 )
129+ 2 , 4 , figsize = (17 , 12 )
127130 ) # Create a 2x3 grid of subplots
128131 axes = axes .flatten ()
129- colors = ["red" , "green" , "blue" , "orange" , "purple" ]
132+ colors = ["red" , "green" , "blue" , "orange" , "purple" , "black" , "yellow" ]
130133
131134 for idx , species in enumerate (species_list ):
132135 ax = axes [idx ] # Get the subplot axis for the current species
@@ -143,25 +146,31 @@ def main():
143146
144147 ax .set_xlim ([0.0 , 1.0 ])
145148 ax .set_ylim ([0.0 , 1.05 ])
146- ax .set_xlabel ("False Positive Rate" , fontsize = 14 )
147- ax .set_ylabel ("True Positive Rate" , fontsize = 14 )
149+ ax .set_xlabel ("False Positive Rate" , fontsize = 14 )
150+ ax .set_ylabel ("True Positive Rate" , fontsize = 14 )
148151 ax .set_title (f"${ species_title [idx ].capitalize ()} $" , fontsize = 14 )
149- ax .legend (loc = "lower right" , fontsize = 12 )
152+ ax .legend (loc = "lower right" , fontsize = 12 )
150153
151- axes [5 ].set_visible (False )
154+ axes [7 ].set_visible (False )
152155 fig .suptitle ("ROC Curve for All Species w/ " + subplot_titles [k ], fontsize = 20 )
153156 # Adjust layout to prevent overlap
154157 plt .tight_layout () # Adjust rect to accommodate legends
155158 # Adjust the space between subplots
156159 plt .subplots_adjust (wspace = 0.2 )
157- plt .savefig (Path ("./results/images/" , f"roc_{ subplot_titles [k ].lower ().replace (" " , "_" )} .pdf" ), format = "pdf" )
160+ plt .savefig (
161+ Path (
162+ "./results/images/" ,
163+ f"roc_{ subplot_titles [k ].lower ().replace (" " , "_" )} .pdf" ,
164+ ),
165+ format = "pdf" ,
166+ )
158167 plt .show ()
159168
160169 fig , axes = plt .subplots (
161- 2 , 3 , figsize = (17 , 12 )
170+ 2 , 4 , figsize = (17 , 12 )
162171 ) # Create a 2x3 grid of subplots
163172 axes = axes .flatten ()
164- colors = ["red" , "green" , "blue" , "orange" , "purple" ]
173+ colors = ["red" , "green" , "blue" , "orange" , "purple" , "black" , "yellow" ]
165174
166175 for idx , species in enumerate (species_list ):
167176 ax = axes [idx ] # Get the subplot axis for the current species
@@ -178,12 +187,12 @@ def main():
178187
179188 ax .set_xlim ([0.0 , 1.0 ])
180189 ax .set_ylim ([0.0 , 1.05 ])
181- ax .set_xlabel ("Recall" , fontsize = 14 )
190+ ax .set_xlabel ("Recall" , fontsize = 14 )
182191 ax .set_ylabel ("Precision" , fontsize = 14 )
183192 ax .set_title (f"${ species_title [idx ].capitalize ()} $" , fontsize = 14 )
184- ax .legend (loc = "lower right" , fontsize = 12 )
193+ ax .legend (loc = "lower right" , fontsize = 12 )
185194
186- axes [5 ].set_visible (False )
195+ axes [7 ].set_visible (False )
187196 fig .suptitle (
188197 "Precision/Recall Curve for All Species w/ " + subplot_titles [k ],
189198 fontsize = 20 ,
@@ -192,13 +201,19 @@ def main():
192201 plt .tight_layout () # Adjust rect to accommodate legends
193202 # Adjust the space between subplots
194203 plt .subplots_adjust (wspace = 0.2 )
195- plt .savefig (Path ("./results/images/" , f"pr_{ subplot_titles [k ].lower ().replace (" " , "_" )} .pdf" ), format = "pdf" )
204+ plt .savefig (
205+ Path (
206+ "./results/images/" ,
207+ f"pr_{ subplot_titles [k ].lower ().replace (" " , "_" )} .pdf" ,
208+ ),
209+ format = "pdf" ,
210+ )
196211 plt .show ()
197212 k += 1
198213
199214 # generate RW figures
200215
201- species_list = ["elegans" , "fly" , "bsub" , "yeast" , "zfish" ]
216+ species_list = ["elegans" , "fly" , "bsub" , "yeast" , "zfish" , "athaliana" , "ecoli" ]
202217 file_directories = [
203218 "./results/final-rw-inferred-regular/" ,
204219 "./results/final-rw-inferred-pro-go/" ,
@@ -237,7 +252,7 @@ def main():
237252 fig , axs = plt .subplots (1 , 4 , figsize = (40 , 12 )) # 2 rows, 2 columns
238253 axs = axs .flatten () # Flatten to easily index the subplots
239254
240- colors = ["red" , "green" , "blue" , "orange" , "purple" ]
255+ colors = ["red" , "green" , "blue" , "orange" , "purple" , "black" , "magenta" ]
241256
242257 # Plot data for each directory on a subplot
243258 for idx , directory in enumerate (file_directories ):
@@ -272,7 +287,7 @@ def main():
272287 fig , axs = plt .subplots (1 , 4 , figsize = (40 , 12 )) # 2 rows, 2 columns
273288 axs = axs .flatten () # Flatten to easily index the subplots
274289
275- colors = ["red" , "green" , "blue" , "orange" , "purple" ]
290+ colors = ["red" , "green" , "blue" , "orange" , "purple" , "black" , "magenta" ]
276291
277292 # Plot data for each directory on a subplot
278293 for idx , directory in enumerate (file_directories ):
0 commit comments