Skip to content

Commit 382c0fb

Browse files
committed
err in recursive draws when size gt 1
1 parent dd96592 commit 382c0fb

File tree

2 files changed

+27
-13
lines changed

2 files changed

+27
-13
lines changed

project/sim_modules/drawdist.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,11 @@ def unif_int(self, low, high, size, trunc=-2):
6161
upper bound
6262
6363
"""
64-
trunc = np.floor((len(str(round(low))) - 1) / 2) * -1
65-
nn = np.random.randint(low, high+1, size)
64+
if type(low) is int or type(low) is float:
65+
trunc = np.floor((len(str(round(low))) - 1) / 2) * -1
66+
else:
67+
trunc = np.floor((len(str(round(low[0]))) - 1) / 2) * -1
68+
nn = np.random.randint(low, np.array(high)+1, size)
6669
return np.around(nn, int(trunc))
6770

6871
def log_unif_int(self, low, high, size, base=10):

project/sim_modules/recurtbi.py

+22-11
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,35 @@
1515
cn = 0
1616

1717

18-
def itdict(d, size):
18+
def itdict(dt, size):
1919
global cn
20-
while cn < len(d):
21-
for k, v in d.items():
20+
while cn < len(dt):
21+
for k, v in dt.items():
2222
if len(v) == 3:
2323
dist, low, high = v
2424
if "tbi" in str(low):
25-
if type(d[low]) is int:
26-
d[k][1] = d[low]
25+
if len(dt[low]) == size:
26+
dt[k][1] = dt[low]
2727
elif "tbi" in str(high):
28-
if type(d[high]) is int:
29-
d[k][2] = d[high]
28+
if len(dt[high]) == size:
29+
dt[k][2] = dt[high]
3030
else:
3131
draw = getattr(DrawDist(), dist[1:])
32-
assert dist[1:] in avail_dist,"dist not recognized"
33-
d[k] = draw(float(low), float(high), size)
32+
assert dist[1:] in avail_dist, "dist not recognized"
33+
if type(low) is str and type(high) is str:
34+
# low is str, high is str '0' '123'
35+
dt[k] = draw(float(low), float(high), size)
36+
elif type(low) is str and len(high) == size:
37+
# low is str, high is array '0' ([1,2,3])
38+
dt[k] = draw([float(low)]*size, high, size)
39+
elif len(low) == size and type(high) is str:
40+
# low is array, high is str ([1,2,3]) '10'
41+
dt[k] = draw(low, [float(high)]*size, size)
42+
elif len(low) == size and len(high) == size:
43+
# low is array, high is array ([1,2,3]) ([1,2,3])
44+
dt[k] = draw(low, high, size)
3445
cn += 1
35-
return itdict(d, size)
36-
return d
46+
return itdict(dt, size)
47+
return dt
3748

3849
# d={"tbi1": [0, "tbi2"], "tbi2": ["tbi3", "tbi4"], "tbi3":[100, "tbi4"],"tbi4": [1000, 10000]}

0 commit comments

Comments
 (0)