@@ -18,14 +18,71 @@ defmodule AxonDatasets.FashionMNIST do
18
18
19
19
defp download_labels ( opts ) do
20
20
data_path = opts [ :data_path ] || @ default_data_path
21
- transform = opts [ :transform_images ] || fn out -> out end
21
+ transform = opts [ :transform_labels ] || fn out -> out end
22
22
23
23
<< _ :: 32 , n_labels :: 32 , labels :: binary >> =
24
24
Utils . unzip_cache_or_download ( @ base_url , @ label_file , data_path )
25
25
26
26
transform . ( { labels , { :u , 8 } , { n_labels } } )
27
27
end
28
28
29
+ @ doc """
30
+ Downloads the FashionMNIST dataset or fetches it locally.
31
+ ## Options
32
+ * `datapath` - path where the dataset .gz should be stored locally
33
+ * `transform_images/1` - accepts accept a tuple like
34
+ `{binary_data, tensor_type, data_shape}` which can be used for
35
+ converting the `binary_data` to a tensor with a function like
36
+ fn {labels_binary, type, _shape} ->
37
+ labels_binary
38
+ |> Nx.from_binary(type)
39
+ |> Nx.new_axis(-1)
40
+ |> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
41
+ |> Nx.to_batched_list(32)
42
+ end
43
+ * `transform_labels/1` - similar to `transform_images/1` but applied to
44
+ dataset labels
45
+
46
+ Examples:
47
+ iex> AxonDatasets.FashionMNIST.download()
48
+ Fetching train-images-idx3-ubyte.gz from http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/
49
+
50
+ Fetching train-labels-idx1-ubyte.gz from http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/
51
+
52
+ {{<<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
53
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...>>,
54
+ {:u, 8}, {60000, 28, 28}},
55
+ {<<9, 0, 0, 3, 0, 2, 7, 2, 5, 5, 0, 9, 5, 5, 7, 9, 1, 0, 6, 4, 3, 1, 4, 8, 4,
56
+ 3, 0, 2, 4, 4, 5, 3, 6, 6, 0, 8, 5, 2, 1, 6, 6, 7, 9, 5, 9, 2, 7, ...>>,
57
+ {:u, 8}, {60000}}}
58
+
59
+ iex> transform_labels = fn {labels_binary, type, _shape} ->
60
+ iex> labels_binary
61
+ iex> |> Nx.from_binary(type)
62
+ iex> |> Nx.new_axis(-1)
63
+ iex> |> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
64
+ iex> |> Nx.to_batched_list(32)
65
+ iex> end
66
+ #Function<7.126501267/1 in :erl_eval.expr/5>
67
+ iex> AxonDatasets.FashionMNIST.download(transform_labels: transform_labels)
68
+ Using train-images-idx3-ubyte.gz from tmp/fashionmnist
69
+
70
+ Using train-labels-idx1-ubyte.gz from tmp/fashionmnist
71
+
72
+ {{<<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
73
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...>>,
74
+ {:u, 8}, {60000, 28, 28}}, #Nx.Tensor<
75
+ u8[60000][10]
76
+ [
77
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
78
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
79
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
80
+ [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
81
+ [1, 0, 0, 0, 0, 0, 0, 0, ...],
82
+ ...
83
+ ]
84
+ >}
85
+ """
29
86
def download ( opts \\ [ ] ) ,
30
87
do: { download_images ( opts ) , download_labels ( opts ) }
31
88
end
0 commit comments