@@ -877,7 +877,7 @@ def __repr__(self):
877
877
878
878
def __str__ (self ):
879
879
return "WWLayer {} {} {} skipped {}" .format (self .layer_id , self .name ,
880
- self .the_type . name , self .skipped )
880
+ self .the_type , self .skipped )
881
881
882
882
883
883
@@ -1358,10 +1358,10 @@ def ww_layer_iter_(self):
1358
1358
is_skipped = self .apply_filters (ww_layer )
1359
1359
is_supported = self .layer_supported (ww_layer )
1360
1360
1361
-
1362
1361
if is_supported and not is_skipped :
1363
1362
yield ww_layer
1364
1363
1364
+
1365
1365
def make_layer_iter_ (self ):
1366
1366
return self .ww_layer_iter_ ()
1367
1367
@@ -1815,7 +1815,6 @@ def infer_model_file_format(model_dir):
1815
1815
fileglob = f"{ model_dir } /*weight*npy"
1816
1816
num_files = len (glob .glob (fileglob ))
1817
1817
if num_files > 0 :
1818
- print ("found ww files" )
1819
1818
format = MODEL_FILE_FORMATS .WW_FLATFILES
1820
1819
return format , fileglob
1821
1820
@@ -2861,6 +2860,10 @@ def analyze(self, model=None, layers=[],
2861
2860
raise Exception (msg )
2862
2861
params = self .normalize_params (params )
2863
2862
2863
+
2864
+ #TODO: ?
2865
+ # if we want deltas, we can just add the delta="model" in
2866
+ #
2864
2867
layer_iterator = self .make_layer_iterator (model = self .model , layers = layers , params = params )
2865
2868
2866
2869
details = pd .DataFrame (columns = [])
@@ -4595,7 +4598,7 @@ def extract_safetensors_config_(weights_dir, model_name, state_dict_filename, st
4595
4598
4596
4599
# helper methods for pre-processinf pytorch state_dict files
4597
4600
@staticmethod
4598
- def extract_pytorch_statedict_ (weights_dir , model_name , state_dict_filename , start_id = 0 , format = None , save = None ):
4601
+ def extract_pytorch_statedict_ (weights_dir , model_name , state_dict_filename , start_id = 0 , format = MODEL_FILE_FORMATS . PYTORCH , save = None ):
4599
4602
"""Read a pytorch state_dict file, and return a dict of layer configs
4600
4603
4601
4604
Can read model.bin or model.safetensors (detects by filename)
@@ -4609,8 +4612,8 @@ def extract_pytorch_statedict_(weights_dir, model_name, state_dict_filename, sta
4609
4612
4610
4613
start_id: int to start layer id counter
4611
4614
4612
- format: N/A yet, just infer
4613
-
4615
+ format: default id MODEL_FILE_FORMATS.PYTORCH, you must set SAFETENSORS to extract it
4616
+
4614
4617
save: (not set by default) save is True for ww_flatfies , False otherwise (use
4615
4618
4616
4619
Returns:
@@ -4620,24 +4623,25 @@ def extract_pytorch_statedict_(weights_dir, model_name, state_dict_filename, sta
4620
4623
4621
4624
Note: Currently only process dense layers (i.e. transformers), and
4622
4625
We may not want every layer in the state_dict
4626
+
4627
+ Moreover, there is no operational reason to extract safetensors;
4628
+ it is only here for testing and may be removed
4629
+ that is, to extract the config file
4623
4630
4624
4631
"""
4625
4632
4626
4633
config = {}
4627
4634
4628
- # TODO: check that format matches files if it is set
4629
- if format is not None :
4630
- logger .warning (f"Format { format } ignored, just selecting files we find " )
4631
-
4635
+ weight_keys = {}
4632
4636
if os .path .exists (state_dict_filename ):
4633
4637
logger .debug (f"Reading { state_dict_filename } " )
4634
- if state_dict_filename .endswith (".bin" ):
4638
+ if format == MODEL_FILE_FORMATS . PYTORCH and state_dict_filename .endswith (".bin" ):
4635
4639
state_dict = torch .load (state_dict_filename , map_location = torch .device ('cpu' ))
4636
4640
logger .info (f"Read pytorch state_dict: { state_dict_filename } , len={ len (state_dict )} " )
4637
- format = MODEL_FILE_FORMATS . PYTORCH
4641
+ weight_keys = [ key for key in state_dict . keys () if 'weight' in key . lower ()]
4638
4642
4639
- elif state_dict_filename . endswith ( ".safetensors" ):
4640
- format = MODEL_FILE_FORMATS .SAFETENSORS
4643
+
4644
+ elif format == MODEL_FILE_FORMATS .SAFETENSORS and state_dict_filename . endswith ( ".safetensors" ):
4641
4645
4642
4646
#TODO: move this to its own method
4643
4647
from safetensors import safe_open
@@ -4646,87 +4650,101 @@ def extract_pytorch_statedict_(weights_dir, model_name, state_dict_filename, sta
4646
4650
state_dict = {}
4647
4651
with safe_open (state_dict_filename , framework = "pt" , device = 'cpu' ) as f :
4648
4652
for k in f .keys ():
4649
- state_dict [k ] = f . get_tensor ( k )
4653
+ weight_keys [k ] = k
4650
4654
4651
4655
logger .info (f"Read safetensors: { state_dict_filename } , len={ len (state_dict )} " )
4652
4656
4653
4657
else :
4654
- logger .fatal (f"PyTorch state_dict { state_dict_filename } not found, stopping" )
4655
-
4658
+ logger .fatal (f"Format: { format } incorrect and /or PyTorch state_dict { state_dict_filename } not found, stopping" )
4659
+
4656
4660
# we only want the modell but sometimes the state dict has more info
4657
4661
if 'model' in [str (x ) for x in state_dict .keys ()]:
4658
4662
state_dict = state_dict ['model' ]
4659
4663
4660
- weight_keys = [key for key in state_dict .keys () if 'weight' in key .lower ()]
4661
4664
4662
4665
for layer_id , weight_key in enumerate (weight_keys ):
4663
4666
4667
+ if 'weight' in weight_key or 'bias' in weight_key :
4664
4668
# TODO: do not save weights by default
4665
- # and change bias file name depending on safetensors or not
4666
- #
4667
- layer_id_updated = layer_id + start_id
4668
- name = f"{ model_name } .{ layer_id_updated } "
4669
- longname = re .sub ('.weight$' , '' , weight_key )
4670
-
4671
- T = state_dict [weight_key ]
4672
-
4673
- shape = len (T .shape )
4674
- #if shape==2:
4675
- W = torch_T_to_np (T )
4676
-
4677
- # TODO: make sure this works with safetensors also
4678
- the_type = WWFlatFile .layer_type_as_str (W )
4669
+ # and change bias file name depending on safetensors or not
4670
+ #
4671
+ layer_id_updated = layer_id + start_id
4672
+ name = f"{ model_name } .{ layer_id_updated } "
4673
+ longname = re .sub ('.weight$' , '' , weight_key )
4674
+
4675
+ if format == MODEL_FILE_FORMATS .PYTORCH :
4676
+ T = state_dict [weight_key ]
4677
+ elif format == MODEL_FILE_FORMATS .SAFETENSORS :
4678
+ with safe_open (state_dict_filename , framework = "pt" , device = 'cpu' ) as f :
4679
+
4680
+ T = f .get_tensor (weight_key )
4679
4681
4680
- # TODO: is this always corret
4681
- has_bias = None
4682
-
4683
- bias_key = re .sub ('weight$' , 'bias' , weight_key )
4684
- if bias_key in state_dict :
4685
- T = state_dict [bias_key ]
4686
- b = torch_T_to_np (T )
4687
- has_bias = True
4682
+ shape = len (T .shape )
4683
+ #if shape==2:
4684
+ W = torch_T_to_np (T )
4688
4685
4689
- # TODO: what about perceptron layers ?
4690
-
4691
- # Save files by default for ww_flatfiles, but only is save=True for safetensores
4692
- if format == MODEL_FILE_FORMATS .WW_FLATFILES or save :
4693
- weightfile = f"{ name } .weight.npy"
4694
- filename = os .path .join (weights_dir ,weightfile )
4695
- logger .debug (f"saving weights to { filename } " )
4696
- np .save (filename , W )
4697
-
4698
- if has_bias :
4699
- biasfile = f"{ name } .bias.npy"
4700
- filename = os .path .join (weights_dir ,biasfile )
4701
- logger .debug (f"saving biases to { filename } " )
4702
- np .save (filename , b )
4686
+ # TODO: make sure this works with safetensors also
4687
+ the_type = WWFlatFile .layer_type_as_str (W )
4703
4688
4704
- # safetensors
4705
- elif format == MODEL_FILE_FORMATS .SAFETENSORS :
4706
- weightfile = state_dict_filename
4707
- biasfile = None
4708
- if has_bias :
4709
- biasfile = state_dict_filename
4689
+ # TODO: is this always corret
4690
+ has_bias = None
4691
+
4692
+ bias_key = re .sub ('weight$' , 'bias' , weight_key )
4693
+ if bias_key in state_dict :
4710
4694
4711
- else :
4712
- logger .fatal ("Unknown format {format}, stopping" )
4695
+ if format == MODEL_FILE_FORMATS .PYTORCH :
4696
+ T = state_dict [bias_key ]
4697
+ elif format == MODEL_FILE_FORMATS .SAFETENSORS :
4698
+ with safe_open (state_dict_filename , framework = "pt" , device = 'cpu' ) as f :
4699
+ T = f .get_tensor (bias_key )
4700
+
4701
+ b = torch_T_to_np (T )
4702
+ has_bias = True
4703
+
4704
+ # TODO: what about perceptron layers ?
4705
+
4706
+ # Save files by default for ww_flatfiles, but only is save=True for safetensores
4707
+ if format != MODEL_FILE_FORMATS .SAFETENSORS or save :
4708
+ weightfile = f"{ name } .weight.npy"
4709
+ biasfile = None
4710
+
4711
+ filename = os .path .join (weights_dir ,weightfile )
4712
+ logger .debug (f"saving weights to { filename } " )
4713
+ np .save (filename , W )
4714
+
4715
+ if has_bias :
4716
+ biasfile = f"{ name } .bias.npy"
4717
+ filename = os .path .join (weights_dir ,biasfile )
4718
+ logger .debug (f"saving biases to { filename } " )
4719
+ np .save (filename , b )
4720
+
4721
+ # safetensors
4722
+ elif format == MODEL_FILE_FORMATS .SAFETENSORS :
4723
+ weightfile = state_dict_filename
4724
+ biasfile = None
4725
+ if has_bias :
4726
+ biasfile = state_dict_filename
4727
+
4728
+ else :
4729
+ logger .fatal ("Unknown format {format}, stopping" )
4730
+
4731
+
4732
+
4733
+ # TODO
4734
+ # add the position id, 0 by default for weights and bias individuallu
4735
+ # allow other, percepton layers, because we need these later
4736
+ # = allow unknown types!
4713
4737
4714
-
4715
-
4716
- # TODO
4717
- # add the position id, 0 by default for weights and bias individuallu
4718
- # allow other, percepton layers, because we need these later
4719
- # = allow unknown types!
4720
-
4721
- layer_config = {}
4722
- layer_config ['key' ]= weight_key
4723
- layer_config ['longname' ]= longname
4724
- layer_config ['weightfile' ]= weightfile
4725
- layer_config ['biasfile' ]= biasfile
4726
- layer_config ['type' ]= the_type
4727
- layer_config ['dims' ]= json .dumps (W .shape )
4728
-
4729
- config [int (layer_id_updated )]= layer_config
4738
+ layer_config = {}
4739
+ layer_config ['key' ]= weight_key
4740
+ layer_config ['name' ]= name
4741
+ layer_config ['longname' ]= longname
4742
+ layer_config ['weightfile' ]= weightfile
4743
+ layer_config ['biasfile' ]= biasfile
4744
+ layer_config ['type' ]= the_type
4745
+ layer_config ['dims' ]= json .dumps (W .shape )
4746
+
4747
+ config [int (layer_id_updated )]= layer_config
4730
4748
4731
4749
return config
4732
4750
@@ -4875,21 +4893,19 @@ def extract_pytorch_bins(model_dir=None, model_name=None, tmp_dir="/tmp", format
4875
4893
config ['framework' ] = FRAMEWORK .PYTORCH
4876
4894
config ['weights_dir' ] = weights_dir
4877
4895
4878
- print ("the format" , format )
4879
4896
# TODO: infer the format ?
4880
4897
if format is None :
4881
4898
format , fileglob = WeightWatcher .infer_model_file_format (model_dir )
4882
4899
logger .info (f"Inferred format, found: { format } " )
4883
4900
4884
4901
config ['format' ] = format
4885
4902
if format == MODEL_FILE_FORMATS .PYTORCH :
4886
- fileglob = f"{ model_dir } /pytorch_model*bin"
4887
- elif format == MODEL_FILE_FORMATS .SAFETENSORS :
4903
+ fileglob = f"{ model_dir } /pytorch_model*bin"
4904
+ elif format == MODEL_FILE_FORMATS .SAFETENSORS :
4888
4905
fileglob = f"{ model_dir } /model*safetensors"
4889
4906
else :
4890
4907
logger .fatal (f"Unknown file format { format } , quitting" )
4891
4908
4892
- print (config )
4893
4909
4894
4910
logger .debug (f"searching for files { fileglob } " )
4895
4911
@@ -4921,7 +4937,7 @@ def extract_pytorch_bins(model_dir=None, model_name=None, tmp_dir="/tmp", format
4921
4937
for state_dict_filename in sorted (glob .glob (fileglob )):
4922
4938
logger .info (f"reading and extracting { state_dict_filename } " )
4923
4939
# TODO: update layer ids
4924
- layer_configs = WeightWatcher .extract_pytorch_statedict_ (weights_dir , model_name , state_dict_filename , start_id , format = format )
4940
+ layer_configs = WeightWatcher .extract_pytorch_statedict_ (weights_dir , model_name , state_dict_filename , start_id , format = format )
4925
4941
start_id = np .max ([int (x ) for x in layer_configs .keys ()]) + 1
4926
4942
config ['layers' ].update (layer_configs )
4927
4943
logger .debug (f"next start_id = { start_id } " )
0 commit comments