diff --git a/docs/source/improve_synth.rst b/docs/source/improve_synth.rst index 450b4b14..1f8cdbbe 100644 --- a/docs/source/improve_synth.rst +++ b/docs/source/improve_synth.rst @@ -36,7 +36,7 @@ use the configuration file is a more appropriate interface (see also our :doc:`c MetaFrame.fit_dataframe( df, - var_specs="your_config_file.toml" + config="your_config_file.toml" ) This refers to a configuration file called ``your_config_file.toml``: @@ -177,7 +177,7 @@ The most common use-case for this is to set the distribution type and/or paramet .. code-block:: python # In this example you put the specifications in the toml file. - MetaFrame.fit_dataframe(df, var_specs="your_config_file.toml") + MetaFrame.fit_dataframe(df, config="your_config_file.toml") .. code-block:: toml diff --git a/metasyn/__main__.py b/metasyn/__main__.py index ad9e9d74..610e3e11 100644 --- a/metasyn/__main__.py +++ b/metasyn/__main__.py @@ -134,7 +134,7 @@ def create_metadata() -> None: data_frame = pl.read_csv(args.input, try_parse_dates=True, infer_schema_length=10000, null_values=["", "na", "NA", "N/A", "Na"], ignore_errors=True) - meta_frame = MetaFrame.fit_dataframe(data_frame, meta_config) + meta_frame = MetaFrame.fit_dataframe(data_frame, config=meta_config) meta_frame.save(args.output) diff --git a/metasyn/config.py b/metasyn/config.py index 58c1b268..f494bcdc 100644 --- a/metasyn/config.py +++ b/metasyn/config.py @@ -55,7 +55,7 @@ def __init__( self.config_version = config_version @staticmethod - def _parse_var_spec(var_spec): + def _parse_var_spec(var_spec) -> VarSpec: if isinstance(var_spec, VarSpec): return var_spec return VarSpec.from_dict(var_spec) @@ -72,6 +72,16 @@ def dist_providers(self, dist_providers): else: self._dist_providers = dist_providers + def update_varspecs(self, new_var_specs: Union[list[dict], list[VarSpec]]): + new_var_specs = [self._parse_var_spec(v) for v in new_var_specs] + for cur_new_var_spec in new_var_specs: + # Check if currently in varspecs and pop if it exists. + for i_var, old_var_spec in enumerate(self.var_specs): + if old_var_spec.name == cur_new_var_spec.name: + self.var_specs.pop(i_var) + break + self.var_specs.append(cur_new_var_spec) + @classmethod def from_toml(cls, config_fp: Union[str, Path]) -> MetaConfig: """Create a MetaConfig class from a .toml file. diff --git a/metasyn/metaframe.py b/metasyn/metaframe.py index 1b9a4ce9..14712009 100644 --- a/metasyn/metaframe.py +++ b/metasyn/metaframe.py @@ -68,11 +68,12 @@ def n_columns(self) -> int: def fit_dataframe( # noqa: PLR0912 cls, df: Optional[pl.DataFrame], - var_specs: Optional[Union[list[VarSpec], pathlib.Path, str, MetaConfig]] = None, + var_specs: Optional[Union[list[VarSpec]]] = None, dist_providers: Optional[list[str]] = None, privacy: Optional[Union[BasePrivacy, dict]] = None, n_rows: Optional[int] = None, - progress_bar: bool = True): + progress_bar: bool = True, + config: Optional[Union[pathlib.Path, str, MetaConfig]] = None): """Create a metasyn object from a polars (or pandas) dataframe. The Polars dataframe should be formatted already with the correct @@ -100,21 +101,34 @@ def fit_dataframe( # noqa: PLR0912 of rows in the input dataframe. progress_bar: Whether to create a progress bar. + config: + A path or MetaConfig object that contains information about the variable specifications + , defaults, etc. Variable specs in the config parameter will be overwritten by the + var_specs parameter. Returns ------- MetaFrame: Initialized metasyn metaframe. """ + if isinstance(var_specs, (str, pathlib.Path, MetaConfig)) and config is None: + warn("Supplying the configuration through var_specs is deprecated and will be removed" + f" in metasyn version 2.0. Use config={var_specs} instead.", + DeprecationWarning, stacklevel=2) + config = var_specs + var_specs = None # Parse the var_specs into a MetaConfig instance. - if isinstance(var_specs, (pathlib.Path, str)): - meta_config = MetaConfig.from_toml(var_specs) - elif isinstance(var_specs, MetaConfig): - meta_config = var_specs - elif var_specs is None: + if config is None: meta_config = MetaConfig([], dist_providers, defaults = {"privacy": privacy}) + elif isinstance(config, (pathlib.Path, str)): + meta_config = MetaConfig.from_toml(config) else: - meta_config = MetaConfig(var_specs, dist_providers, defaults = {"privacy": privacy}) + meta_config = config + + # var_specs overrules variable specifications in the configuration (file). + if var_specs is not None: + meta_config.update_varspecs(var_specs) + if dist_providers is not None: meta_config.dist_providers = dist_providers # type: ignore if privacy is not None: @@ -175,7 +189,7 @@ def from_config(cls, meta_config: MetaConfig) -> MetaFrame: ------- A created MetaFrame. """ - return cls.fit_dataframe(None, meta_config) + return cls.fit_dataframe(None, config=meta_config) def to_dict(self) -> Dict[str, Any]: """Create dictionary with the properties for recreation.""" diff --git a/tests/test_cli.py b/tests/test_cli.py index 1fe28997..eafd2697 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -96,8 +96,8 @@ def test_create_meta(tmp_dir, config): ] if config: cmd.extend(["--config", Path(tmp_dir) / 'config.ini']) - result = subprocess.run(cmd, check=False, capture_output=True) - assert result.returncode == 0 + result = subprocess.run(cmd, check=True, capture_output=True) + assert result.returncode == 0, result.stdout assert out_file.is_file() meta_frame = MetaFrame.load_json(out_file) assert len(meta_frame.meta_vars) == 12 diff --git a/tests/test_toml.py b/tests/test_toml.py index f66506b5..bf549b0d 100644 --- a/tests/test_toml.py +++ b/tests/test_toml.py @@ -21,7 +21,7 @@ def test_datafree_create(tmpdir): temp_toml = tmpdir / "test.toml" create_input_toml(temp_toml) assert cmp(temp_toml, Path("examples", "config_files", "example_all.toml")) - mf = MetaFrame.fit_dataframe(None, var_specs=Path(temp_toml)) + mf = MetaFrame.fit_dataframe(None, config=Path(temp_toml)) assert isinstance(mf, MetaFrame) assert mf.n_columns == len(BuiltinDistributionProvider.distributions) @@ -35,11 +35,31 @@ def test_datafree_create(tmpdir): ) def test_toml_save_load(tmpdir, toml_input, data): """Test whether TOML GMF files can be saved/loaded.""" - mf = MetaFrame.fit_dataframe(data, toml_input) + mf = MetaFrame.fit_dataframe(data, config=toml_input) mf.save(tmpdir/"test.toml") new_mf = MetaFrame.load(tmpdir/"test.toml") assert mf.n_columns == new_mf.n_columns +def test_varspec_update(): + """Check whether overwriting the varspecs with the var_specs parameter works.""" + toml_input = Path("examples", "config_files", "example_all.toml") + var_specs = [{ + "name": "DiscreteTruncatedNormalDistribution", + "var_type": "discrete", + "distribution": { + "implements": "core.normal", + "unique": False, + "parameters": { + "mean": 0, + "sd": 1, + } + } + }] + mf_normal = MetaFrame.fit_dataframe(None, config=toml_input) + mf_varspec = MetaFrame.fit_dataframe(None, var_specs=var_specs, config=toml_input) + assert mf_normal["DiscreteTruncatedNormalDistribution"].distribution.implements == "core.truncated_normal" + assert mf_varspec["DiscreteTruncatedNormalDistribution"].distribution.implements == "core.normal" + @mark.parametrize( "gmf_file", [ Path("examples", "gmf_files", "example_gmf_simple.json"),