From efff5bca9d3e0f4d88b894cc15b9fcac4ca5cc7e Mon Sep 17 00:00:00 2001 From: Anuj Agrawal Date: Thu, 14 Nov 2024 22:49:29 +0530 Subject: [PATCH] Added unit tests for safe lua lifted libraries Signed-off-by: Anuj Agrawal --- pkg/util/lifted/lua/loadlib_safe_test.go | 73 +++++++ pkg/util/lifted/lua/oslib_safe_test.go | 238 ++++++++++++++++++----- 2 files changed, 264 insertions(+), 47 deletions(-) create mode 100644 pkg/util/lifted/lua/loadlib_safe_test.go diff --git a/pkg/util/lifted/lua/loadlib_safe_test.go b/pkg/util/lifted/lua/loadlib_safe_test.go new file mode 100644 index 000000000000..6cb2427ae6f3 --- /dev/null +++ b/pkg/util/lifted/lua/loadlib_safe_test.go @@ -0,0 +1,73 @@ +/* +Copyright 2024 The Karmada Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package lua + +import ( + "testing" + + "github.com/stretchr/testify/assert" + lua "github.com/yuin/gopher-lua" +) + +func TestLoLoaderPreload(t *testing.T) { + L := lua.NewState() + defer L.Close() + OpenPackage(L) + + t.Run("nonexistent module", func(t *testing.T) { + L.Push(L.NewFunction(func(L *lua.LState) int { + L.Push(lua.LString("nonexistent")) + return loLoaderPreload(L) + })) + assert.NoError(t, L.PCall(0, 1, nil)) + assert.Equal(t, "no field package.preload['nonexistent']", L.ToString(-1)) + }) + + t.Run("existing module", func(t *testing.T) { + L.GetField(L.GetField(L.Get(lua.EnvironIndex), "package"), "preload").(*lua.LTable).RawSetString("testmod", L.NewFunction(func(_ *lua.LState) int { return 0 })) + L.Push(L.NewFunction(func(L *lua.LState) int { + L.Push(lua.LString("testmod")) + return loLoaderPreload(L) + })) + assert.NoError(t, L.PCall(0, 1, nil)) + assert.Equal(t, lua.LTFunction, L.Get(-1).Type()) + }) +} + +func TestLoLoadLib(t *testing.T) { + L := lua.NewState() + defer L.Close() + + L.Push(L.NewFunction(func(L *lua.LState) int { + return loLoadLib(L) + })) + err := L.PCall(0, 0, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "loadlib is not supported") +} + +func TestOpenPackage(t *testing.T) { + L := lua.NewState() + defer L.Close() + + assert.Equal(t, 1, OpenPackage(L)) + pkg := L.Get(-1) + assert.Equal(t, lua.LTTable, pkg.Type()) + assert.Equal(t, lua.LTTable, L.GetField(pkg, "preload").Type()) + assert.Equal(t, lua.LTTable, L.GetField(pkg, "loaders").Type()) + assert.Equal(t, lua.LTTable, L.GetField(pkg, "loaded").Type()) +} diff --git a/pkg/util/lifted/lua/oslib_safe_test.go b/pkg/util/lifted/lua/oslib_safe_test.go index d254e38c0da4..22e85a8e2dab 100644 --- a/pkg/util/lifted/lua/oslib_safe_test.go +++ b/pkg/util/lifted/lua/oslib_safe_test.go @@ -20,6 +20,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" lua "github.com/yuin/gopher-lua" ) @@ -27,36 +28,121 @@ func TestOpenSafeOs(t *testing.T) { L := lua.NewState() defer L.Close() - expect := 1 + assert.Equal(t, 1, OpenSafeOs(L)) - actual := OpenSafeOs(L) - - if actual != expect { - t.Errorf("OpenSafeOs returned %v, expected %v", actual, expect) - } + // Test registered functions + osTable := L.GetGlobal(lua.OsLibName) + assert.Equal(t, lua.LTTable, osTable.Type()) + assert.NotNil(t, L.GetField(osTable, "time")) + assert.NotNil(t, L.GetField(osTable, "date")) } func TestSafeOsLoader(t *testing.T) { L := lua.NewState() defer L.Close() - expect := 1 - - actual := SafeOsLoader(L) + assert.Equal(t, 1, SafeOsLoader(L)) - if actual != expect { - t.Errorf("SafeOsLoader returned %v, expected %v", actual, expect) - } + // Verify module table is returned + result := L.Get(-1) + assert.Equal(t, lua.LTTable, result.Type()) } func TestOsTime(t *testing.T) { L := lua.NewState() defer L.Close() - actual := osTime(L) - if actual != 1 { - t.Errorf("osTime returned %v, expected %v", actual, 1) - } + t.Run("no args", func(t *testing.T) { + assert.Equal(t, 1, osTime(L)) + result := L.Get(-1) + assert.Equal(t, lua.LTNumber, result.Type()) + + now := time.Now().Unix() + timestamp := int64(result.(lua.LNumber)) + assert.InDelta(t, now, timestamp, 1.0) + }) + + t.Run("with table", func(t *testing.T) { + L.Push(L.NewFunction(func(L *lua.LState) int { + tbl := L.NewTable() + tbl.RawSetString("year", lua.LNumber(2024)) + tbl.RawSetString("month", lua.LNumber(3)) + tbl.RawSetString("day", lua.LNumber(14)) + tbl.RawSetString("hour", lua.LNumber(12)) + tbl.RawSetString("min", lua.LNumber(30)) + tbl.RawSetString("sec", lua.LNumber(45)) + tbl.RawSetString("isdst", lua.LBool(false)) + L.Push(tbl) + return osTime(L) + })) + + assert.NoError(t, L.PCall(0, 1, nil)) + result := L.Get(-1) + assert.Equal(t, lua.LTNumber, result.Type()) + + expectedTime := time.Date(2024, 3, 14, 12, 30, 45, 0, time.Local) + assert.Equal(t, expectedTime.Unix(), int64(result.(lua.LNumber))) + }) +} + +func TestOsDate(t *testing.T) { + L := lua.NewState() + defer L.Close() + + t.Run("default format", func(t *testing.T) { + L.Push(L.NewFunction(func(L *lua.LState) int { + return osDate(L) + })) + assert.NoError(t, L.PCall(0, 1, nil)) + result := L.Get(-1) + assert.Equal(t, lua.LTString, result.Type()) + }) + + t.Run("UTC time", func(t *testing.T) { + L.Push(L.NewFunction(func(L *lua.LState) int { + L.Push(lua.LString("!%Y-%m-%d")) + return osDate(L) + })) + assert.NoError(t, L.PCall(0, 1, nil)) + result := L.Get(-1) + assert.Equal(t, lua.LTString, result.Type()) + + now := time.Now().UTC() + expected := now.Format("2006-01-02") + assert.Equal(t, expected, string(result.(lua.LString))) + }) + + t.Run("custom timestamp", func(t *testing.T) { + timestamp := time.Date(2024, 3, 14, 15, 30, 45, 0, time.Local) + L.Push(L.NewFunction(func(L *lua.LState) int { + L.Push(lua.LString("%Y-%m-%d %H:%M:%S")) + L.Push(lua.LNumber(timestamp.Unix())) + return osDate(L) + })) + assert.NoError(t, L.PCall(0, 1, nil)) + result := L.Get(-1) + assert.Equal(t, "2024-03-14 15:30:45", string(result.(lua.LString))) + }) + + t.Run("table format", func(t *testing.T) { + timestamp := time.Now().Unix() + L.Push(L.NewFunction(func(L *lua.LState) int { + L.Push(lua.LString("*t")) + L.Push(lua.LNumber(timestamp)) + return osDate(L) + })) + assert.NoError(t, L.PCall(0, 1, nil)) + + result := L.Get(-1).(*lua.LTable) + tm := time.Unix(timestamp, 0) + assert.Equal(t, float64(tm.Year()), float64(result.RawGetString("year").(lua.LNumber))) + assert.Equal(t, float64(tm.Month()), float64(result.RawGetString("month").(lua.LNumber))) + assert.Equal(t, float64(tm.Day()), float64(result.RawGetString("day").(lua.LNumber))) + assert.Equal(t, float64(tm.Hour()), float64(result.RawGetString("hour").(lua.LNumber))) + assert.Equal(t, float64(tm.Minute()), float64(result.RawGetString("min").(lua.LNumber))) + assert.Equal(t, float64(tm.Second()), float64(result.RawGetString("sec").(lua.LNumber))) + assert.Equal(t, float64(tm.Weekday()+1), float64(result.RawGetString("wday").(lua.LNumber))) + }) } func TestGetIntField(t *testing.T) { @@ -64,33 +150,40 @@ func TestGetIntField(t *testing.T) { tb.RawSetString("min", lua.LNumber(15)) tb.RawSetString("day", lua.LString("a")) - // Test with valid key - expected := 15 - if v := getIntField(tb, "min", 0); v != expected { - t.Errorf("getIntField(tb, \"min\", 0) returned %d, expected %d", v, expected) - } + t.Run("valid number", func(t *testing.T) { + assert.Equal(t, 15, getIntField(tb, "min", 0)) + }) - // Test with non-number value - expected = 0 - if v := getIntField(tb, "day", 0); v != expected { - t.Errorf("getIntField(tb, \"day\", 0) returned %d, expected %d", v, expected) - } + t.Run("non-number value", func(t *testing.T) { + assert.Equal(t, 0, getIntField(tb, "day", 0)) + }) + + t.Run("missing key", func(t *testing.T) { + assert.Equal(t, 42, getIntField(tb, "nonexistent", 42)) + }) } func TestGetBoolField(t *testing.T) { tb := &lua.LTable{} - tb.RawSetString("min", lua.LNumber(15)) - tb.RawSetString("isdst", lua.LBool(false)) + tb.RawSetString("flag1", lua.LBool(true)) + tb.RawSetString("flag2", lua.LBool(false)) + tb.RawSetString("notbool", lua.LNumber(1)) - // Test with valid key - if v := getBoolField(tb, "isdst", false); v { - t.Errorf("getBoolField(tb, \"isdst\", false) returned %v, expected %v", v, false) - } + t.Run("true value", func(t *testing.T) { + assert.True(t, getBoolField(tb, "flag1", false)) + }) - // Test with non-number value - if v := getBoolField(tb, "min", true); !v { - t.Errorf("getBoolField(tb, \"min\", true) returned %v, expected %v", v, true) - } + t.Run("false value", func(t *testing.T) { + assert.False(t, getBoolField(tb, "flag2", true)) + }) + + t.Run("non-bool value", func(t *testing.T) { + assert.True(t, getBoolField(tb, "notbool", true)) + }) + + t.Run("missing key", func(t *testing.T) { + assert.True(t, getBoolField(tb, "nonexistent", true)) + }) } func TestStrftime(t *testing.T) { @@ -101,24 +194,75 @@ func TestStrftime(t *testing.T) { want string }{ { - name: "character in cDateFlagToGo", - time: time.Date(2022, time.February, 16, 15, 45, 27, 0, time.UTC), - cfmt: "%Y/%m/%d %H:%M:%S", - want: "2022/02/16 15:45:27", + name: "basic format", + time: time.Date(2024, 3, 14, 15, 30, 45, 0, time.UTC), + cfmt: "%Y-%m-%d %H:%M:%S", + want: "2024-03-14 15:30:45", + }, + { + name: "escaped percent", + time: time.Date(2024, 3, 14, 15, 30, 45, 0, time.UTC), + cfmt: "%%Y", + want: "%Y", + }, + { + name: "weekday", + time: time.Date(2024, 3, 14, 15, 30, 45, 0, time.UTC), + cfmt: "%w", + want: "4", }, { - name: "character not in cDateFlagToGo", - time: time.Date(2022, time.February, 16, 15, 45, 27, 0, time.FixedZone("", -8*60*60)), - cfmt: "%A, %w %B %Y %I:%M:%S %p %Z %e", - want: "Wednesday, 3 February 2022 03:45:27 PM -0800 %e", + name: "month names", + time: time.Date(2024, 3, 14, 15, 30, 45, 0, time.UTC), + cfmt: "%B %b", + want: "March Mar", + }, + { + name: "12-hour clock", + time: time.Date(2024, 3, 14, 15, 30, 45, 0, time.UTC), + cfmt: "%I %p", + want: "03 PM", + }, + { + name: "timezone", + time: time.Date(2024, 3, 14, 15, 30, 45, 0, time.UTC), + cfmt: "%z %Z", + want: "+0000 UTC", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := strftime(tt.time, tt.cfmt); got != tt.want { - t.Errorf("strftime(%v, %q) got %q, want %q", tt.time, tt.cfmt, got, tt.want) - } + assert.Equal(t, tt.want, strftime(tt.time, tt.cfmt)) }) } } + +func TestFlagScanner(t *testing.T) { + t.Run("basic scanning", func(t *testing.T) { + scanner := newFlagScanner('%', "", "", "test%%string") + + // Scan all characters before String() is valid + for c, eos := scanner.Next(); !eos; c, eos = scanner.Next() { + scanner.AppendChar(c) + } + + assert.Equal(t, "test%string", scanner.String()) + }) + + t.Run("empty string", func(t *testing.T) { + scanner := newFlagScanner('%', "", "", "") + c, eos := scanner.Next() + assert.Equal(t, byte(0), c) + assert.True(t, eos) + }) + + t.Run("single flag", func(t *testing.T) { + scanner := newFlagScanner('%', "<", ">", "%d") + scanner.Next() // Skip first % + c, eos := scanner.Next() + assert.Equal(t, byte('d'), c) + assert.False(t, eos) + assert.True(t, scanner.HasFlag) + }) +}