1
1
#:include "common.fypp"
2
+ #:set RANKS = range(2, MAXRANK + 1)
2
3
submodule(stdlib_specialfunctions) stdlib_specialfunctions_activations
3
4
implicit none
4
5
@@ -192,73 +193,44 @@ end function
192
193
! Softmax
193
194
!==================================================
194
195
#:for rk, rt in REAL_KINDS_TYPES
195
- pure module function Softmax_r1_${rk}$( x ) result( y )
196
+ pure module function Softmax_r1_${rk}$( x , dim ) result( y )
196
197
${rt}$, intent(in) :: x(:)
197
198
${rt}$ :: y(size(x))
199
+ integer, intent(in), optional :: dim
198
200
199
201
y = exp(x - maxval(x))
200
202
y = y / sum(y)
201
203
end function
202
204
203
- pure module function Softmax_r2_${rk}$( x , dim ) result( y )
204
- ${rt}$, intent(in) :: x(:,:)
205
- ${rt}$ :: y(size(x,dim=1),size(x,dim=2))
205
+ #:for rank in RANKS
206
+ pure module function Softmax_r${rank}$_${rk}$( x , dim ) result( y )
207
+ ${rt}$, intent(in) :: x${ranksuffix(rank)}$
208
+ ${rt}$ :: y${shape('x', rank)}$
206
209
207
210
integer, intent(in), optional :: dim
208
211
integer :: dim_, j
209
212
210
213
dim_ = 1; if(present(dim)) dim_ = dim
211
214
212
- if(dim_==1)then
213
- do j = 1, size(x,dim=2)
214
- y(:,j) = Softmax( x(:,j) )
215
+ if(dim_<${rank}$)then
216
+ do j = 1, size(x,dim=${rank}$)
217
+ #:if rank == 2
218
+ y${select_subarray(rank, [(rank, 'j')])}$ = Softmax( x${select_subarray(rank, [(rank, 'j')])}$ )
219
+ #:else
220
+ y${select_subarray(rank, [(rank, 'j')])}$ = Softmax( x${select_subarray(rank, [(rank, 'j')])}$, dim=dim_ )
221
+ #:endif
215
222
end do
216
223
else
217
224
do j = 1, size(x,dim=1)
218
- y(j,:) = Softmax( x(j,:) )
219
- end do
220
- end if
221
- end function
222
-
223
- pure module function Softmax_r3_${rk}$( x , dim ) result( y )
224
- ${rt}$, intent(in) :: x(:,:,:)
225
- ${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3))
226
-
227
- integer, intent(in), optional :: dim
228
- integer :: dim_, j
229
-
230
- dim_ = 1; if(present(dim)) dim_ = dim
231
-
232
- if(dim_<=2)then
233
- do j = 1, size(x,dim=3)
234
- y(:,:,j) = Softmax( x(:,:,j) , dim = dim_ )
235
- end do
236
- else
237
- do j = 1, size(x,dim=1)
238
- y(j,:,:) = Softmax( x(j,:,:) , dim = 2 )
239
- end do
240
- end if
241
- end function
242
-
243
- pure module function Softmax_r4_${rk}$( x , dim ) result( y )
244
- ${rt}$, intent(in) :: x(:,:,:,:)
245
- ${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3),size(x,dim=4))
246
-
247
- integer, intent(in), optional :: dim
248
- integer :: dim_, j
249
-
250
- dim_ = 1; if(present(dim)) dim_ = dim
251
-
252
- if(dim_<=3)then
253
- do j = 1, size(x,dim=4)
254
- y(:,:,:,j) = Softmax( x(:,:,:,j) , dim = dim_ )
255
- end do
256
- else
257
- do j = 1, size(x,dim=1)
258
- y(j,:,:,:) = Softmax( x(j,:,:,:) , dim = 3 )
225
+ #:if rank == 2
226
+ y${select_subarray(rank, [(1, 'j')])}$ = Softmax( x${select_subarray(rank, [(1, 'j')])}$ )
227
+ #:else
228
+ y${select_subarray(rank, [(1, 'j')])}$ = Softmax( x${select_subarray(rank, [(1, 'j')])}$, dim=${rank-1}$ )
229
+ #:endif
259
230
end do
260
231
end if
261
232
end function
233
+ #:endfor
262
234
263
235
pure module function Softmax_grad_r1_${rk}$( x ) result( y )
264
236
${rt}$, intent(in) :: x(:)
@@ -268,9 +240,10 @@ pure module function Softmax_grad_r1_${rk}$( x ) result( y )
268
240
y = y * (1._${rk}$ - y)
269
241
end function
270
242
271
- pure module function Softmax_grad_r2_${rk}$( x , dim ) result( y )
272
- ${rt}$, intent(in) :: x(:,:)
273
- ${rt}$ :: y(size(x,dim=1),size(x,dim=2))
243
+ #:for rank in RANKS
244
+ pure module function Softmax_grad_r${rank}$_${rk}$( x , dim ) result( y )
245
+ ${rt}$, intent(in) :: x${ranksuffix(rank)}$
246
+ ${rt}$ :: y${shape('x', rank)}$
274
247
275
248
integer, intent(in), optional :: dim
276
249
integer :: dim_
@@ -280,32 +253,51 @@ pure module function Softmax_grad_r2_${rk}$( x , dim ) result( y )
280
253
y = Softmax(x,dim_)
281
254
y = y * (1._${rk}$ - y)
282
255
end function
256
+ #:endfor
283
257
284
- pure module function Softmax_grad_r3_${rk}$( x , dim ) result( y )
285
- ${rt}$, intent(in) :: x(:,:,:)
286
- ${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3))
287
-
288
- integer, intent(in), optional :: dim
289
- integer :: dim_
258
+ #:endfor
290
259
291
- dim_ = 1; if(present(dim)) dim_ = dim
292
-
293
- y = Softmax(x,dim_)
294
- y = y * (1._${rk}$ - y)
260
+ !==================================================
261
+ ! LogSoftmax
262
+ !==================================================
263
+ #:for rk, rt in REAL_KINDS_TYPES
264
+ pure module function LogSoftmax_r1_${rk}$( x, dim ) result( y )
265
+ ${rt}$, intent(in) :: x(:)
266
+ ${rt}$ :: y(size(x))
267
+ integer, intent(in), optional :: dim
268
+ y = x - maxval(x)
269
+ y = y - log( sum(exp(y)) )
295
270
end function
296
271
297
- pure module function Softmax_grad_r4_${rk}$( x , dim ) result( y )
298
- ${rt}$, intent(in) :: x(:,:,:,:)
299
- ${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3),size(x,dim=4))
300
-
272
+ #:for rank in RANKS
273
+ pure module function LogSoftmax_r${rank}$_${rk}$( x , dim ) result( y )
274
+ ${rt}$, intent(in) :: x${ranksuffix(rank)}$
275
+ ${rt}$ :: y${shape('x', rank)}$
276
+
301
277
integer, intent(in), optional :: dim
302
- integer :: dim_
278
+ integer :: dim_, j
303
279
304
280
dim_ = 1; if(present(dim)) dim_ = dim
305
-
306
- y = Softmax(x,dim_)
307
- y = y * (1._${rk}$ - y)
281
+
282
+ if(dim_<${rank}$)then
283
+ do j = 1, size(x,dim=${rank}$)
284
+ #:if rank == 2
285
+ y${select_subarray(rank, [(rank, 'j')])}$ = LogSoftmax( x${select_subarray(rank, [(rank, 'j')])}$ )
286
+ #:else
287
+ y${select_subarray(rank, [(rank, 'j')])}$ = LogSoftmax( x${select_subarray(rank, [(rank, 'j')])}$, dim=dim_ )
288
+ #:endif
289
+ end do
290
+ else
291
+ do j = 1, size(x,dim=1)
292
+ #:if rank == 2
293
+ y${select_subarray(rank, [(1, 'j')])}$ = LogSoftmax( x${select_subarray(rank, [(1, 'j')])}$ )
294
+ #:else
295
+ y${select_subarray(rank, [(1, 'j')])}$ = LogSoftmax( x${select_subarray(rank, [(1, 'j')])}$, dim=${rank-1}$ )
296
+ #:endif
297
+ end do
298
+ end if
308
299
end function
300
+ #:endfor
309
301
310
302
#:endfor
311
303
0 commit comments