Skip to content

Commit 11ab5d5

Browse files
authored
Merge pull request #56 from characat0/patch-2
Auth using environment variables
2 parents d0dee22 + 0a92b04 commit 11ab5d5

File tree

6 files changed

+137
-8
lines changed

6 files changed

+137
-8
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["@deyandyankov, @pebeto, and contributors"]
44
version = "0.6.0"
55

66
[deps]
7+
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
78
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
89
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
910
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
@@ -12,6 +13,7 @@ URIs = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4"
1213
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
1314

1415
[compat]
16+
Base64 = "1.0"
1517
HTTP = "1.0"
1618
JSON = "0.21"
1719
ShowCases = "0.1"

src/MLFlowClient.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ module MLFlowClient
1414
using Dates
1515
using UUIDs
1616
using HTTP
17+
using Base64
1718
using URIs
1819
using JSON
1920
using ShowCases

src/types/mlflow.jl

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,19 @@ Base type which defines location and version for MLFlow API service.
66
# Fields
77
- `apiroot::String`: API root URL, e.g. `http://localhost:5000/api`
88
- `apiversion::Union{Integer, AbstractFloat}`: used API version, e.g. `2.0`
9-
- `headers::Dict`: HTTP headers to be provided with the REST API requests (useful for
10-
authetication tokens) Default is `false`, using the REST API endpoint.
9+
- `headers::Dict`: HTTP headers to be provided with the REST API requests.
10+
- `username::Union{Nothing, String}`: username for basic authentication.
11+
- `password::Union{Nothing, String}`: password for basic authentication.
12+
13+
!!! warning
14+
You cannot provide an `Authorization` header when an `username` and `password` are
15+
provided. An error will be thrown in that case.
16+
17+
!!! note
18+
- If `MLFLOW_TRACKING_URI` is set, the provided `apiroot` will be ignored.
19+
- If `MLFLOW_TRACKING_USERNAME` is set, the provided `username` will be ignored.
20+
- If `MLFLOW_TRACKING_PASSWORD` is set, the provided `password` will be ignored.
21+
These indications will be displayed as warnings.
1122
1223
# Examples
1324
@@ -19,17 +30,49 @@ mlf = MLFlow()
1930
remote_url="https://<your-server>.cloud.databricks.com"; # address of your remote server
2031
mlf = MLFlow(remote_url, headers=Dict("Authorization" => "Bearer <your-secret-token>"))
2132
```
22-
2333
"""
2434
struct MLFlow
2535
apiroot::String
2636
apiversion::AbstractFloat
2737
headers::Dict
38+
username::Union{Nothing,String}
39+
password::Union{Nothing,String}
40+
41+
function MLFlow(apiroot, apiversion, headers, username, password)
42+
if haskey(ENV, "MLFLOW_TRACKING_URI")
43+
@warn "The provided apiroot will be ignored as MLFLOW_TRACKING_URI is set."
44+
apiroot = ENV["MLFLOW_TRACKING_URI"]
45+
end
46+
47+
if haskey(ENV, "MLFLOW_TRACKING_USERNAME")
48+
@warn "The provided username will be ignored as MLFLOW_TRACKING_USERNAME is set."
49+
username = ENV["MLFLOW_TRACKING_USERNAME"]
50+
end
51+
52+
if haskey(ENV, "MLFLOW_TRACKING_PASSWORD")
53+
@warn "The provided password will be ignored as MLFLOW_TRACKING_PASSWORD is set."
54+
password = ENV["MLFLOW_TRACKING_PASSWORD"]
55+
end
56+
57+
if username |> !isnothing && password |> !isnothing
58+
if haskey(headers, "Authorization")
59+
error("You cannot provide an Authorization header when an username and password are provided.")
60+
end
61+
encoded_credentials = Base64.base64encode("$(username):$(password)")
62+
headers =
63+
merge(headers, Dict("Authorization" => "Basic $(encoded_credentials)"))
64+
end
65+
new(apiroot, apiversion, headers, username, password)
66+
end
2867
end
29-
MLFlow(apiroot; apiversion=2.0, headers=Dict()) = MLFlow(apiroot, apiversion, headers)
30-
MLFlow(; apiroot="http://localhost:5000/api", apiversion=2.0, headers=Dict()) =
31-
MLFlow((haskey(ENV, "MLFLOW_TRACKING_URI") ?
32-
ENV["MLFLOW_TRACKING_URI"] : apiroot), apiversion, headers)
68+
MLFlow(apiroot::String; apiversion::AbstractFloat=2.0, headers::Dict=Dict(),
69+
username::Union{Nothing,String}=nothing,
70+
password::Union{Nothing,String}=nothing)::MLFlow =
71+
MLFlow(apiroot, apiversion, headers, username, password)
72+
MLFlow(; apiroot::String="http://localhost:5000/api", apiversion::AbstractFloat=2.0,
73+
headers::Dict=Dict(), username::Union{Nothing,String}=nothing,
74+
password::Union{Nothing,String}=nothing)::MLFlow =
75+
MLFlow(apiroot, apiversion, headers, username, password)
3376

3477
Base.show(io::IO, t::MLFlow) =
3578
show(io, ShowCase(t, [:apiroot, :apiversion], new_lines=true))

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ end
44

55
include("base.jl")
66

7+
include("types/mlflow.jl")
8+
79
include("services/run.jl")
810
include("services/misc.jl")
911
include("services/logger.jl")

test/services/user.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,14 @@ end
3636
updateuserpassword(getmlfinstance(encoded_credentials), "missy", "ana")
3737
encoded_credentials = Base64.base64encode("$(user.username):ana")
3838

39-
@test_nowarn searchexperiments(getmlfinstance(encoded_credentials))
39+
@test begin
40+
try
41+
searchexperiments(getmlfinstance(encoded_credentials))
42+
true
43+
catch
44+
false
45+
end
46+
end
4047
deleteuser(mlf, user.username)
4148
end
4249

test/types/mlflow.jl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
@testset verbose = true "instantiate mlflow" begin
2+
mlflow_tracking_uri = ENV["MLFLOW_TRACKING_URI"]
3+
4+
@testset "using default constructor" begin
5+
delete!(ENV, "MLFLOW_TRACKING_URI")
6+
7+
instance = MLFlow("test", 2.0, Dict(), nothing, nothing)
8+
9+
@test instance.apiroot == "test"
10+
@test instance.apiversion == 2.0
11+
@test instance.headers == Dict()
12+
@test isnothing(instance.username)
13+
@test isnothing(instance.password)
14+
15+
ENV["MLFLOW_TRACKING_URI"] = mlflow_tracking_uri
16+
end
17+
18+
@testset "using apiroot-only constructor" begin
19+
delete!(ENV, "MLFLOW_TRACKING_URI")
20+
21+
instance = MLFlow("test")
22+
23+
@test instance.apiroot == "test"
24+
@test instance.apiversion == 2.0
25+
@test instance.headers == Dict()
26+
@test isnothing(instance.username)
27+
@test isnothing(instance.password)
28+
29+
ENV["MLFLOW_TRACKING_URI"] = mlflow_tracking_uri
30+
end
31+
32+
@testset "using constructor with keyword arguments" begin
33+
delete!(ENV, "MLFLOW_TRACKING_URI")
34+
35+
instance = MLFlow(; username="test", password="test")
36+
37+
@test instance.apiroot == "http://localhost:5000/api"
38+
@test instance.apiversion == 2.0
39+
@test haskey(instance.headers, "Authorization")
40+
@test instance.username == "test"
41+
@test instance.password == "test"
42+
43+
ENV["MLFLOW_TRACKING_URI"] = mlflow_tracking_uri
44+
end
45+
46+
@testset "using env variables" begin
47+
mlflow_tracking_username =
48+
haskey(ENV, "MLFLOW_TRACKING_USERNAME") ? ENV["MLFLOW_TRACKING_USERNAME"] : nothing
49+
mlflow_tracking_password =
50+
haskey(ENV, "MLFLOW_TRACKING_PASSWORD") ? ENV["MLFLOW_TRACKING_PASSWORD"] : nothing
51+
52+
ENV["MLFLOW_TRACKING_USERNAME"] = "test"
53+
ENV["MLFLOW_TRACKING_PASSWORD"] = "test"
54+
55+
@test_logs (:warn, "The provided apiroot will be ignored as MLFLOW_TRACKING_URI is set.") (:warn, "The provided username will be ignored as MLFLOW_TRACKING_USERNAME is set.") (:warn, "The provided password will be ignored as MLFLOW_TRACKING_PASSWORD is set.") MLFlow()
56+
57+
if !isnothing(mlflow_tracking_username)
58+
ENV["MLFLOW_TRACKING_USERNAME"] = mlflow_tracking_username
59+
else
60+
delete!(ENV, "MLFLOW_TRACKING_USERNAME")
61+
end
62+
if !isnothing(mlflow_tracking_password)
63+
ENV["MLFLOW_TRACKING_PASSWORD"] = mlflow_tracking_password
64+
else
65+
delete!(ENV, "MLFLOW_TRACKING_PASSWORD")
66+
end
67+
end
68+
69+
@testset "defining username, password and authorization header" begin
70+
encoded_credentials = Base64.base64encode("test:test")
71+
@test_throws ErrorException MLFlow(; username="test", password="test",
72+
headers=Dict("Authorization" => "Basic $encoded_credentials"))
73+
end
74+
end

0 commit comments

Comments
 (0)