Skip to content

Commit 5d862d5

Browse files
committed
add docs
1 parent 9007cab commit 5d862d5

File tree

3 files changed

+105
-1
lines changed

3 files changed

+105
-1
lines changed

lib/cifar10.ex

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,35 @@ defmodule AxonDatasets.CIFAR10 do
1414
end
1515
end
1616

17+
@doc """
18+
Downloads the CIFAR10 dataset or fetches it locally.
19+
## Options
20+
* `datapath` - path where the dataset .gz should be stored locally
21+
* `transform_images/1` - accepts accept a tuple like
22+
`{binary_data, tensor_type, data_shape}` which can be used for
23+
converting the `binary_data` to a tensor with a function like
24+
fn {labels_binary, type, _shape} ->
25+
labels_binary
26+
|> Nx.from_binary(type)
27+
|> Nx.new_axis(-1)
28+
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
29+
|> Nx.to_batched_list(32)
30+
end
31+
* `transform_labels/1` - similar to `transform_images/1` but applied to
32+
dataset labels
33+
34+
Examples:
35+
iex> AxonDatasets.CIFAR10.download()
36+
Fetching cifar-10-binary.tar.gz from https://www.cs.toronto.edu/~kriz/
37+
38+
{{<<59, 43, 50, 68, 98, 119, 139, 145, 149, 149, 131, 125, 142, 144, 137, 129,
39+
137, 134, 124, 139, 139, 133, 136, 139, 152, 163, 168, 159, 158, 158, 152,
40+
148, 16, 0, 18, 51, 88, 120, 128, 127, 126, 116, 106, 101, 105, 113, 109,
41+
112, ...>>, {:u, 8}, {50000, 3, 32, 32}},
42+
{<<6, 9, 9, 4, 1, 1, 2, 7, 8, 3, 4, 7, 7, 2, 9, 9, 9, 3, 2, 6, 4, 3, 6, 6, 2,
43+
6, 3, 5, 4, 0, 0, 9, 1, 3, 4, 0, 3, 7, 3, 3, 5, 2, 2, 7, 1, 1, 1, ...>>,
44+
{:u, 8}, {50000}}}
45+
"""
1746
def download(opts \\ []) do
1847
data_path = opts[:data_path] || @default_data_path
1948
transform_images = opts[:transform_images] || fn out -> out end

lib/fashionmnist.ex

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,71 @@ defmodule AxonDatasets.FashionMNIST do
1818

1919
defp download_labels(opts) do
2020
data_path = opts[:data_path] || @default_data_path
21-
transform = opts[:transform_images] || fn out -> out end
21+
transform = opts[:transform_labels] || fn out -> out end
2222

2323
<<_::32, n_labels::32, labels::binary>> =
2424
Utils.unzip_cache_or_download(@base_url, @label_file, data_path)
2525

2626
transform.({labels, {:u, 8}, {n_labels}})
2727
end
2828

29+
@doc """
30+
Downloads the FashionMNIST dataset or fetches it locally.
31+
## Options
32+
* `datapath` - path where the dataset .gz should be stored locally
33+
* `transform_images/1` - accepts accept a tuple like
34+
`{binary_data, tensor_type, data_shape}` which can be used for
35+
converting the `binary_data` to a tensor with a function like
36+
fn {labels_binary, type, _shape} ->
37+
labels_binary
38+
|> Nx.from_binary(type)
39+
|> Nx.new_axis(-1)
40+
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
41+
|> Nx.to_batched_list(32)
42+
end
43+
* `transform_labels/1` - similar to `transform_images/1` but applied to
44+
dataset labels
45+
46+
Examples:
47+
iex> AxonDatasets.FashionMNIST.download()
48+
Fetching train-images-idx3-ubyte.gz from http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/
49+
50+
Fetching train-labels-idx1-ubyte.gz from http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/
51+
52+
{{<<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
53+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...>>,
54+
{:u, 8}, {60000, 28, 28}},
55+
{<<9, 0, 0, 3, 0, 2, 7, 2, 5, 5, 0, 9, 5, 5, 7, 9, 1, 0, 6, 4, 3, 1, 4, 8, 4,
56+
3, 0, 2, 4, 4, 5, 3, 6, 6, 0, 8, 5, 2, 1, 6, 6, 7, 9, 5, 9, 2, 7, ...>>,
57+
{:u, 8}, {60000}}}
58+
59+
iex> transform_labels = fn {labels_binary, type, _shape} ->
60+
iex> labels_binary
61+
iex> |> Nx.from_binary(type)
62+
iex> |> Nx.new_axis(-1)
63+
iex> |> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
64+
iex> |> Nx.to_batched_list(32)
65+
iex> end
66+
#Function<7.126501267/1 in :erl_eval.expr/5>
67+
iex> AxonDatasets.FashionMNIST.download(transform_labels: transform_labels)
68+
Using train-images-idx3-ubyte.gz from tmp/fashionmnist
69+
70+
Using train-labels-idx1-ubyte.gz from tmp/fashionmnist
71+
72+
{{<<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
73+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...>>,
74+
{:u, 8}, {60000, 28, 28}}, #Nx.Tensor<
75+
u8[60000][10]
76+
[
77+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
78+
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
79+
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
80+
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
81+
[1, 0, 0, 0, 0, 0, 0, 0, ...],
82+
...
83+
]
84+
>}
85+
"""
2986
def download(opts \\ []),
3087
do: {download_images(opts), download_labels(opts)}
3188
end

lib/mnist.ex

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,24 @@ defmodule AxonDatasets.MNIST do
2626
transform.({labels, {:u, 8}, {n_labels}})
2727
end
2828

29+
@doc """
30+
Downloads the MNIST dataset or fetches it locally.
31+
32+
## Options
33+
* `datapath` - path where the dataset .gz should be stored locally
34+
* `transform_images/1` - accepts accept a tuple like
35+
`{binary_data, tensor_type, data_shape}` which can be used for
36+
converting the `binary_data` to a tensor with a function like
37+
fn {labels_binary, type, _shape} ->
38+
labels_binary
39+
|> Nx.from_binary(type)
40+
|> Nx.new_axis(-1)
41+
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
42+
|> Nx.to_batched_list(32)
43+
end
44+
* `transform_labels/1` - similar to `transform_images/1` but applied to
45+
dataset labels
46+
"""
2947
def download(opts \\ []),
3048
do: {download_images(opts), download_labels(opts)}
3149
end

0 commit comments

Comments
 (0)