Skip to content

Commit 5e1fddf

Browse files
authored
Nick/shape support (#348)
* Set show figure to false in example to avoid pop ups during local testing. * Update shape support: * Coerce a broader class of things into a shape tuple * Use math prob on int tuple instead of np.prod to avoid overflow when possible * Demonstrate extended support for one dimensional data flexibility with ttv. * Push OneDArray to all top level usages of numpy arrays: * pyttb_utils is more internal so can probably be more specific * Broaden our inputs to take sequences over lists. * Fix scalar edge case * Fix our tutorial since we weren't actually checking subs or values before.
1 parent 70fc7c7 commit 5e1fddf

17 files changed

+360
-282
lines changed

docs/source/tutorial/class_sptensor.ipynb

+3-3
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@
271271
"outputs": [],
272272
"source": [
273273
"indices = np.array([[0, 0, 0], [1, 0, 0]])\n",
274-
"values = np.ones((2,))\n",
274+
"values = np.ones((2, 1))\n",
275275
"\n",
276276
"Y = ttb.sptensor.from_aggregator(indices, values) # Create a sparse tensor.\n",
277277
"Y"
@@ -300,7 +300,7 @@
300300
"outputs": [],
301301
"source": [
302302
"indices = np.array([[0, 0, 0], [2, 2, 2]])\n",
303-
"values = np.array([1, 3])\n",
303+
"values = np.array([[1], [3]])\n",
304304
"\n",
305305
"Y = ttb.sptensor.from_aggregator(indices, values) # Create a sparse tensor.\n",
306306
"Y"
@@ -555,7 +555,7 @@
555555
"outputs": [],
556556
"source": [
557557
"indices = np.array([[0], [2], [4]])\n",
558-
"values = np.array([1, 1, 1])\n",
558+
"values = np.array([[1], [1], [1]])\n",
559559
"X = ttb.sptensor.from_aggregator(indices, values)\n",
560560
"X"
561561
]

pyttb/cp_als.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,21 @@
66

77
from __future__ import annotations
88

9-
from typing import Dict, List, Literal, Optional, Tuple, Union
9+
from typing import Dict, Literal, Optional, Tuple, Union
1010

1111
import numpy as np
1212

1313
import pyttb as ttb
14+
from pyttb.pyttb_utils import OneDArray, parse_one_d
1415

1516

1617
def cp_als( # noqa: PLR0912,PLR0913,PLR0915
1718
input_tensor: Union[ttb.tensor, ttb.sptensor, ttb.ttensor, ttb.sumtensor],
1819
rank: int,
1920
stoptol: float = 1e-4,
2021
maxiters: int = 1000,
21-
dimorder: Optional[List[int]] = None,
22-
optdims: Optional[List[int]] = None,
22+
dimorder: Optional[OneDArray] = None,
23+
optdims: Optional[OneDArray] = None,
2324
init: Union[Literal["random"], Literal["nvecs"], ttb.ktensor] = "random",
2425
printitn: int = 1,
2526
fixsigns: bool = True,
@@ -109,8 +110,8 @@ def cp_als( # noqa: PLR0912,PLR0913,PLR0915
109110
[[0.1467... 0.0923...]
110111
[0.1862... 0.3455...]]
111112
>>> print(output["params"]) # doctest: +NORMALIZE_WHITESPACE
112-
{'stoptol': 0.0001, 'maxiters': 1000, 'dimorder': [0, 1],\
113-
'optdims': [0, 1], 'printitn': 1, 'fixsigns': True}
113+
{'stoptol': 0.0001, 'maxiters': 1000, 'dimorder': array([0, 1]),\
114+
'optdims': array([0, 1]), 'printitn': 1, 'fixsigns': True}
114115
115116
Example using "nvecs" initialization:
116117
@@ -135,15 +136,17 @@ def cp_als( # noqa: PLR0912,PLR0913,PLR0915
135136

136137
# Set up dimorder if not specified
137138
if dimorder is None:
138-
dimorder = list(range(N))
139-
elif not isinstance(dimorder, list):
140-
assert False, "Dimorder must be a list"
141-
elif tuple(range(N)) != tuple(sorted(dimorder)):
139+
dimorder = np.arange(N)
140+
else:
141+
dimorder = parse_one_d(dimorder)
142+
if tuple(range(N)) != tuple(sorted(dimorder)):
142143
assert False, "Dimorder must be a list or permutation of range(tensor.ndims)"
143144

144145
# Set up optdims if not specified
145146
if optdims is None:
146-
optdims = list(range(N))
147+
optdims = np.arange(N)
148+
else:
149+
optdims = parse_one_d(optdims)
147150

148151
# Error checking
149152
assert rank > 0, "Number of components requested must be positive"

pyttb/export_data.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66

77
from __future__ import annotations
88

9-
from typing import Optional, TextIO, Tuple, Union
9+
from typing import Optional, TextIO, Union
1010

1111
import numpy as np
1212

1313
import pyttb as ttb
14+
from pyttb.pyttb_utils import Shape, parse_shape
1415

1516

1617
def export_data(
@@ -56,8 +57,9 @@ def export_data(
5657
export_array(fp, data, fmt_data)
5758

5859

59-
def export_size(fp: TextIO, shape: Tuple[int, ...]):
60+
def export_size(fp: TextIO, shape: Shape):
6061
"""Export the size of something to a file"""
62+
shape = parse_shape(shape)
6163
print(f"{len(shape)}", file=fp) # # of dimensions on one line
6264
shape_str = " ".join([str(d) for d in shape])
6365
print(f"{shape_str}", file=fp) # size of each dimensions on the next line

pyttb/gcp_opt.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
import logging
1010
import time
11-
from typing import Dict, List, Literal, Optional, Tuple, Union
11+
from math import prod
12+
from typing import Dict, Literal, Optional, Sequence, Tuple, Union
1213

1314
import numpy as np
1415

@@ -24,7 +25,7 @@ def gcp_opt( # noqa: PLR0912,PLR0913
2425
rank: int,
2526
objective: Union[Objectives, Tuple[function_type, function_type, float]],
2627
optimizer: Union[StochasticSolver, LBFGSB],
27-
init: Union[Literal["random"], ttb.ktensor, List[np.ndarray]] = "random",
28+
init: Union[Literal["random"], ttb.ktensor, Sequence[np.ndarray]] = "random",
2829
mask: Optional[Union[ttb.tensor, np.ndarray]] = None,
2930
sampler: Optional[GCPSampler] = None,
3031
printitn: int = 1,
@@ -74,7 +75,7 @@ def gcp_opt( # noqa: PLR0912,PLR0913
7475
if not isinstance(data, (ttb.tensor, ttb.sptensor)):
7576
raise ValueError("Input data must be tensor or sptensor.")
7677

77-
tensor_size = int(np.prod(data.shape))
78+
tensor_size = prod(data.shape)
7879

7980
if isinstance(data, ttb.tensor) and isinstance(mask, ttb.tensor):
8081
data *= mask
@@ -134,7 +135,7 @@ def gcp_opt( # noqa: PLR0912,PLR0913
134135
def _get_initial_guess(
135136
data: Union[ttb.tensor, ttb.sptensor],
136137
rank: int,
137-
init: Union[Literal["random"], ttb.ktensor, List[np.ndarray]],
138+
init: Union[Literal["random"], ttb.ktensor, Sequence[np.ndarray]],
138139
) -> ttb.ktensor:
139140
"""Get initial guess for gcp_opt
140141
@@ -143,7 +144,7 @@ def _get_initial_guess(
143144
Normalized ktensor.
144145
"""
145146
# TODO might be nice to merge with ALS/other CP methods
146-
if isinstance(init, list):
147+
if isinstance(init, Sequence) and not isinstance(init, str):
147148
return ttb.ktensor(init).normalize("all")
148149
if isinstance(init, ttb.ktensor):
149150
init.normalize("all")

pyttb/hosvd.py

+16-14
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,22 @@
77
from __future__ import annotations
88

99
import warnings
10-
from typing import List, Optional
10+
from typing import Optional
1111

1212
import numpy as np
1313
import scipy
1414

1515
import pyttb as ttb
16+
from pyttb.pyttb_utils import OneDArray, parse_one_d
1617

1718

1819
def hosvd( # noqa: PLR0912,PLR0913,PLR0915
1920
input_tensor: ttb.tensor,
2021
tol: float,
2122
verbosity: float = 1,
22-
dimorder: Optional[List[int]] = None,
23+
dimorder: Optional[OneDArray] = None,
2324
sequential: bool = True,
24-
ranks: Optional[List[int]] = None,
25+
ranks: Optional[OneDArray] = None,
2526
) -> ttb.ttensor:
2627
"""Compute sequentially-truncated higher-order SVD (Tucker).
2728
@@ -57,21 +58,22 @@ def hosvd( # noqa: PLR0912,PLR0913,PLR0915
5758
# In tucker als this is N
5859
d = input_tensor.ndims
5960

60-
if ranks is not None:
61-
if len(ranks) != d:
62-
raise ValueError(
63-
f"Ranks must be a list of length tensor ndims. Ndims: {d} but got "
64-
f"ranks: {ranks}."
65-
)
61+
if ranks is None:
62+
ranks = np.zeros((d,), dtype=int)
6663
else:
67-
ranks = [0] * d
64+
ranks = parse_one_d(ranks)
65+
66+
if len(ranks) != d:
67+
raise ValueError(
68+
"Ranks must be a sequence of length tensor ndims."
69+
f" Ndims: {d} but got ranks: {ranks}."
70+
)
6871

6972
# Set up dimorder if not specified (this is copy past from tucker_als
70-
if not dimorder:
71-
dimorder = list(range(d))
73+
if dimorder is None:
74+
dimorder = np.arange(d)
7275
else:
73-
if not isinstance(dimorder, list):
74-
raise ValueError("Dimorder must be a list")
76+
dimorder = parse_one_d(dimorder)
7577
if tuple(range(d)) != tuple(sorted(dimorder)):
7678
raise ValueError(
7779
"Dimorder must be a list or permutation of range(tensor.ndims)"

0 commit comments

Comments
 (0)