11#:include "common.fypp"
2+ #:set RANKS = range(2, MAXRANK + 1)
23submodule(stdlib_specialfunctions) stdlib_specialfunctions_activations
34 implicit none
45
@@ -192,73 +193,44 @@ end function
192193! Softmax
193194!==================================================
194195#:for rk, rt in REAL_KINDS_TYPES
195- pure module function Softmax_r1_${rk}$( x ) result( y )
196+ pure module function Softmax_r1_${rk}$( x , dim ) result( y )
196197 ${rt}$, intent(in) :: x(:)
197198 ${rt}$ :: y(size(x))
199+ integer, intent(in), optional :: dim
198200
199201 y = exp(x - maxval(x))
200202 y = y / sum(y)
201203end function
202204
203- pure module function Softmax_r2_${rk}$( x , dim ) result( y )
204- ${rt}$, intent(in) :: x(:,:)
205- ${rt}$ :: y(size(x,dim=1),size(x,dim=2))
205+ #:for rank in RANKS
206+ pure module function Softmax_r${rank}$_${rk}$( x , dim ) result( y )
207+ ${rt}$, intent(in) :: x${ranksuffix(rank)}$
208+ ${rt}$ :: y${shape('x', rank)}$
206209
207210 integer, intent(in), optional :: dim
208211 integer :: dim_, j
209212
210213 dim_ = 1; if(present(dim)) dim_ = dim
211214
212- if(dim_==1)then
213- do j = 1, size(x,dim=2)
214- y(:,j) = Softmax( x(:,j) )
215+ if(dim_<${rank}$)then
216+ do j = 1, size(x,dim=${rank}$)
217+ #:if rank == 2
218+ y${select_subarray(rank, [(rank, 'j')])}$ = Softmax( x${select_subarray(rank, [(rank, 'j')])}$ )
219+ #:else
220+ y${select_subarray(rank, [(rank, 'j')])}$ = Softmax( x${select_subarray(rank, [(rank, 'j')])}$, dim=dim_ )
221+ #:endif
215222 end do
216223 else
217224 do j = 1, size(x,dim=1)
218- y(j,:) = Softmax( x(j,:) )
219- end do
220- end if
221- end function
222-
223- pure module function Softmax_r3_${rk}$( x , dim ) result( y )
224- ${rt}$, intent(in) :: x(:,:,:)
225- ${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3))
226-
227- integer, intent(in), optional :: dim
228- integer :: dim_, j
229-
230- dim_ = 1; if(present(dim)) dim_ = dim
231-
232- if(dim_<=2)then
233- do j = 1, size(x,dim=3)
234- y(:,:,j) = Softmax( x(:,:,j) , dim = dim_ )
235- end do
236- else
237- do j = 1, size(x,dim=1)
238- y(j,:,:) = Softmax( x(j,:,:) , dim = 2 )
239- end do
240- end if
241- end function
242-
243- pure module function Softmax_r4_${rk}$( x , dim ) result( y )
244- ${rt}$, intent(in) :: x(:,:,:,:)
245- ${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3),size(x,dim=4))
246-
247- integer, intent(in), optional :: dim
248- integer :: dim_, j
249-
250- dim_ = 1; if(present(dim)) dim_ = dim
251-
252- if(dim_<=3)then
253- do j = 1, size(x,dim=4)
254- y(:,:,:,j) = Softmax( x(:,:,:,j) , dim = dim_ )
255- end do
256- else
257- do j = 1, size(x,dim=1)
258- y(j,:,:,:) = Softmax( x(j,:,:,:) , dim = 3 )
225+ #:if rank == 2
226+ y${select_subarray(rank, [(1, 'j')])}$ = Softmax( x${select_subarray(rank, [(1, 'j')])}$ )
227+ #:else
228+ y${select_subarray(rank, [(1, 'j')])}$ = Softmax( x${select_subarray(rank, [(1, 'j')])}$, dim=${rank-1}$ )
229+ #:endif
259230 end do
260231 end if
261232end function
233+ #:endfor
262234
263235pure module function Softmax_grad_r1_${rk}$( x ) result( y )
264236 ${rt}$, intent(in) :: x(:)
@@ -268,9 +240,10 @@ pure module function Softmax_grad_r1_${rk}$( x ) result( y )
268240 y = y * (1._${rk}$ - y)
269241end function
270242
271- pure module function Softmax_grad_r2_${rk}$( x , dim ) result( y )
272- ${rt}$, intent(in) :: x(:,:)
273- ${rt}$ :: y(size(x,dim=1),size(x,dim=2))
243+ #:for rank in RANKS
244+ pure module function Softmax_grad_r${rank}$_${rk}$( x , dim ) result( y )
245+ ${rt}$, intent(in) :: x${ranksuffix(rank)}$
246+ ${rt}$ :: y${shape('x', rank)}$
274247
275248 integer, intent(in), optional :: dim
276249 integer :: dim_
@@ -280,32 +253,51 @@ pure module function Softmax_grad_r2_${rk}$( x , dim ) result( y )
280253 y = Softmax(x,dim_)
281254 y = y * (1._${rk}$ - y)
282255end function
256+ #:endfor
283257
284- pure module function Softmax_grad_r3_${rk}$( x , dim ) result( y )
285- ${rt}$, intent(in) :: x(:,:,:)
286- ${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3))
287-
288- integer, intent(in), optional :: dim
289- integer :: dim_
258+ #:endfor
290259
291- dim_ = 1; if(present(dim)) dim_ = dim
292-
293- y = Softmax(x,dim_)
294- y = y * (1._${rk}$ - y)
260+ !==================================================
261+ ! LogSoftmax
262+ !==================================================
263+ #:for rk, rt in REAL_KINDS_TYPES
264+ pure module function LogSoftmax_r1_${rk}$( x, dim ) result( y )
265+ ${rt}$, intent(in) :: x(:)
266+ ${rt}$ :: y(size(x))
267+ integer, intent(in), optional :: dim
268+ y = x - maxval(x)
269+ y = y - log( sum(exp(y)) )
295270end function
296271
297- pure module function Softmax_grad_r4_${rk}$( x , dim ) result( y )
298- ${rt}$, intent(in) :: x(:,:,:,:)
299- ${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3),size(x,dim=4))
300-
272+ #:for rank in RANKS
273+ pure module function LogSoftmax_r${rank}$_${rk}$( x , dim ) result( y )
274+ ${rt}$, intent(in) :: x${ranksuffix(rank)}$
275+ ${rt}$ :: y${shape('x', rank)}$
276+
301277 integer, intent(in), optional :: dim
302- integer :: dim_
278+ integer :: dim_, j
303279
304280 dim_ = 1; if(present(dim)) dim_ = dim
305-
306- y = Softmax(x,dim_)
307- y = y * (1._${rk}$ - y)
281+
282+ if(dim_<${rank}$)then
283+ do j = 1, size(x,dim=${rank}$)
284+ #:if rank == 2
285+ y${select_subarray(rank, [(rank, 'j')])}$ = LogSoftmax( x${select_subarray(rank, [(rank, 'j')])}$ )
286+ #:else
287+ y${select_subarray(rank, [(rank, 'j')])}$ = LogSoftmax( x${select_subarray(rank, [(rank, 'j')])}$, dim=dim_ )
288+ #:endif
289+ end do
290+ else
291+ do j = 1, size(x,dim=1)
292+ #:if rank == 2
293+ y${select_subarray(rank, [(1, 'j')])}$ = LogSoftmax( x${select_subarray(rank, [(1, 'j')])}$ )
294+ #:else
295+ y${select_subarray(rank, [(1, 'j')])}$ = LogSoftmax( x${select_subarray(rank, [(1, 'j')])}$, dim=${rank-1}$ )
296+ #:endif
297+ end do
298+ end if
308299end function
300+ #:endfor
309301
310302#:endfor
311303
0 commit comments