@@ -595,7 +595,26 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
595
595
logger .info ("extracting archive file {} to temp dir {}" .format (
596
596
resolved_archive_file , tempdir ))
597
597
with tarfile .open (resolved_archive_file , 'r:gz' ) as archive :
598
- archive .extractall (tempdir )
598
+ def is_within_directory (directory , target ):
599
+
600
+ abs_directory = os .path .abspath (directory )
601
+ abs_target = os .path .abspath (target )
602
+
603
+ prefix = os .path .commonprefix ([abs_directory , abs_target ])
604
+
605
+ return prefix == abs_directory
606
+
607
+ def safe_extract (tar , path = "." , members = None , * , numeric_owner = False ):
608
+
609
+ for member in tar .getmembers ():
610
+ member_path = os .path .join (path , member .name )
611
+ if not is_within_directory (path , member_path ):
612
+ raise Exception ("Attempted Path Traversal in Tar File" )
613
+
614
+ tar .extractall (path , members , numeric_owner = numeric_owner )
615
+
616
+
617
+ safe_extract (archive , tempdir )
599
618
serialization_dir = tempdir
600
619
# Load config
601
620
config_file = os .path .join (serialization_dir , CONFIG_NAME )
0 commit comments