Skip to content

Commit 1cda1d4

Browse files
rspwarnaarUTReSurfEMG
authored andcommitted
Fix overwrite config file
1 parent 000b1a7 commit 1cda1d4

File tree

1 file changed

+32
-7
lines changed

1 file changed

+32
-7
lines changed

resurfemg/data_connector/config.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,29 @@ class Config:
6565

6666
required_directories = ['root_data']
6767

68-
def __init__(self, location=None, verbose=False):
68+
def __init__(self, location=None, verbose=False, force=False):
69+
"""
70+
This function initializes the configuration file. If no location is
71+
specified it will try to load the configuration file from the default
72+
locations:
73+
- ./config.json
74+
- ~/.resurfemg/config.json
75+
- /etc/resurfemg/config.json
76+
- PROJECT_ROOT/config.json
77+
-----------------------------------------------------------------------
78+
:param location: The location of the configuration file.
79+
:type location: str
80+
:param verbose: A boolean to print the loaded configuration.
81+
:type verbose: bool
82+
:param force: A boolean to overwrite the configuration file.
83+
:type force: bool
84+
:raises ValueError: If the configuration file is not found.
85+
"""
6986
self._raw = None
7087
self._loaded = None
7188
self.example = 'config_example_resurfemg.json'
7289
self.repo_root = find_repo_root(self.example)
90+
self.force = force
7391
self.created_config = False
7492
# In the ResurfEMG project, the test data is stored in ./test_data
7593
test_path = os.path.join(self.repo_root, 'test_data')
@@ -138,6 +156,7 @@ def usage(self):
138156
def create_config_from_example(
139157
self,
140158
location: str,
159+
force=False,
141160
):
142161
"""
143162
This function creates a config file from an example file.
@@ -147,10 +166,15 @@ def create_config_from_example(
147166
:raises ValueError: If the example file is not found.
148167
"""
149168
config_path = location.replace(self.example, 'config.json')
150-
with open(location, 'r') as f:
151-
example = json.load(f)
152-
with open(config_path, 'w') as f:
153-
json.dump(example, f, indent=4, sort_keys=True)
169+
if os.path.isfile(config_path) and not force:
170+
raise ValueError(
171+
f'Config file already exists at {config_path}.'
172+
+ ' Use `force=True` to overwrite.')
173+
else:
174+
with open(location, 'r') as f:
175+
example = json.load(f)
176+
with open(config_path, 'w') as f:
177+
json.dump(example, f, indent=4, sort_keys=True)
154178

155179
def load(self, location, verbose=False):
156180
"""
@@ -180,9 +204,10 @@ def load(self, location, verbose=False):
180204
logging.info('Failed to load %s: %s', _path, e)
181205
else:
182206
if (self.repo_root is not None and os.path.isfile(
183-
os.path.join(self.repo_root, 'config.json'))):
207+
os.path.join(self.repo_root, self.example))):
184208
self.create_config_from_example(
185-
os.path.join(self.repo_root, self.example))
209+
os.path.join(self.repo_root, self.example),
210+
force=self.force,)
186211
root_path = os.path.join(self.repo_root, 'not_pushed')
187212
if not os.path.isdir(root_path):
188213
os.makedirs(root_path)

0 commit comments

Comments
 (0)