Skip to content

Commit 1e623ad

Browse files
authored
Add Fix{N} for fixing a single positional argument at any position (#829)
1 parent f8af0d1 commit 1e623ad

File tree

4 files changed

+203
-1
lines changed

4 files changed

+203
-1
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Compat"
22
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
3-
version = "4.15.0"
3+
version = "4.16.0"
44

55
[deps]
66
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"

README.md

+3
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ changes in `julia`.
7272

7373
## Supported features
7474

75+
* `Compat.Fix{N}` which fixes an argument at the `N`th position ([#54653]) (since Compat 4.16.0)
76+
7577
* `chopprefix(s, prefix)` and `chopsuffix(s, suffix)` ([#40995]) (since Compat 4.15.0)
7678

7779
* `logrange(lo, hi; length)` is like `range` but with a constant ratio, not difference. ([#39071]) (since Compat 4.14.0) Note that on Julia 1.8 and earlier, the version from Compat has slightly lower floating-point accuracy than the one in Base (Julia 1.11 and later).
@@ -192,3 +194,4 @@ Note that you should specify the correct minimum version for `Compat` in the
192194
[#47679]: https://github.com/JuliaLang/julia/pull/47679
193195
[#48038]: https://github.com/JuliaLang/julia/issues/48038
194196
[#50105]: https://github.com/JuliaLang/julia/issues/50105
197+
[#54653]: https://github.com/JuliaLang/julia/issues/54653

src/Compat.jl

+68
Original file line numberDiff line numberDiff line change
@@ -1122,6 +1122,74 @@ if VERSION < v"1.8.0-DEV.1016"
11221122
export chopprefix, chopsuffix
11231123
end
11241124

1125+
# https://github.com/JuliaLang/julia/pull/54653: add Fix
1126+
@static if !isdefined(Base, :Fix) # VERSION < v"1.12.0-DEV.981"
1127+
@static if !isdefined(Base, :_stable_typeof)
1128+
_stable_typeof(x) = typeof(x)
1129+
_stable_typeof(::Type{T}) where {T} = Type{T}
1130+
else
1131+
using Base: _stable_typeof
1132+
end
1133+
1134+
@doc """
1135+
Fix{N}(f, x)
1136+
1137+
A type representing a partially-applied version of a function `f`, with the argument
1138+
`x` fixed at position `N::Int`. In other words, `Fix{3}(f, x)` behaves similarly to
1139+
`(y1, y2, y3...; kws...) -> f(y1, y2, x, y3...; kws...)`.
1140+
1141+
!!! note
1142+
When nesting multiple `Fix`, note that the `N` in `Fix{N}` is _relative_ to the current
1143+
available arguments, rather than an absolute ordering on the target function. For example,
1144+
`Fix{1}(Fix{2}(f, 4), 4)` fixes the first and second arg, while `Fix{2}(Fix{1}(f, 4), 4)`
1145+
fixes the first and third arg.
1146+
1147+
!!! note
1148+
Note that `Compat.Fix{1}`/`Fix{2}` are not the same as `Base.Fix1`/`Fix2` on Julia
1149+
versions earlier than `1.12.0-DEV.981`. Therefore, if you wish to use this as a way
1150+
to _dispatch_ on `Fix{N}`, you may wish to declare a method for both
1151+
`Compat.Fix{1}`/`Fix{2}` as well as `Base.Fix1`/`Fix2`, conditional on
1152+
a `@static if !isdefined(Base, :Fix); ...; end`.
1153+
""" Fix
1154+
1155+
struct Fix{N,F,T} <: Function
1156+
f::F
1157+
x::T
1158+
1159+
function Fix{N}(f::F, x) where {N,F}
1160+
if !(N isa Int)
1161+
throw(ArgumentError("expected type parameter in `Fix` to be `Int`, but got `$N::$(typeof(N))`"))
1162+
elseif N < 1
1163+
throw(ArgumentError("expected `N` in `Fix{N}` to be integer greater than 0, but got $N"))
1164+
end
1165+
new{N,_stable_typeof(f),_stable_typeof(x)}(f, x)
1166+
end
1167+
end
1168+
1169+
function (f::Fix{N})(args::Vararg{Any,M}; kws...) where {N,M}
1170+
M < N-1 && throw(ArgumentError("expected at least $(N-1) arguments to `Fix{$N}`, but got $M"))
1171+
return f.f(args[begin:begin+(N-2)]..., f.x, args[begin+(N-1):end]...; kws...)
1172+
end
1173+
1174+
# Special cases for improved constant propagation
1175+
(f::Fix{1})(arg; kws...) = f.f(f.x, arg; kws...)
1176+
(f::Fix{2})(arg; kws...) = f.f(arg, f.x; kws...)
1177+
1178+
@doc """
1179+
Alias for `Fix{1}`. See [`Fix`](@ref Compat.Fix).
1180+
""" Fix1
1181+
1182+
const Fix1{F,T} = Fix{1,F,T}
1183+
1184+
@doc """
1185+
Alias for `Fix{2}`. See [`Fix`](@ref Compat.Fix).
1186+
""" Fix2
1187+
1188+
const Fix2{F,T} = Fix{2,F,T}
1189+
else
1190+
using Base: Fix, Fix1, Fix2
1191+
end
1192+
11251193
include("deprecated.jl")
11261194

11271195
end # module Compat

test/runtests.jl

+131
Original file line numberDiff line numberDiff line change
@@ -907,3 +907,134 @@ end
907907
@test isa(chopsuffix(S("foo"), "oo"), SubString)
908908
end
909909
end
910+
911+
# https://github.com/JuliaLang/julia/pull/54653: add Fix
912+
@testset "Fix" begin
913+
function test_fix1(Fix1=Compat.Fix1)
914+
increment = Fix1(+, 1)
915+
@test increment(5) == 6
916+
@test increment(-1) == 0
917+
@test increment(0) == 1
918+
@test map(increment, [1, 2, 3]) == [2, 3, 4]
919+
920+
concat_with_hello = Fix1(*, "Hello ")
921+
@test concat_with_hello("World!") == "Hello World!"
922+
# Make sure inference is good:
923+
@inferred concat_with_hello("World!")
924+
925+
one_divided_by = Fix1(/, 1)
926+
@test one_divided_by(10) == 1/10.0
927+
@test one_divided_by(-5) == 1/-5.0
928+
929+
return nothing
930+
end
931+
932+
function test_fix2(Fix2=Compat.Fix2)
933+
return_second = Fix2((x, y) -> y, 999)
934+
@test return_second(10) == 999
935+
@inferred return_second(10)
936+
@test return_second(-5) == 999
937+
938+
divide_by_two = Fix2(/, 2)
939+
@test map(divide_by_two, (2, 4, 6)) == (1.0, 2.0, 3.0)
940+
@inferred map(divide_by_two, (2, 4, 6))
941+
942+
concat_with_world = Fix2(*, " World!")
943+
@test concat_with_world("Hello") == "Hello World!"
944+
@inferred concat_with_world("Hello World!")
945+
946+
return nothing
947+
end
948+
949+
# Test with normal Base.Fix1 and Base.Fix2
950+
test_fix1()
951+
test_fix2()
952+
953+
# Now, repeat the Fix1 and Fix2 tests, but
954+
# with a Fix lambda function used in their place
955+
test_fix1((op, arg) -> Compat.Fix{1}(op, arg))
956+
test_fix2((op, arg) -> Compat.Fix{2}(op, arg))
957+
958+
# Now, we do more complex tests of Fix:
959+
let Fix=Compat.Fix
960+
@testset "Argument Fixation" begin
961+
let f = (x, y, z) -> x + y * z
962+
fixed_f1 = Fix{1}(f, 10)
963+
@test fixed_f1(2, 3) == 10 + 2 * 3
964+
965+
fixed_f2 = Fix{2}(f, 5)
966+
@test fixed_f2(1, 4) == 1 + 5 * 4
967+
968+
fixed_f3 = Fix{3}(f, 3)
969+
@test fixed_f3(1, 2) == 1 + 2 * 3
970+
end
971+
end
972+
@testset "Helpful errors" begin
973+
let g = (x, y) -> x - y
974+
# Test minimum N
975+
fixed_g1 = Fix{1}(g, 100)
976+
@test fixed_g1(40) == 100 - 40
977+
978+
# Test maximum N
979+
fixed_g2 = Fix{2}(g, 100)
980+
@test fixed_g2(150) == 150 - 100
981+
982+
# One over
983+
fixed_g3 = Fix{3}(g, 100)
984+
@test_throws ArgumentError("expected at least 2 arguments to `Fix{3}`, but got 1") fixed_g3(1)
985+
end
986+
end
987+
@testset "Type Stability and Inference" begin
988+
let h = (x, y) -> x / y
989+
fixed_h = Fix{2}(h, 2.0)
990+
@test @inferred(fixed_h(4.0)) == 2.0
991+
end
992+
end
993+
@testset "Interaction with varargs" begin
994+
vararg_f = (x, y, z...) -> x + 10 * y + sum(z; init=zero(x))
995+
fixed_vararg_f = Fix{2}(vararg_f, 6)
996+
997+
# Can call with variable number of arguments:
998+
@test fixed_vararg_f(1, 2, 3, 4) == 1 + 10 * 6 + sum((2, 3, 4))
999+
if VERSION >= v"1.7.0"
1000+
@inferred fixed_vararg_f(1, 2, 3, 4)
1001+
end
1002+
@test fixed_vararg_f(5) == 5 + 10 * 6
1003+
if VERSION >= v"1.7.0"
1004+
@inferred fixed_vararg_f(5)
1005+
end
1006+
end
1007+
@testset "Errors should propagate normally" begin
1008+
error_f = (x, y) -> sin(x * y)
1009+
fixed_error_f = Fix{2}(error_f, Inf)
1010+
@test_throws DomainError fixed_error_f(10)
1011+
end
1012+
@testset "Chaining Fix together" begin
1013+
f1 = Fix{1}(*, "1")
1014+
f2 = Fix{1}(f1, "2")
1015+
f3 = Fix{1}(f2, "3")
1016+
@test f3() == "123"
1017+
1018+
g1 = Fix{2}(*, "1")
1019+
g2 = Fix{2}(g1, "2")
1020+
g3 = Fix{2}(g2, "3")
1021+
@test g3("") == "123"
1022+
end
1023+
@testset "Zero arguments" begin
1024+
f = Fix{1}(x -> x, 'a')
1025+
@test f() == 'a'
1026+
end
1027+
@testset "Dummy-proofing" begin
1028+
@test_throws ArgumentError("expected `N` in `Fix{N}` to be integer greater than 0, but got 0") Fix{0}(>, 1)
1029+
@test_throws ArgumentError("expected type parameter in `Fix` to be `Int`, but got `0.5::Float64`") Fix{0.5}(>, 1)
1030+
@test_throws ArgumentError("expected type parameter in `Fix` to be `Int`, but got `1::UInt64`") Fix{UInt64(1)}(>, 1)
1031+
end
1032+
@testset "Specialize to structs not in `Base`" begin
1033+
struct MyStruct
1034+
x::Int
1035+
end
1036+
f = Fix{1}(MyStruct, 1)
1037+
@test f isa Fix{1,Type{MyStruct},Int}
1038+
end
1039+
end
1040+
end

0 commit comments

Comments
 (0)