Skip to content

Commit ca69be0

Browse files
committed
Support setting metatable for Lua builtin types.
Closes #445
1 parent 16951e3 commit ca69be0

File tree

9 files changed

+215
-19
lines changed

9 files changed

+215
-19
lines changed

src/function.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use std::{mem, ptr, slice};
55
use crate::error::{Error, Result};
66
use crate::state::Lua;
77
use crate::table::Table;
8-
use crate::types::{Callback, MaybeSend, ValueRef};
8+
use crate::types::{Callback, LuaType, MaybeSend, ValueRef};
99
use crate::util::{
1010
assert_stack, check_stack, linenumber_to_usize, pop_error, ptr_to_lossy_str, ptr_to_str, StackGuard,
1111
};
@@ -588,6 +588,10 @@ impl IntoLua for WrappedAsyncFunction {
588588
}
589589
}
590590

591+
impl LuaType for Function {
592+
const TYPE_ID: c_int = ffi::LUA_TFUNCTION;
593+
}
594+
591595
#[cfg(test)]
592596
mod assertions {
593597
use super::*;

src/state.rs

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use crate::string::String;
1818
use crate::table::Table;
1919
use crate::thread::Thread;
2020
use crate::types::{
21-
AppDataRef, AppDataRefMut, ArcReentrantMutexGuard, Integer, MaybeSend, Number, ReentrantMutex,
21+
AppDataRef, AppDataRefMut, ArcReentrantMutexGuard, Integer, LuaType, MaybeSend, Number, ReentrantMutex,
2222
ReentrantMutexGuard, RegistryKey, VmState, XRc, XWeak,
2323
};
2424
use crate::userdata::{AnyUserData, UserData, UserDataProxy, UserDataRegistry, UserDataStorage};
@@ -1337,24 +1337,66 @@ impl Lua {
13371337
unsafe { self.lock().make_userdata(UserDataStorage::new(ud)) }
13381338
}
13391339

1340-
/// Sets the metatable for a Luau builtin vector type.
1341-
#[cfg(any(feature = "luau", doc))]
1342-
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
1343-
pub fn set_vector_metatable(&self, metatable: Option<Table>) {
1340+
/// Sets the metatable for a Lua builtin type.
1341+
///
1342+
/// The metatable will be shared by all values of the given type.
1343+
///
1344+
/// # Examples
1345+
///
1346+
/// Change metatable for Lua boolean type:
1347+
///
1348+
/// ```
1349+
/// # use mlua::{Lua, Result, Function};
1350+
/// # fn main() -> Result<()> {
1351+
/// # let lua = Lua::new();
1352+
/// let mt = lua.create_table()?;
1353+
/// mt.set("__tostring", lua.create_function(|_, b: bool| Ok(if b { 2 } else { 0 }))?)?;
1354+
/// lua.set_type_metatable::<bool>(Some(mt));
1355+
/// lua.load("assert(tostring(true) == '2')").exec()?;
1356+
/// # Ok(())
1357+
/// # }
1358+
/// ```
1359+
#[allow(private_bounds)]
1360+
pub fn set_type_metatable<T: LuaType>(&self, metatable: Option<Table>) {
13441361
let lua = self.lock();
13451362
let state = lua.state();
13461363
unsafe {
13471364
let _sg = StackGuard::new(state);
13481365
assert_stack(state, 2);
13491366

1350-
#[cfg(not(feature = "luau-vector4"))]
1351-
ffi::lua_pushvector(state, 0., 0., 0.);
1352-
#[cfg(feature = "luau-vector4")]
1353-
ffi::lua_pushvector(state, 0., 0., 0., 0.);
1367+
match T::TYPE_ID {
1368+
ffi::LUA_TBOOLEAN => {
1369+
ffi::lua_pushboolean(state, 0);
1370+
}
1371+
ffi::LUA_TLIGHTUSERDATA => {
1372+
ffi::lua_pushlightuserdata(state, ptr::null_mut());
1373+
}
1374+
ffi::LUA_TNUMBER => {
1375+
ffi::lua_pushnumber(state, 0.);
1376+
}
1377+
#[cfg(feature = "luau")]
1378+
ffi::LUA_TVECTOR => {
1379+
#[cfg(not(feature = "luau-vector4"))]
1380+
ffi::lua_pushvector(state, 0., 0., 0.);
1381+
#[cfg(feature = "luau-vector4")]
1382+
ffi::lua_pushvector(state, 0., 0., 0., 0.);
1383+
}
1384+
ffi::LUA_TSTRING => {
1385+
ffi::lua_pushstring(state, b"\0" as *const u8 as *const _);
1386+
}
1387+
ffi::LUA_TFUNCTION => match self.load("function() end").eval::<Function>() {
1388+
Ok(func) => lua.push_ref(&func.0),
1389+
Err(_) => return,
1390+
},
1391+
ffi::LUA_TTHREAD => {
1392+
ffi::lua_newthread(state);
1393+
}
1394+
_ => {}
1395+
}
13541396
match metatable {
13551397
Some(metatable) => lua.push_ref(&metatable.0),
13561398
None => ffi::lua_pushnil(state),
1357-
};
1399+
}
13581400
ffi::lua_setmetatable(state, -2);
13591401
}
13601402
}

src/string.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::borrow::Borrow;
22
use std::hash::{Hash, Hasher};
33
use std::ops::Deref;
4-
use std::os::raw::c_void;
4+
use std::os::raw::{c_int, c_void};
55
use std::string::String as StdString;
66
use std::{cmp, fmt, slice, str};
77

@@ -13,7 +13,7 @@ use {
1313

1414
use crate::error::{Error, Result};
1515
use crate::state::Lua;
16-
use crate::types::ValueRef;
16+
use crate::types::{LuaType, ValueRef};
1717

1818
/// Handle to an internal Lua string.
1919
///
@@ -327,6 +327,10 @@ impl<'a> IntoIterator for BorrowedBytes<'a> {
327327
}
328328
}
329329

330+
impl LuaType for String {
331+
const TYPE_ID: c_int = ffi::LUA_TSTRING;
332+
}
333+
330334
#[cfg(test)]
331335
mod assertions {
332336
use super::*;

src/table.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::collections::HashSet;
22
use std::fmt;
33
use std::marker::PhantomData;
4-
use std::os::raw::c_void;
4+
use std::os::raw::{c_int, c_void};
55
use std::string::String as StdString;
66

77
#[cfg(feature = "serialize")]
@@ -15,7 +15,7 @@ use crate::error::{Error, Result};
1515
use crate::function::Function;
1616
use crate::state::{LuaGuard, RawLua};
1717
use crate::traits::ObjectLike;
18-
use crate::types::{Integer, ValueRef};
18+
use crate::types::{Integer, LuaType, ValueRef};
1919
use crate::util::{assert_stack, check_stack, StackGuard};
2020
use crate::value::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, Nil, Value};
2121

@@ -961,6 +961,10 @@ impl Serialize for Table {
961961
}
962962
}
963963

964+
impl LuaType for Table {
965+
const TYPE_ID: c_int = ffi::LUA_TTABLE;
966+
}
967+
964968
#[cfg(feature = "serialize")]
965969
impl<'a> SerializableTable<'a> {
966970
#[inline]

src/thread.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::error::{Error, Result};
44
#[allow(unused)]
55
use crate::state::Lua;
66
use crate::state::RawLua;
7-
use crate::types::{ValueRef, VmState};
7+
use crate::types::{LuaType, ValueRef, VmState};
88
use crate::util::{check_stack, error_traceback_thread, pop_error, StackGuard};
99
use crate::value::{FromLuaMulti, IntoLuaMulti};
1010

@@ -372,6 +372,10 @@ impl PartialEq for Thread {
372372
}
373373
}
374374

375+
impl LuaType for Thread {
376+
const TYPE_ID: c_int = ffi::LUA_TTHREAD;
377+
}
378+
375379
#[cfg(feature = "async")]
376380
impl<A, R> AsyncThread<A, R> {
377381
#[inline]

src/types.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ pub type Integer = ffi::lua_Integer;
2727
/// Type of Lua floating point numbers.
2828
pub type Number = ffi::lua_Number;
2929

30-
// Represents different subtypes wrapped to AnyUserData
30+
// Represents different subtypes wrapped in AnyUserData
3131
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
3232
pub(crate) enum SubtypeId {
3333
None,
@@ -115,6 +115,22 @@ impl<T> MaybeSend for T {}
115115

116116
pub(crate) struct DestructedUserdata;
117117

118+
pub(crate) trait LuaType {
119+
const TYPE_ID: c_int;
120+
}
121+
122+
impl LuaType for bool {
123+
const TYPE_ID: c_int = ffi::LUA_TBOOLEAN;
124+
}
125+
126+
impl LuaType for Number {
127+
const TYPE_ID: c_int = ffi::LUA_TNUMBER;
128+
}
129+
130+
impl LuaType for LightUserData {
131+
const TYPE_ID: c_int = ffi::LUA_TLIGHTUSERDATA;
132+
}
133+
118134
mod app_data;
119135
mod registry_key;
120136
mod sync;

src/types/vector.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ use std::fmt;
33
#[cfg(all(any(feature = "luau", doc), feature = "serialize"))]
44
use serde::ser::{Serialize, SerializeTupleStruct, Serializer};
55

6+
use super::LuaType;
7+
68
/// A Luau vector type.
79
///
810
/// By default vectors are 3-dimensional, but can be 4-dimensional
@@ -84,3 +86,12 @@ impl PartialEq<[f32; Self::SIZE]> for Vector {
8486
self.0 == *other
8587
}
8688
}
89+
90+
impl LuaType for Vector {
91+
#[cfg(feature = "luau")]
92+
const TYPE_ID: i32 = ffi::LUA_TVECTOR;
93+
94+
// This is a dummy value, as `Vector` is supported only by Luau
95+
#[cfg(not(feature = "luau"))]
96+
const TYPE_ID: i32 = ffi::LUA_TNONE;
97+
}

tests/luau.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ fn test_vector_metatable() -> Result<()> {
194194
)
195195
.eval::<Table>()?;
196196
vector_mt.set_metatable(Some(vector_mt.clone()));
197-
lua.set_vector_metatable(Some(vector_mt.clone()));
197+
lua.set_type_metatable::<Vector>(Some(vector_mt.clone()));
198198
lua.globals().set("Vector3", vector_mt)?;
199199

200200
let compiler = Compiler::new().set_vector_lib("Vector3").set_vector_ctor("new");

tests/types.rs

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::os::raw::c_void;
22

3-
use mlua::{Function, LightUserData, Lua, Result};
3+
use mlua::{Function, LightUserData, Lua, Number, Result, String as LuaString, Thread};
44

55
#[test]
66
fn test_lightuserdata() -> Result<()> {
@@ -24,3 +24,114 @@ fn test_lightuserdata() -> Result<()> {
2424

2525
Ok(())
2626
}
27+
28+
#[test]
29+
fn test_boolean_type_metatable() -> Result<()> {
30+
let lua = Lua::new();
31+
32+
let mt = lua.create_table()?;
33+
mt.set("__add", Function::wrap(|_, (a, b): (bool, bool)| Ok(a || b)))?;
34+
lua.set_type_metatable::<bool>(Some(mt));
35+
36+
lua.load(r#"assert(true + true == true)"#).exec().unwrap();
37+
lua.load(r#"assert(true + false == true)"#).exec().unwrap();
38+
lua.load(r#"assert(false + true == true)"#).exec().unwrap();
39+
lua.load(r#"assert(false + false == false)"#).exec().unwrap();
40+
41+
Ok(())
42+
}
43+
44+
#[test]
45+
fn test_lightuserdata_type_metatable() -> Result<()> {
46+
let lua = Lua::new();
47+
48+
let mt = lua.create_table()?;
49+
mt.set(
50+
"__add",
51+
Function::wrap(|_, (a, b): (LightUserData, LightUserData)| {
52+
Ok(LightUserData((a.0 as usize + b.0 as usize) as *mut c_void))
53+
}),
54+
)?;
55+
lua.set_type_metatable::<LightUserData>(Some(mt));
56+
57+
let res = lua
58+
.load(
59+
r#"
60+
local a, b = ...
61+
return a + b
62+
"#,
63+
)
64+
.call::<LightUserData>((
65+
LightUserData(42 as *mut c_void),
66+
LightUserData(100 as *mut c_void),
67+
))
68+
.unwrap();
69+
assert_eq!(res, LightUserData(142 as *mut c_void));
70+
71+
Ok(())
72+
}
73+
74+
#[test]
75+
fn test_number_type_metatable() -> Result<()> {
76+
let lua = Lua::new();
77+
78+
let mt = lua.create_table()?;
79+
mt.set("__call", Function::wrap(|_, (n1, n2): (f64, f64)| Ok(n1 * n2)))?;
80+
lua.set_type_metatable::<Number>(Some(mt));
81+
lua.load(r#"assert((1.5)(3.0) == 4.5)"#).exec().unwrap();
82+
lua.load(r#"assert((5)(5) == 25)"#).exec().unwrap();
83+
84+
Ok(())
85+
}
86+
87+
#[test]
88+
fn test_string_type_metatable() -> Result<()> {
89+
let lua = Lua::new();
90+
91+
let mt = lua.create_table()?;
92+
mt.set(
93+
"__add",
94+
Function::wrap(|_, (a, b): (LuaString, LuaString)| Ok(format!("{}{}", a.to_str()?, b.to_str()?))),
95+
)?;
96+
lua.set_type_metatable::<LuaString>(Some(mt));
97+
98+
lua.load(r#"assert(("foo" + "bar") == "foobar")"#).exec().unwrap();
99+
100+
Ok(())
101+
}
102+
103+
#[test]
104+
fn test_function_type_metatable() -> Result<()> {
105+
let lua = Lua::new();
106+
107+
let mt = lua.create_table()?;
108+
mt.set(
109+
"__index",
110+
Function::wrap(|_, (_, key): (Function, String)| Ok(format!("function.{key}"))),
111+
)?;
112+
lua.set_type_metatable::<Function>(Some(mt));
113+
114+
lua.load(r#"assert((function() end).foo == "function.foo")"#)
115+
.exec()
116+
.unwrap();
117+
118+
Ok(())
119+
}
120+
121+
#[test]
122+
fn test_thread_type_metatable() -> Result<()> {
123+
let lua = Lua::new();
124+
125+
let mt = lua.create_table()?;
126+
mt.set(
127+
"__index",
128+
Function::wrap(|_, (_, key): (Thread, String)| Ok(format!("thread.{key}"))),
129+
)?;
130+
lua.set_type_metatable::<Thread>(Some(mt));
131+
132+
lua.load(r#"assert((coroutine.create(function() end)).foo == "thread.foo")"#)
133+
.exec()
134+
.unwrap();
135+
136+
Ok(())
137+
}

0 commit comments

Comments
 (0)