Skip to content

Commit 28b671f

Browse files
debugging 0.7.1.8
1 parent b8e69cc commit 28b671f

File tree

3 files changed

+143
-111
lines changed

3 files changed

+143
-111
lines changed

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(self):
5555
'matplotlib-inline',
5656
'powerlaw',
5757
'scikit-learn',
58+
'safetensors', # maybe will remove, autoload
5859
'tqdm'],
5960
entry_points = '''
6061
[console_scripts]

weightwatcher/constants.py

+43-28
Original file line numberDiff line numberDiff line change
@@ -184,48 +184,63 @@
184184
DENSE = 'dense'
185185
CONV2D = 'conv2d'
186186
CONV1D = 'conv1d'
187-
class LAYER_TYPE(IntFlag):
188-
UNKNOWN = auto()
189-
STACKED = auto()
190-
DENSE = auto()
191-
CONV1D = auto()
192-
CONV2D = auto()
193-
FLATTENED = auto()
194-
EMBEDDING = auto()
195-
NORM = auto()
187+
188+
class LAYER_TYPE():
189+
UNKNOWN = UNKNOWN
190+
STACKED = 'stacked'
191+
DENSE = DENSE
192+
CONV1D = CONV1D
193+
CONV2D = CONV2D
194+
FLATTENED = 'flattened'
195+
EMBEDDING = 'embedding'
196+
NORM = NORM
196197

197198
# framework names
198199
KERAS = 'kers'
199200
PYTORCH = 'pytorch'
200201
PYSTATEDICT = 'pystatedict'
201202
ONNX = 'onnx'
202-
203-
class FRAMEWORK(IntFlag):
204-
UNKNOWN = auto()
205-
PYTORCH = auto()
206-
KERAS = auto()
207-
ONNX = auto()
208-
PYSTATEDICT = auto()
209-
PYSTATEDICT_DIR = auto()
210-
WW_FLATFILES = auto()
211-
KERASH5 = auto()
212-
KERASH5FILE = auto()
213-
214-
215-
class CHANNELS(IntFlag):
216-
UNKNOWN = auto()
217-
FIRST = auto()
218-
LAST = auto()
203+
WW_FLATFILES = "ww_flatfiles"
204+
PYTORCH = "pytorch"
205+
SAFETENSORS = "safetensors"
206+
207+
# class FRAMEWORK(IntFlag):
208+
# UNKNOWN = auto()
209+
# PYTORCH = auto()
210+
# KERAS = auto()
211+
# ONNX = auto()
212+
# PYSTATEDICT = auto()
213+
# PYSTATEDICT_DIR = auto()
214+
# WW_FLATFILES = auto()
215+
# KERASH5 = auto()
216+
# KERASH5FILE = auto()
217+
218+
class FRAMEWORK():
219+
UNKNOWN = UNKNOWN
220+
PYTORCH = 'pytorch'
221+
KERAS = 'keras'
222+
ONNX = 'onnx'
223+
PYSTATEDICT = 'pystatedict'
224+
PYSTATEDICT_DIR = 'pystatedict_dir'
225+
WW_FLATFILES = WW_FLATFILES
226+
KERAS_H5 = 'keras_h5'
227+
KERAS_H5_FILE = 'keras_h5_file'
228+
229+
230+
class CHANNELS():
231+
UNKNOWN = UNKNOWN
232+
FIRST = 'first'
233+
LAST = 'last'
219234

220235
class METHODS(IntFlag):
221236
DESCRIBE = auto()
222237
ANALYZE = auto()
223238

224-
239+
# only used to extract into ww_flatfiels format
225240
class MODEL_FILE_FORMATS():
226241
PYTORCH = "pytorch"
227242
SAFETENSORS = "safetensors"
228-
WW_FLATFILES = "ww_flatfiles"
243+
#WW_FLATFILES = WW_FLATFILES
229244

230245

231246
# TODO either complete or remove thi

weightwatcher/weightwatcher.py

+99-83
Original file line numberDiff line numberDiff line change
@@ -877,7 +877,7 @@ def __repr__(self):
877877

878878
def __str__(self):
879879
return "WWLayer {} {} {} skipped {}".format(self.layer_id, self.name,
880-
self.the_type.name, self.skipped)
880+
self.the_type, self.skipped)
881881

882882

883883

@@ -1358,10 +1358,10 @@ def ww_layer_iter_(self):
13581358
is_skipped = self.apply_filters(ww_layer)
13591359
is_supported = self.layer_supported(ww_layer)
13601360

1361-
13621361
if is_supported and not is_skipped:
13631362
yield ww_layer
13641363

1364+
13651365
def make_layer_iter_(self):
13661366
return self.ww_layer_iter_()
13671367

@@ -1815,7 +1815,6 @@ def infer_model_file_format(model_dir):
18151815
fileglob = f"{model_dir}/*weight*npy"
18161816
num_files = len(glob.glob(fileglob))
18171817
if num_files > 0:
1818-
print("found ww files")
18191818
format = MODEL_FILE_FORMATS.WW_FLATFILES
18201819
return format, fileglob
18211820

@@ -2861,6 +2860,10 @@ def analyze(self, model=None, layers=[],
28612860
raise Exception(msg)
28622861
params = self.normalize_params(params)
28632862

2863+
2864+
#TODO: ?
2865+
# if we want deltas, we can just add the delta="model" in
2866+
#
28642867
layer_iterator = self.make_layer_iterator(model=self.model, layers=layers, params=params)
28652868

28662869
details = pd.DataFrame(columns=[])
@@ -4595,7 +4598,7 @@ def extract_safetensors_config_(weights_dir, model_name, state_dict_filename, st
45954598

45964599
# helper methods for pre-processinf pytorch state_dict files
45974600
@staticmethod
4598-
def extract_pytorch_statedict_(weights_dir, model_name, state_dict_filename, start_id = 0, format=None, save=None):
4601+
def extract_pytorch_statedict_(weights_dir, model_name, state_dict_filename, start_id = 0, format=MODEL_FILE_FORMATS.PYTORCH, save=None):
45994602
"""Read a pytorch state_dict file, and return a dict of layer configs
46004603
46014604
Can read model.bin or model.safetensors (detects by filename)
@@ -4609,8 +4612,8 @@ def extract_pytorch_statedict_(weights_dir, model_name, state_dict_filename, sta
46094612
46104613
start_id: int to start layer id counter
46114614
4612-
format: N/A yet, just infer
4613-
4615+
format: default id MODEL_FILE_FORMATS.PYTORCH, you must set SAFETENSORS to extract it
4616+
46144617
save: (not set by default) save is True for ww_flatfies , False otherwise (use
46154618
46164619
Returns:
@@ -4620,24 +4623,25 @@ def extract_pytorch_statedict_(weights_dir, model_name, state_dict_filename, sta
46204623
46214624
Note: Currently only process dense layers (i.e. transformers), and
46224625
We may not want every layer in the state_dict
4626+
4627+
Moreover, there is no operational reason to extract safetensors;
4628+
it is only here for testing and may be removed
4629+
that is, to extract the config file
46234630
46244631
"""
46254632

46264633
config = {}
46274634

4628-
# TODO: check that format matches files if it is set
4629-
if format is not None:
4630-
logger.warning(f"Format {format} ignored, just selecting files we find ")
4631-
4635+
weight_keys = {}
46324636
if os.path.exists(state_dict_filename):
46334637
logger.debug(f"Reading {state_dict_filename}")
4634-
if state_dict_filename.endswith(".bin"):
4638+
if format==MODEL_FILE_FORMATS.PYTORCH and state_dict_filename.endswith(".bin"):
46354639
state_dict = torch.load(state_dict_filename, map_location=torch.device('cpu'))
46364640
logger.info(f"Read pytorch state_dict: {state_dict_filename}, len={len(state_dict)}")
4637-
format = MODEL_FILE_FORMATS.PYTORCH
4641+
weight_keys = [key for key in state_dict.keys() if 'weight' in key.lower()]
46384642

4639-
elif state_dict_filename.endswith(".safetensors"):
4640-
format = MODEL_FILE_FORMATS.SAFETENSORS
4643+
4644+
elif format==MODEL_FILE_FORMATS.SAFETENSORS and state_dict_filename.endswith(".safetensors"):
46414645

46424646
#TODO: move this to its own method
46434647
from safetensors import safe_open
@@ -4646,87 +4650,101 @@ def extract_pytorch_statedict_(weights_dir, model_name, state_dict_filename, sta
46464650
state_dict = {}
46474651
with safe_open(state_dict_filename, framework="pt", device='cpu') as f:
46484652
for k in f.keys():
4649-
state_dict[k] = f.get_tensor(k)
4653+
weight_keys[k] = k
46504654

46514655
logger.info(f"Read safetensors: {state_dict_filename}, len={len(state_dict)}")
46524656

46534657
else:
4654-
logger.fatal(f"PyTorch state_dict {state_dict_filename} not found, stopping")
4655-
4658+
logger.fatal(f"Format: {format } incorrect and /or PyTorch state_dict {state_dict_filename} not found, stopping")
4659+
46564660
# we only want the modell but sometimes the state dict has more info
46574661
if 'model' in [str(x) for x in state_dict.keys()]:
46584662
state_dict = state_dict['model']
46594663

4660-
weight_keys = [key for key in state_dict.keys() if 'weight' in key.lower()]
46614664

46624665
for layer_id, weight_key in enumerate(weight_keys):
46634666

4667+
if 'weight' in weight_key or 'bias' in weight_key:
46644668
# TODO: do not save weights by default
4665-
# and change bias file name depending on safetensors or not
4666-
#
4667-
layer_id_updated = layer_id+start_id
4668-
name = f"{model_name}.{layer_id_updated}"
4669-
longname = re.sub('.weight$', '', weight_key)
4670-
4671-
T = state_dict[weight_key]
4672-
4673-
shape = len(T.shape)
4674-
#if shape==2:
4675-
W = torch_T_to_np(T)
4676-
4677-
# TODO: make sure this works with safetensors also
4678-
the_type = WWFlatFile.layer_type_as_str(W)
4669+
# and change bias file name depending on safetensors or not
4670+
#
4671+
layer_id_updated = layer_id+start_id
4672+
name = f"{model_name}.{layer_id_updated}"
4673+
longname = re.sub('.weight$', '', weight_key)
4674+
4675+
if format==MODEL_FILE_FORMATS.PYTORCH:
4676+
T = state_dict[weight_key]
4677+
elif format==MODEL_FILE_FORMATS.SAFETENSORS:
4678+
with safe_open(state_dict_filename, framework="pt", device='cpu') as f:
4679+
4680+
T= f.get_tensor(weight_key)
46794681

4680-
# TODO: is this always corret
4681-
has_bias = None
4682-
4683-
bias_key = re.sub('weight$', 'bias', weight_key)
4684-
if bias_key in state_dict:
4685-
T = state_dict[bias_key]
4686-
b = torch_T_to_np(T)
4687-
has_bias = True
4682+
shape = len(T.shape)
4683+
#if shape==2:
4684+
W = torch_T_to_np(T)
46884685

4689-
# TODO: what about perceptron layers ?
4690-
4691-
# Save files by default for ww_flatfiles, but only is save=True for safetensores
4692-
if format==MODEL_FILE_FORMATS.WW_FLATFILES or save:
4693-
weightfile = f"{name}.weight.npy"
4694-
filename = os.path.join(weights_dir,weightfile)
4695-
logger.debug(f"saving weights to {filename}")
4696-
np.save(filename, W)
4697-
4698-
if has_bias:
4699-
biasfile = f"{name}.bias.npy"
4700-
filename = os.path.join(weights_dir,biasfile)
4701-
logger.debug(f"saving biases to {filename}")
4702-
np.save(filename, b)
4686+
# TODO: make sure this works with safetensors also
4687+
the_type = WWFlatFile.layer_type_as_str(W)
47034688

4704-
# safetensors
4705-
elif format==MODEL_FILE_FORMATS.SAFETENSORS:
4706-
weightfile = state_dict_filename
4707-
biasfile = None
4708-
if has_bias:
4709-
biasfile = state_dict_filename
4689+
# TODO: is this always corret
4690+
has_bias = None
4691+
4692+
bias_key = re.sub('weight$', 'bias', weight_key)
4693+
if bias_key in state_dict:
47104694

4711-
else:
4712-
logger.fatal("Unknown format {format}, stopping")
4695+
if format==MODEL_FILE_FORMATS.PYTORCH:
4696+
T = state_dict[bias_key]
4697+
elif format==MODEL_FILE_FORMATS.SAFETENSORS:
4698+
with safe_open(state_dict_filename, framework="pt", device='cpu') as f:
4699+
T= f.get_tensor(bias_key)
4700+
4701+
b = torch_T_to_np(T)
4702+
has_bias = True
4703+
4704+
# TODO: what about perceptron layers ?
4705+
4706+
# Save files by default for ww_flatfiles, but only is save=True for safetensores
4707+
if format!=MODEL_FILE_FORMATS.SAFETENSORS or save:
4708+
weightfile = f"{name}.weight.npy"
4709+
biasfile = None
4710+
4711+
filename = os.path.join(weights_dir,weightfile)
4712+
logger.debug(f"saving weights to {filename}")
4713+
np.save(filename, W)
4714+
4715+
if has_bias:
4716+
biasfile = f"{name}.bias.npy"
4717+
filename = os.path.join(weights_dir,biasfile)
4718+
logger.debug(f"saving biases to {filename}")
4719+
np.save(filename, b)
4720+
4721+
# safetensors
4722+
elif format==MODEL_FILE_FORMATS.SAFETENSORS:
4723+
weightfile = state_dict_filename
4724+
biasfile = None
4725+
if has_bias:
4726+
biasfile = state_dict_filename
4727+
4728+
else:
4729+
logger.fatal("Unknown format {format}, stopping")
4730+
4731+
4732+
4733+
# TODO
4734+
# add the position id, 0 by default for weights and bias individuallu
4735+
# allow other, percepton layers, because we need these later
4736+
# = allow unknown types!
47134737

4714-
4715-
4716-
# TODO
4717-
# add the position id, 0 by default for weights and bias individuallu
4718-
# allow other, percepton layers, because we need these later
4719-
# = allow unknown types!
4720-
4721-
layer_config = {}
4722-
layer_config['key']=weight_key
4723-
layer_config['longname']=longname
4724-
layer_config['weightfile']=weightfile
4725-
layer_config['biasfile']=biasfile
4726-
layer_config['type']=the_type
4727-
layer_config['dims']=json.dumps(W.shape)
4728-
4729-
config[int(layer_id_updated)]=layer_config
4738+
layer_config = {}
4739+
layer_config['key']=weight_key
4740+
layer_config['name']=name
4741+
layer_config['longname']=longname
4742+
layer_config['weightfile']=weightfile
4743+
layer_config['biasfile']=biasfile
4744+
layer_config['type']=the_type
4745+
layer_config['dims']=json.dumps(W.shape)
4746+
4747+
config[int(layer_id_updated)]=layer_config
47304748

47314749
return config
47324750

@@ -4875,21 +4893,19 @@ def extract_pytorch_bins(model_dir=None, model_name=None, tmp_dir="/tmp", format
48754893
config['framework'] = FRAMEWORK.PYTORCH
48764894
config['weights_dir'] = weights_dir
48774895

4878-
print("the format", format)
48794896
# TODO: infer the format ?
48804897
if format is None:
48814898
format, fileglob = WeightWatcher.infer_model_file_format(model_dir)
48824899
logger.info(f"Inferred format, found: {format}")
48834900

48844901
config['format'] = format
48854902
if format==MODEL_FILE_FORMATS.PYTORCH:
4886-
fileglob = f"{model_dir}/pytorch_model*bin"
4887-
elif format==MODEL_FILE_FORMATS.SAFETENSORS:
4903+
fileglob = f"{model_dir}/pytorch_model*bin"
4904+
elif format==MODEL_FILE_FORMATS.SAFETENSORS:
48884905
fileglob = f"{model_dir}/model*safetensors"
48894906
else:
48904907
logger.fatal(f"Unknown file format {format}, quitting")
48914908

4892-
print(config)
48934909

48944910
logger.debug(f"searching for files {fileglob}")
48954911

@@ -4921,7 +4937,7 @@ def extract_pytorch_bins(model_dir=None, model_name=None, tmp_dir="/tmp", format
49214937
for state_dict_filename in sorted(glob.glob(fileglob)):
49224938
logger.info(f"reading and extracting {state_dict_filename}")
49234939
# TODO: update layer ids
4924-
layer_configs = WeightWatcher.extract_pytorch_statedict_(weights_dir, model_name, state_dict_filename, start_id, format=format)
4940+
layer_configs = WeightWatcher.extract_pytorch_statedict_(weights_dir, model_name, state_dict_filename, start_id, format=format)
49254941
start_id = np.max([int(x) for x in layer_configs.keys()]) + 1
49264942
config['layers'].update(layer_configs)
49274943
logger.debug(f"next start_id = {start_id}")

0 commit comments

Comments
 (0)