1
1
import numpy as np
2
2
import matplotlib .pyplot as plt
3
+ import tqdm
4
+ import xarray as xr
3
5
4
- def calculate_flowlines (x ,y ,u ,v ,seed_points ,max_error = 0.00001 ):
6
+ ################ This is the import statement required to reference scripts within the package
7
+ import os ,sys ,glob
8
+ ndh_tools_path_opts = [
9
+ '/mnt/data01/Code/' ,
10
+ '/mnt/l/mnt/data01/Code/' ,
11
+ '/home/common/HolschuhLab/Code/'
12
+ ]
13
+ for i in ndh_tools_path_opts :
14
+ if os .path .isfile (i ): sys .path .append (i )
15
+ ################################################################################################
16
+
17
+ import NDH_Tools as ndh
18
+
19
+ def calculate_flowlines (input_xr ,seed_points ,uv_varnames = ['u' ,'v' ],xy_varnames = ['x' ,'y' ],steps = 20000 ,ds = 2 ,forward0_both1_backward2 = 1 ):
5
20
"""
6
21
% (C) Nick Holschuh - Amherst College -- 2022 ([email protected] )
7
22
%
@@ -10,7 +25,7 @@ def calculate_flowlines(x,y,u,v,seed_points,max_error=0.00001):
10
25
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
11
26
% The inputs are:
12
27
%
13
- % input_array -- array of data to analyze
28
+ % input_xr -- xarray dataarray that has the gradient objects in it
14
29
%
15
30
%%%%%%%%%%%%%%%
16
31
% The outputs are:
@@ -20,34 +35,189 @@ def calculate_flowlines(x,y,u,v,seed_points,max_error=0.00001):
20
35
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
21
36
"""
22
37
23
- ################# This uses a modified plt.streamline to pass through a user-editable keyword
24
- ################# argument "max_error", which goes into the interpolater to guarantee
25
- ################# accurate streammline calculation. Copy this version of streamplot
26
- ################# into your matplotlib directory to enable the use of streamline
27
- ################# calculation from NDH_Tools
28
-
29
-
30
- sls = []
31
-
32
- if isinstance (seed_points ,list ):
33
- seed_points = np .array (seed_points )
34
-
35
- if len (seed_points .shape ) == 1 :
36
- seed_points = np .expand_dims (seed_points ,axis = 0 )
37
-
38
- fig = plt .figure ()
39
-
40
- for ind0 , sp in enumerate (seed_points [:,0 ]):
41
- streamlines = plt .streamplot (x ,y ,u ,v ,start_points = [seed_points [ind0 ,:]], max_error = max_error )
42
-
43
-
44
- ########### Here we extract the coordinate information along the streamline
45
- sl = [streamlines .lines .get_paths ()[0 ].vertices [0 ]]
46
- for i in streamlines .lines .get_paths ():
47
- sl .append (i .vertices [1 ])
48
-
49
- sls .append (np .array (sl ))
50
-
51
- plt .close (fig )
52
-
53
- return sls
38
+ ##################### Here, we standardize the naming convention within the xarray object
39
+ input_xr = input_xr .rename ({xy_varnames [0 ]:'x' ,xy_varnames [1 ]:'y' })
40
+ uv_scalar = np .sqrt (input_xr [uv_varnames [0 ]].values ** 2 + input_xr [uv_varnames [1 ]].values ** 2 )
41
+ input_xr [uv_varnames [0 ]] = (('y' ,'x' ),input_xr [uv_varnames [0 ]].values / uv_scalar )
42
+ input_xr [uv_varnames [1 ]] = (('y' ,'x' ),input_xr [uv_varnames [1 ]].values / uv_scalar )
43
+
44
+
45
+ #################### We initialize the objects for the flowline calculation
46
+ flowlines = []
47
+
48
+ #################### Here is the forward calculation
49
+ if forward0_both1_backward2 <= 1 :
50
+ temp_xs = np .expand_dims (seed_points [:,0 ],0 )
51
+ temp_ys = np .expand_dims (seed_points [:,1 ],0 )
52
+
53
+ for ind0 in tqdm .tqdm (np .arange (steps )):
54
+ x_search = xr .DataArray (temp_xs [- 1 ,:],dims = ['vector_index' ])
55
+ y_search = xr .DataArray (temp_ys [- 1 ,:],dims = ['vector_index' ])
56
+ new_u = input_xr [uv_varnames [0 ]].sel (x = x_search ,y = y_search ,method = 'nearest' )
57
+ new_v = input_xr [uv_varnames [1 ]].sel (x = x_search ,y = y_search ,method = 'nearest' )
58
+
59
+ ######### This is an order of magnitude slower
60
+ #new_u = input_xr[uv_varnames[0]].interp(x=x_search,y=y_search)
61
+ #new_v = input_xr[uv_varnames[1]].interp(x=x_search,y=y_search)
62
+
63
+ temp_xs = np .concatenate ([temp_xs ,temp_xs [- 1 :,:]+ new_u .values .T * ds ])
64
+ temp_ys = np .concatenate ([temp_ys ,temp_ys [- 1 :,:]+ new_v .values .T * ds ])
65
+
66
+ xs = temp_xs
67
+ ys = temp_ys
68
+ else :
69
+ xs = np .empty ([2 ,1 ])
70
+ ys = np .empty ([2 ,1 ])
71
+
72
+
73
+ #################### Here is the backward calculation
74
+ if forward0_both1_backward2 >= 1 :
75
+ temp_xs = np .expand_dims (seed_points [:,0 ],0 )
76
+ temp_ys = np .expand_dims (seed_points [:,1 ],0 )
77
+
78
+ for ind0 in tqdm .tqdm (np .arange (steps )):
79
+ x_search = xr .DataArray (temp_xs [- 1 ,:],dims = ['vector_index' ])
80
+ y_search = xr .DataArray (temp_ys [- 1 ,:],dims = ['vector_index' ])
81
+ new_u = input_xr [uv_varnames [0 ]].sel (x = x_search ,y = y_search ,method = 'nearest' )
82
+ new_v = input_xr [uv_varnames [1 ]].sel (x = x_search ,y = y_search ,method = 'nearest' )
83
+
84
+ ######### This is an order of magnitude slower
85
+ #new_u = input_xr[uv_varnames[0]].interp(x=x_search,y=y_search)
86
+ #new_v = input_xr[uv_varnames[1]].interp(x=x_search,y=y_search)
87
+
88
+ temp_xs = np .concatenate ([temp_xs ,temp_xs [- 1 :,:]- new_u .values .T * ds ])
89
+ temp_ys = np .concatenate ([temp_ys ,temp_ys [- 1 :,:]- new_v .values .T * ds ])
90
+
91
+ xs = np .concatenate ([np .flipud (temp_xs ),xs ])
92
+ ys = np .concatenate ([np .flipud (temp_ys ),ys ])
93
+
94
+
95
+
96
+ flowlines = []
97
+ for ind0 in np .arange (len (xs [0 ,:])):
98
+ xy = np .stack ([xs [:,ind0 ],ys [:,ind0 ]]).T
99
+ flowlines .append (xy )
100
+
101
+ return flowlines
102
+
103
+ ##########################################################################################
104
+ #### This version of the code doesn't work quite right...
105
+ ##########################################################################################
106
+ ##def calculate_flowlines(x,y,u,v,seed_points,max_error=0.00001,retry_count_threshold=10):
107
+ ## """
108
+ ## % (C) Nick Holschuh - Amherst College -- 2022 ([email protected] )
109
+ ## %
110
+ ## % This function prints out the minimum and maximum values of an array
111
+ ## %
112
+ ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
113
+ ## % The inputs are:
114
+ ## %
115
+ ## % input_array -- array of data to analyze
116
+ ## %
117
+ ## %%%%%%%%%%%%%%%
118
+ ## % The outputs are:
119
+ ## %
120
+ ## % output -- the min and max in a 1x2 array
121
+ ## %
122
+ ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
123
+ ## """
124
+ ##
125
+ ## ################# This uses a modified plt.streamline to pass through a user-editable keyword
126
+ ## ################# argument "max_error", which goes into the interpolater to guarantee
127
+ ## ################# accurate streamline calculation. Copy the updated version of streamplot
128
+ ## ################# into your matplotlib directory to enable the use of streamline
129
+ ## ################# calculation from NDH_Tools (in streamline.py, which calls _integrate_rk12)
130
+ ##
131
+ ## if isinstance(seed_points,list):
132
+ ## seed_points = np.array(seed_points)
133
+ ##
134
+ ## if len(seed_points.shape) == 1:
135
+ ## seed_points = np.expand_dims(seed_points,axis=0)
136
+ ##
137
+ ## ###################### Initialize the returned object
138
+ ## final_sls = []
139
+ ## for ind0 in np.arange(len(seed_points[:,0])):
140
+ ## final_sls.append([])
141
+ ##
142
+ ## retry_count = 0
143
+ ## retry_inds = np.arange(0,len(seed_points[:,0]))
144
+ ## seed_subset = seed_points
145
+ ##
146
+ ## while len(retry_inds) > 0:
147
+ ##
148
+ ## sls = []
149
+ ##
150
+ ## ################# Calculate the streamlines for all unfound seed points
151
+ ## fig = plt.figure()
152
+ ## if retry_count == 0:
153
+ ## print('The initial streamline calculation -- this can be slow. Finding '+str(len(seed_subset[:,0]))+' streamlines')
154
+ ## try:
155
+ ## streamlines = plt.streamplot(x,y,u,v,start_points=seed_points, max_error=max_error, density=100)
156
+ ## except:
157
+ ## streamlines = plt.streamplot(x,y,u,v,start_points=seed_points, density=100)
158
+ ## if retry_count == 0:
159
+ ## print('Note: You need to update your matplotlib streamline.py and reduce the max error for this to work properly')
160
+ ## plt.close(fig)
161
+ ##
162
+ ## ################# Here we extract the coordinate info from the streamlines
163
+ ## sl_deconstruct = []
164
+ ## for i in streamlines.lines.get_paths():
165
+ ## sl_deconstruct.append(i.vertices[1])
166
+ ## sl_deconstruct = np.array(sl_deconstruct)
167
+ ##
168
+ ## ################ Here we separate the streamlines based on large breaks in distance
169
+ ## sl_dist = ndh.distance_vector(sl_deconstruct[:,0],sl_deconstruct[:,1],1)
170
+ ## dist_mean = np.mean(sl_dist)
171
+ ## breaks = np.where(sl_dist > (dist_mean+1)*50)[0]
172
+ ## if len(breaks) > 0:
173
+ ## breaks = np.concatenate([np.array([-1]),breaks,np.array([len(sl_deconstruct[:,0])])])+1
174
+ ## else:
175
+ ## breaks = np.array([0,len(sl_deconstruct[:,0])+1])
176
+ ##
177
+ ## for ind0 in np.arange(len(breaks)-1):
178
+ ## sls.append(sl_deconstruct[breaks[ind0]:breaks[ind0+1],:])
179
+ ##
180
+ ## ################ Here we identify which streamline goes with which seed_point
181
+ ## matching = []
182
+ ## for ind0 in np.arange(len(seed_subset[:,0])):
183
+ ## dists = []
184
+ ## for ind1,sl in enumerate(sls):
185
+ ## comp_vals = ndh.find_nearest_xy(sl,seed_subset[ind0,:])
186
+ ## dists.append(comp_vals['distance'][0])
187
+ ## best = np.where(np.array(dists) < 1e-8)[0]
188
+ ## try:
189
+ ## matching.append(best[0])
190
+ ## except:
191
+ ## matching.append(-1)
192
+ ##
193
+ ## ################# populate the final object
194
+ ## for ind0,i in enumerate(matching):
195
+ ## if i != -1:
196
+ ## final_sls[retry_inds[ind0]] = sls[i]
197
+ ##
198
+ ## ################# Finally, we identify the new set of streamlines that need to be computed, based on which have no match
199
+ ## new_retry_inds = np.where(np.array(matching) == -1)[0]
200
+ ## seed_subset = seed_points[retry_inds[new_retry_inds],:]
201
+ ## retry_inds = retry_inds[new_retry_inds]
202
+ ##
203
+ ## if len(retry_inds) > 0:
204
+ ## retry_count = retry_count+1
205
+ ## print('Recalculating for nearly overlapping points -- try '+str(retry_count)+'. Finding '+str(len(seed_subset[:,0]))+' streamlines')
206
+ ##
207
+ ## if retry_count > retry_count_threshold:
208
+ ## break
209
+ ##
210
+ ## if 0:
211
+ ## plt.figure()
212
+ ## plt.plot(test_dist)
213
+ ## plt.axhline(dist_median,c='orange')
214
+ ##
215
+ ## if 0:
216
+ ## plt.figure()
217
+ ## plt.plot(test[:,0],test[:,1],c='blue')
218
+ ## for i in final_sls:
219
+ ## plt.plot(i[:,0],i[:,1],c='red')
220
+ ## plt.plot(seed_points[:,0],seed_points[:,1],'o')
221
+ ##
222
+ ##
223
+ ## return final_sls
0 commit comments