-
Notifications
You must be signed in to change notification settings - Fork 73
/
Copy path__init__.py
91 lines (85 loc) · 3.63 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# ===============================================================================
# Copyright 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===============================================================================
import os
from typing import Dict, Tuple
from ..utils.bench_case import get_bench_case_value, get_data_name
from ..utils.common import custom_format
from ..utils.custom_types import BenchCase
from .loaders import (
dataset_loading_functions,
load_custom_data,
load_openml_data,
load_sklearn_synthetic_data,
)
def load_data(bench_case: BenchCase) -> Tuple[Dict, Dict]:
# get data name and cache dirs
data_name = get_data_name(bench_case, shortened=False)
data_cache = get_bench_case_value(bench_case, "data:cache_directory", "data_cache")
raw_data_cache = get_bench_case_value(
bench_case, "data:raw_cache_directory", os.path.join(data_cache, "raw")
)
common_kwargs = {
"data_name": data_name,
"data_cache": data_cache,
"raw_data_cache": raw_data_cache,
}
preproc_kwargs = get_bench_case_value(bench_case, "data:preprocessing_kwargs", dict())
# make cache directories
os.makedirs(data_cache, exist_ok=True)
os.makedirs(raw_data_cache, exist_ok=True)
# load by dataset name
dataset = get_bench_case_value(bench_case, "data:dataset")
if dataset is not None:
dataset_params = get_bench_case_value(bench_case, "data:dataset_kwargs", dict())
if dataset in dataset_loading_functions:
# registered dataset loading branch
return dataset_loading_functions[dataset](
**common_kwargs,
preproc_kwargs=preproc_kwargs,
dataset_params=dataset_params,
)
else:
# user-provided dataset loading branch
return load_custom_data(**common_kwargs, preproc_kwargs=preproc_kwargs)
# load by source
source = get_bench_case_value(bench_case, "data:source")
if source is not None:
# sklearn.datasets functions
if source.startswith("make_"):
generation_kwargs = get_bench_case_value(
bench_case, "data:generation_kwargs", dict()
)
if "center_box" in generation_kwargs:
generation_kwargs["center_box"] = (
-1 * generation_kwargs["center_box"],
generation_kwargs["center_box"],
)
return load_sklearn_synthetic_data(
function_name=source,
input_kwargs=generation_kwargs,
preproc_kwargs=preproc_kwargs,
**common_kwargs,
)
# openml dataset
elif source == "fetch_openml":
openml_id = get_bench_case_value(bench_case, "data:id")
return load_openml_data(
openml_id=openml_id, preproc_kwargs=preproc_kwargs, **common_kwargs
)
raise ValueError(
"Unable to get data from bench_case:\n"
f'{custom_format(get_bench_case_value(bench_case, "data"))}'
)