Skip to content

Commit 5c47bf0

Browse files
committed
add any rank support for softmax and logsoftmax
1 parent bc2bf5a commit 5c47bf0

File tree

4 files changed

+185
-92
lines changed

4 files changed

+185
-92
lines changed

include/common.fypp

+23
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,29 @@ ${prefix + joinstr.join([line.strip() for line in txt.split("\n")]) + suffix}$
194194
#:endif
195195
#:enddef
196196

197+
#! Brace enclosed, comma separated Fortran expressions for a shape.
198+
#!
199+
#! It defines an output variable with the same shape as the input variable.
200+
#!
201+
#! Args:
202+
#! varname (str): Name of the variable to be used as origin
203+
#! origrank (int): Rank of the original variable
204+
#!
205+
#! Returns:
206+
#! Shape expression enclosed in braces, so that it can be used as suffix to
207+
#! define array shapes in declarations.
208+
#!
209+
#:def shape(varname, origrank)
210+
#:assert origrank > 0
211+
#:if origrank > 1
212+
#:call join_lines(joinstr=", ", prefix="(", suffix=")")
213+
#:for i in range(1, origrank+1)
214+
size(${varname}$, ${i}$)
215+
#:endfor
216+
#:endcall
217+
#:endif
218+
#:enddef
219+
197220

198221
#! Generates a routine name from a generic name, rank, type and kind
199222
#!

src/stdlib_specialfunctions.fypp

+30-23
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#:include "common.fypp"
2+
#:set RANKS = range(2, MAXRANK + 1)
23
module stdlib_specialfunctions
34
use stdlib_kinds, only: int8, int16, int32, int64, sp, dp, xdp, qp
45

@@ -271,26 +272,19 @@ module stdlib_specialfunctions
271272
!!
272273
!! Softmax function. Available for ranks 1 to 4
273274
#:for rk, rt in REAL_KINDS_TYPES
274-
pure module function Softmax_r1_${rk}$( x ) result( y )
275+
pure module function Softmax_r1_${rk}$( x , dim ) result( y )
275276
${rt}$, intent(in) :: x(:)
276277
${rt}$ :: y(size(x))
277-
end function
278-
pure module function Softmax_r2_${rk}$( x , dim ) result( y )
279-
${rt}$, intent(in) :: x(:,:)
280-
${rt}$ :: y(size(x,dim=1),size(x,dim=2))
281-
integer, intent(in), optional :: dim
282-
end function
283-
pure module function Softmax_r3_${rk}$( x , dim ) result( y )
284-
${rt}$, intent(in) :: x(:,:,:)
285-
${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3))
286278
integer, intent(in), optional :: dim
287279
end function
288-
pure module function Softmax_r4_${rk}$( x , dim ) result( y )
289-
${rt}$, intent(in) :: x(:,:,:,:)
290-
${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3),size(x,dim=4))
280+
#:for rank in RANKS
281+
pure module function Softmax_r${rank}$_${rk}$( x , dim ) result( y )
282+
${rt}$, intent(in) :: x${ranksuffix(rank)}$
283+
${rt}$ :: y${shape('x', rank)}$
291284
integer, intent(in), optional :: dim
292285
end function
293286
#:endfor
287+
#:endfor
294288
end interface
295289
public :: softmax
296290

@@ -303,24 +297,37 @@ module stdlib_specialfunctions
303297
${rt}$, intent(in) :: x(:)
304298
${rt}$ :: y(size(x))
305299
end function
306-
pure module function Softmax_grad_r2_${rk}$( x , dim ) result( y )
307-
${rt}$, intent(in) :: x(:,:)
308-
${rt}$ :: y(size(x,dim=1),size(x,dim=2))
300+
#:for rank in RANKS
301+
pure module function Softmax_grad_r${rank}$_${rk}$( x , dim ) result( y )
302+
${rt}$, intent(in) :: x${ranksuffix(rank)}$
303+
${rt}$ :: y${shape('x', rank)}$
309304
integer, intent(in), optional :: dim
310305
end function
311-
pure module function Softmax_grad_r3_${rk}$( x , dim ) result( y )
312-
${rt}$, intent(in) :: x(:,:,:)
313-
${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3))
306+
#:endfor
307+
#:endfor
308+
end interface
309+
public :: Softmax_grad
310+
311+
interface LogSoftmax
312+
!! Version: experimental
313+
!!
314+
!! Softmax function. Available for ranks 1 to 4
315+
#:for rk, rt in REAL_KINDS_TYPES
316+
pure module function LogSoftmax_r1_${rk}$( x, dim ) result( y )
317+
${rt}$, intent(in) :: x(:)
318+
${rt}$ :: y(size(x))
314319
integer, intent(in), optional :: dim
315320
end function
316-
pure module function Softmax_grad_r4_${rk}$( x , dim ) result( y )
317-
${rt}$, intent(in) :: x(:,:,:,:)
318-
${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3),size(x,dim=4))
321+
#:for rank in RANKS
322+
pure module function LogSoftmax_r${rank}$_${rk}$( x , dim ) result( y )
323+
${rt}$, intent(in) :: x${ranksuffix(rank)}$
324+
${rt}$ :: y${shape('x', rank)}$
319325
integer, intent(in), optional :: dim
320326
end function
321327
#:endfor
328+
#:endfor
322329
end interface
323-
public :: Softmax_grad
330+
public :: LogSoftmax
324331

325332
interface Softplus
326333
!! Version: experimental

src/stdlib_specialfunctions_activations.fypp

+61-69
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#:include "common.fypp"
2+
#:set RANKS = range(2, MAXRANK + 1)
23
submodule(stdlib_specialfunctions) stdlib_specialfunctions_activations
34
implicit none
45

@@ -192,73 +193,44 @@ end function
192193
! Softmax
193194
!==================================================
194195
#:for rk, rt in REAL_KINDS_TYPES
195-
pure module function Softmax_r1_${rk}$( x ) result( y )
196+
pure module function Softmax_r1_${rk}$( x , dim ) result( y )
196197
${rt}$, intent(in) :: x(:)
197198
${rt}$ :: y(size(x))
199+
integer, intent(in), optional :: dim
198200

199201
y = exp(x - maxval(x))
200202
y = y / sum(y)
201203
end function
202204

203-
pure module function Softmax_r2_${rk}$( x , dim ) result( y )
204-
${rt}$, intent(in) :: x(:,:)
205-
${rt}$ :: y(size(x,dim=1),size(x,dim=2))
205+
#:for rank in RANKS
206+
pure module function Softmax_r${rank}$_${rk}$( x , dim ) result( y )
207+
${rt}$, intent(in) :: x${ranksuffix(rank)}$
208+
${rt}$ :: y${shape('x', rank)}$
206209

207210
integer, intent(in), optional :: dim
208211
integer :: dim_, j
209212

210213
dim_ = 1; if(present(dim)) dim_ = dim
211214

212-
if(dim_==1)then
213-
do j = 1, size(x,dim=2)
214-
y(:,j) = Softmax( x(:,j) )
215+
if(dim_<${rank}$)then
216+
do j = 1, size(x,dim=${rank}$)
217+
#:if rank == 2
218+
y${select_subarray(rank, [(rank, 'j')])}$ = Softmax( x${select_subarray(rank, [(rank, 'j')])}$ )
219+
#:else
220+
y${select_subarray(rank, [(rank, 'j')])}$ = Softmax( x${select_subarray(rank, [(rank, 'j')])}$, dim=dim_ )
221+
#:endif
215222
end do
216223
else
217224
do j = 1, size(x,dim=1)
218-
y(j,:) = Softmax( x(j,:) )
219-
end do
220-
end if
221-
end function
222-
223-
pure module function Softmax_r3_${rk}$( x , dim ) result( y )
224-
${rt}$, intent(in) :: x(:,:,:)
225-
${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3))
226-
227-
integer, intent(in), optional :: dim
228-
integer :: dim_, j
229-
230-
dim_ = 1; if(present(dim)) dim_ = dim
231-
232-
if(dim_<=2)then
233-
do j = 1, size(x,dim=3)
234-
y(:,:,j) = Softmax( x(:,:,j) , dim = dim_ )
235-
end do
236-
else
237-
do j = 1, size(x,dim=1)
238-
y(j,:,:) = Softmax( x(j,:,:) , dim = 2 )
239-
end do
240-
end if
241-
end function
242-
243-
pure module function Softmax_r4_${rk}$( x , dim ) result( y )
244-
${rt}$, intent(in) :: x(:,:,:,:)
245-
${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3),size(x,dim=4))
246-
247-
integer, intent(in), optional :: dim
248-
integer :: dim_, j
249-
250-
dim_ = 1; if(present(dim)) dim_ = dim
251-
252-
if(dim_<=3)then
253-
do j = 1, size(x,dim=4)
254-
y(:,:,:,j) = Softmax( x(:,:,:,j) , dim = dim_ )
255-
end do
256-
else
257-
do j = 1, size(x,dim=1)
258-
y(j,:,:,:) = Softmax( x(j,:,:,:) , dim = 3 )
225+
#:if rank == 2
226+
y${select_subarray(rank, [(1, 'j')])}$ = Softmax( x${select_subarray(rank, [(1, 'j')])}$ )
227+
#:else
228+
y${select_subarray(rank, [(1, 'j')])}$ = Softmax( x${select_subarray(rank, [(1, 'j')])}$, dim=${rank-1}$ )
229+
#:endif
259230
end do
260231
end if
261232
end function
233+
#:endfor
262234

263235
pure module function Softmax_grad_r1_${rk}$( x ) result( y )
264236
${rt}$, intent(in) :: x(:)
@@ -268,9 +240,10 @@ pure module function Softmax_grad_r1_${rk}$( x ) result( y )
268240
y = y * (1._${rk}$ - y)
269241
end function
270242

271-
pure module function Softmax_grad_r2_${rk}$( x , dim ) result( y )
272-
${rt}$, intent(in) :: x(:,:)
273-
${rt}$ :: y(size(x,dim=1),size(x,dim=2))
243+
#:for rank in RANKS
244+
pure module function Softmax_grad_r${rank}$_${rk}$( x , dim ) result( y )
245+
${rt}$, intent(in) :: x${ranksuffix(rank)}$
246+
${rt}$ :: y${shape('x', rank)}$
274247

275248
integer, intent(in), optional :: dim
276249
integer :: dim_
@@ -280,32 +253,51 @@ pure module function Softmax_grad_r2_${rk}$( x , dim ) result( y )
280253
y = Softmax(x,dim_)
281254
y = y * (1._${rk}$ - y)
282255
end function
256+
#:endfor
283257

284-
pure module function Softmax_grad_r3_${rk}$( x , dim ) result( y )
285-
${rt}$, intent(in) :: x(:,:,:)
286-
${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3))
287-
288-
integer, intent(in), optional :: dim
289-
integer :: dim_
258+
#:endfor
290259

291-
dim_ = 1; if(present(dim)) dim_ = dim
292-
293-
y = Softmax(x,dim_)
294-
y = y * (1._${rk}$ - y)
260+
!==================================================
261+
! LogSoftmax
262+
!==================================================
263+
#:for rk, rt in REAL_KINDS_TYPES
264+
pure module function LogSoftmax_r1_${rk}$( x, dim ) result( y )
265+
${rt}$, intent(in) :: x(:)
266+
${rt}$ :: y(size(x))
267+
integer, intent(in), optional :: dim
268+
y = x - maxval(x)
269+
y = y - log( sum(exp(y)) )
295270
end function
296271

297-
pure module function Softmax_grad_r4_${rk}$( x , dim ) result( y )
298-
${rt}$, intent(in) :: x(:,:,:,:)
299-
${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3),size(x,dim=4))
300-
272+
#:for rank in RANKS
273+
pure module function LogSoftmax_r${rank}$_${rk}$( x , dim ) result( y )
274+
${rt}$, intent(in) :: x${ranksuffix(rank)}$
275+
${rt}$ :: y${shape('x', rank)}$
276+
301277
integer, intent(in), optional :: dim
302-
integer :: dim_
278+
integer :: dim_, j
303279

304280
dim_ = 1; if(present(dim)) dim_ = dim
305-
306-
y = Softmax(x,dim_)
307-
y = y * (1._${rk}$ - y)
281+
282+
if(dim_<${rank}$)then
283+
do j = 1, size(x,dim=${rank}$)
284+
#:if rank == 2
285+
y${select_subarray(rank, [(rank, 'j')])}$ = LogSoftmax( x${select_subarray(rank, [(rank, 'j')])}$ )
286+
#:else
287+
y${select_subarray(rank, [(rank, 'j')])}$ = LogSoftmax( x${select_subarray(rank, [(rank, 'j')])}$, dim=dim_ )
288+
#:endif
289+
end do
290+
else
291+
do j = 1, size(x,dim=1)
292+
#:if rank == 2
293+
y${select_subarray(rank, [(1, 'j')])}$ = LogSoftmax( x${select_subarray(rank, [(1, 'j')])}$ )
294+
#:else
295+
y${select_subarray(rank, [(1, 'j')])}$ = LogSoftmax( x${select_subarray(rank, [(1, 'j')])}$, dim=${rank-1}$ )
296+
#:endif
297+
end do
298+
end if
308299
end function
300+
#:endfor
309301

310302
#:endfor
311303

test/specialfunctions/test_specialfunctions_activations.fypp

+71
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ contains
2222

2323
testsuite = [ &
2424
new_unittest("sigmoid", test_sigmoid), &
25+
new_unittest("logsoftmax", test_logsoftmax), &
2526
new_unittest("gelu" , test_gelu ), &
2627
new_unittest("softmax", test_softmax) &
2728
]
@@ -134,6 +135,76 @@ contains
134135

135136
end subroutine test_softmax
136137

138+
subroutine test_logsoftmax(error)
139+
type(error_type), allocatable, intent(out) :: error
140+
141+
real(sp) :: x(3,3,3), y(3,3,3), y_ref(3,3,3)
142+
143+
x = reshape( [ 0.755308866500854,-0.789980888366699, 0.88806813955307 ,&
144+
-1.210636496543884, 0.746919095516205, 0.177668794989586,&
145+
0.540819883346558, 0.291532933712006,-0.324642956256866,&
146+
147+
1.94184136390686 , 0.951070547103882,-2.303410291671753,&
148+
0.59752631187439 , 1.189722180366516, 1.401878595352173,&
149+
-0.262732744216919, 0.421907186508179,-0.200457707047462,&
150+
151+
-0.702468574047089, 0.153426378965378, 0.330110251903534,&
152+
-1.16956090927124 ,-0.845042765140533,-1.364316940307617,&
153+
-1.679381608963013,-1.497506022453308,-1.194215059280396 ] ,[3,3,3] )
154+
155+
!> LogSoftmax on dim = 1
156+
y = LogSoftmax(x,dim=1)
157+
158+
y_ref = reshape( [ -0.856636286,-2.40192604,-0.723877013,&
159+
-2.49238253,-0.534826934,-1.10407722 ,&
160+
-0.788554132,-1.03784108,-1.65401697 ,&
161+
162+
-0.326149583,-1.31692040,-4.57140112 ,&
163+
-1.61804128,-1.02584541,-0.813688993 ,&
164+
-1.39805317,-0.713413179,-1.33577800 ,&
165+
166+
-1.81836534,-0.962470412,-0.785786569,&
167+
-1.16514850,-0.840630412,-1.35990453 ,&
168+
-1.34127355,-1.15939808,-0.856107056 ],[3,3,3] )
169+
170+
!> LogSoftmax on dim = 2
171+
y = LogSoftmax(x,dim=2)
172+
173+
y_ref = reshape( [ -0.666278005,-2.15167999, -0.581566215,&
174+
-2.63222337 ,-0.614779949,-1.29196548 ,&
175+
-0.880766988,-1.07016611,-1.79427731 ,&
176+
177+
-0.315551817,-1.05034387,-3.90906072 ,&
178+
-1.65986681 ,-0.811692238,-0.203771874,&
179+
-2.52012587 ,-1.57950723 ,-1.80610812 ,&
180+
181+
-0.694792688,-0.444887042,-0.337523341,&
182+
-1.16188502 ,-1.44335616 ,-2.03195047 ,&
183+
-1.67170572 ,-2.09581947 ,-1.86184871 ],[3,3,3] )
184+
185+
call check(error, norm2(y-y_ref) < tol_sp )
186+
if (allocated(error)) return
187+
188+
!> LogSoftmax on dim = 3
189+
y = LogSoftmax(x,dim=3)
190+
191+
y_ref = reshape( [ -1.50595474 , -2.22700500 ,-0.478398114,&
192+
-2.09693313 , -1.01544499 ,-1.52940571 ,&
193+
-0.442325860, -0.835677147,-0.936625183,&
194+
195+
-0.319422185, -0.485953659,-3.66987658 ,&
196+
-0.288770229, -0.572641909,-0.305195898,&
197+
-1.24587846 , -0.705302894,-0.812439919,&
198+
199+
-2.96373224 , -1.28359783 ,-1.03635597 ,&
200+
-2.05585742 , -2.60740685 ,-3.07139134 ,&
201+
-2.66252732 , -2.62471604 ,-1.80619729 ],[3,3,3] )
202+
203+
call check(error, norm2(y-y_ref) < tol_sp )
204+
if (allocated(error)) return
205+
206+
end subroutine test_logsoftmax
207+
137208

138209
end module test_specialfunctions_activation
139210

0 commit comments

Comments
 (0)