Skip to content

Commit

Permalink
Update pyo3 to properly support py 3.13
Browse files Browse the repository at this point in the history
  • Loading branch information
vicky5124 committed Jan 30, 2025
1 parent 060dbb5 commit 0f9a7bf
Show file tree
Hide file tree
Showing 16 changed files with 417 additions and 146 deletions.
18 changes: 9 additions & 9 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ websockets-rustls-native-roots = ["tokio-websockets/rustls-native-roots", "_rust
websockets-rustls-webpki-roots = ["tokio-websockets/rustls-webpki-roots", "_rustls-webpki-roots", "_websockets"]
websockets-native-tls = ["tokio-websockets/native-tls", "_native-tls", "_websockets"]

python = ["pyo3", "pyo3-asyncio", "pyo3-log", "pythonize", "log", "paste", "macro_rules_attribute", "parking_lot"]
python = ["pyo3", "pyo3-async-runtimes", "pyo3-log", "pythonize", "log", "paste", "macro_rules_attribute", "parking_lot"]

[package.metadata.docs.rs]
features = ["tungstenite-rustls-webpki-roots", "twilight", "serenity", "songbird", "macros", "python"]
Expand Down Expand Up @@ -121,21 +121,21 @@ version = "0.16"
optional = true

[dependencies.pyo3]
version = "0.20"
features = ["extension-module"]
version = "0.23"
features = ["extension-module", "py-clone"]
optional = true

[dependencies.pythonize]
version = "0.20"
version = "0.23"
optional = true

[dependencies.pyo3-asyncio]
version = "0.20"
[dependencies.pyo3-async-runtimes]
version = "0.23"
features = ["attributes", "tokio-runtime"]
optional = true

[dependencies.pyo3-log]
version = "0.9"
version = "0.12"
optional = true

[dependencies.log]
Expand All @@ -157,8 +157,8 @@ optional = true

[dependencies.macros-dep]
package = "lavalink_rs_macros"
#version = "0.2"
path = "./lavalink_rs_macros"
version = "0.2"
#path = "./lavalink_rs_macros"
optional = true


Expand Down
249 changes: 249 additions & 0 deletions reconnect.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
diff --git a/examples/py_hikari_lightbulb/lavalink_voice.py b/examples/py_hikari_lightbulb/lavalink_voice.py
index 75ce9bd..ee81b75 100644
--- a/examples/py_hikari_lightbulb/lavalink_voice.py
+++ b/examples/py_hikari_lightbulb/lavalink_voice.py
@@ -11,20 +11,17 @@ from lavalink_rs.model.player import ConnectionInfo


class LavalinkVoice(VoiceConnection):
- __slots__ = [
+ __slots__ = (
"lavalink",
"player",
- "__channel_id",
- "__guild_id",
"__session_id",
- "__is_alive",
- "__shard_id",
- "__on_close",
"__owner",
- ]
+ "__should_disconnect",
+ )
lavalink: LavalinkClient
player: PlayerContext

+
def __init__(
self,
lavalink_client: LavalinkClient,
@@ -33,68 +30,56 @@ class LavalinkVoice(VoiceConnection):
channel_id: hikari.Snowflake,
guild_id: hikari.Snowflake,
session_id: str,
- is_alive: bool,
shard_id: int,
owner: VoiceComponent,
- on_close: t.Any,
) -> None:
+ super().__init__(channel_id, guild_id, shard_id)
+
self.player = player
self.lavalink = lavalink_client

- self.__channel_id = channel_id
- self.__guild_id = guild_id
self.__session_id = session_id
- self.__is_alive = is_alive
- self.__shard_id = shard_id
self.__owner = owner
- self.__on_close = on_close
-
- @property
- def channel_id(self) -> hikari.Snowflake:
- """Return the ID of the voice channel this voice connection is in."""
- return self.__channel_id

- @property
- def guild_id(self) -> hikari.Snowflake:
- """Return the ID of the guild this voice connection is in."""
- return self.__guild_id
+ self.__should_disconnect: bool = True

- @property
- def is_alive(self) -> bool:
- """Return `builtins.True` if the connection is alive."""
- return self.__is_alive

- @property
- def shard_id(self) -> int:
- """Return the ID of the shard that requested the connection."""
- return self.__shard_id
+ @staticmethod
+ async def reconnect(player: PlayerContext, endpoint: str, token: str, session_id: str) -> None:
+ update_player = UpdatePlayer()
+ connection_info = ConnectionInfo(
+ endpoint, token, session_id
+ )
+ connection_info.fix()
+ update_player.voice = connection_info

- @property
- def owner(self) -> VoiceComponent:
- """Return the component that is managing this connection."""
- return self.__owner
+ await player.update_player(update_player, True)

async def disconnect(self) -> None:
"""Signal the process to shut down."""
- self.__is_alive = False
- await self.lavalink.delete_player(self.__guild_id)
- await self.__on_close(self)
-
- async def join(self) -> None:
- """Wait for the process to halt before continuing."""
+ if self.__should_disconnect:
+ await self.lavalink.delete_player(self.guild_id)
+ else:
+ self.__should_disconnect = True

async def notify(self, event: hikari.VoiceEvent) -> None:
"""Submit an event to the voice connection to be processed."""
- if isinstance(event, hikari.VoiceServerUpdateEvent):
- # Handle the bot being moved frome one channel to another
- assert event.raw_endpoint
- update_player = UpdatePlayer()
- connection_info = ConnectionInfo(
- event.raw_endpoint, event.token, self.__session_id
- )
- connection_info.fix()
- update_player.voice = connection_info
- await self.player.update_player(update_player, True)
+ if not isinstance(event, hikari.VoiceServerUpdateEvent):
+ return
+
+ # TODO handle this better
+ # https://discord.com/developers/docs/topics/gateway-events#voice-server-update
+ assert event.raw_endpoint
+
+ # Handle the bot being moved frome one channel to another
+ await LavalinkVoice.reconnect(self.player, event.raw_endpoint, event.token, self.__session_id)
+
+
+ @property
+ def owner(self) -> VoiceComponent:
+ """Return the component that is managing this connection."""
+ return self.__owner
+

@classmethod
async def connect(
@@ -106,9 +91,19 @@ class LavalinkVoice(VoiceConnection):
player_data: t.Any,
deaf: bool = True,
) -> LavalinkVoice:
+ try:
+ conn = client.voice.connections.get(guild_id)
+ if conn:
+ assert isinstance(conn, LavalinkVoice)
+ conn.__should_disconnect = False
+ await client.voice.disconnect(guild_id)
+ except:
+ pass
+
voice: LavalinkVoice = await client.voice.connect_to(
guild_id,
channel_id,
+ disconnect_existing=False,
voice_connection_type=LavalinkVoice,
lavalink_client=lavalink_client,
player_data=player_data,
@@ -117,13 +112,13 @@ class LavalinkVoice(VoiceConnection):

return voice

+
@classmethod
async def initialize(
cls,
channel_id: hikari.Snowflake,
endpoint: str,
guild_id: hikari.Snowflake,
- on_close: t.Any,
owner: VoiceComponent,
session_id: str,
shard_id: int,
@@ -132,14 +127,19 @@ class LavalinkVoice(VoiceConnection):
**kwargs: t.Any,
) -> LavalinkVoice:
del user_id
- lavalink_client = kwargs["lavalink_client"]
-
- player = await lavalink_client.create_player_context(
- guild_id, endpoint, token, session_id
- )

+ lavalink_client = kwargs["lavalink_client"]
player_data = kwargs["player_data"]

+ player = lavalink_client.get_player_context(guild_id)
+
+ if player:
+ await LavalinkVoice.reconnect(player, endpoint, token, session_id)
+ else:
+ player = await lavalink_client.create_player_context(
+ guild_id, endpoint, token, session_id
+ )
+
if player_data:
player.data = player_data

@@ -149,10 +149,8 @@ class LavalinkVoice(VoiceConnection):
channel_id=channel_id,
guild_id=guild_id,
session_id=session_id,
- is_alive=True,
shard_id=shard_id,
owner=owner,
- on_close=on_close,
)

return self
diff --git a/examples/py_hikari_lightbulb/plugins/music_basic.py b/examples/py_hikari_lightbulb/plugins/music_basic.py
index da463c3..1cb43e1 100644
--- a/examples/py_hikari_lightbulb/plugins/music_basic.py
+++ b/examples/py_hikari_lightbulb/plugins/music_basic.py
@@ -33,27 +33,13 @@ async def _join(ctx: Context) -> t.Optional[hikari.Snowflake]:

channel_id = voice_state.channel_id

- voice = ctx.bot.voice.connections.get(ctx.guild_id)
-
- if not voice:
- await LavalinkVoice.connect(
- ctx.guild_id,
- channel_id,
- ctx.bot,
- ctx.bot.lavalink,
- (ctx.channel_id, ctx.bot.rest),
- )
- else:
- assert isinstance(voice, LavalinkVoice)
-
- await LavalinkVoice.connect(
- ctx.guild_id,
- channel_id,
- ctx.bot,
- ctx.bot.lavalink,
- (ctx.channel_id, ctx.bot.rest),
- old_voice=voice,
- )
+ await LavalinkVoice.connect(
+ ctx.guild_id,
+ channel_id,
+ ctx.bot,
+ ctx.bot.lavalink,
+ (ctx.channel_id, ctx.bot.rest),
+ )

return channel_id

@@ -96,7 +82,7 @@ async def leave(ctx: Context) -> None:
await ctx.respond("Not in a voice channel")
return None

- await voice.disconnect()
+ await ctx.bot.voice.disconnect(ctx.guild_id)

await ctx.respond("Left the voice channel")

12 changes: 6 additions & 6 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use tokio::sync::Mutex;
pub struct LavalinkClient {
pub nodes: Vec<Arc<node::Node>>,
pub players: Arc<DashMap<GuildId, (ArcSwapOption<PlayerContext>, Arc<node::Node>)>>,
pub events: events::Events,
pub events: Arc<events::Events>,
tx: UnboundedSender<client::ClientMessage>,
user_id: UserId,
user_data: Arc<dyn std::any::Any + Send + Sync>,
Expand Down Expand Up @@ -165,7 +165,7 @@ impl LavalinkClient {
user_id: built_nodes[0].user_id,
nodes: built_nodes,
players: Arc::new(DashMap::new()),
events,
events: Arc::new(events),
tx,
user_data,
strategy,
Expand Down Expand Up @@ -270,21 +270,21 @@ impl LavalinkClient {

Python::with_gil(|py| {
let func = func.into_py(py);
let current_loop = pyo3_asyncio::tokio::get_current_loop(py).unwrap();
let current_loop = pyo3_async_runtimes::tokio::get_current_loop(py).unwrap();

let client = client.clone();
let client2 = client.clone();

pyo3_asyncio::tokio::future_into_py_with_locals(
pyo3_async_runtimes::tokio::future_into_py_with_locals(
py,
pyo3_asyncio::TaskLocals::new(current_loop),
pyo3_async_runtimes::TaskLocals::new(current_loop),
async move {
let future = Python::with_gil(|py| {
let coro = func
.call(py, (client.into_py(py), guild_id.into_py(py)), None)
.unwrap();

pyo3_asyncio::tokio::into_future(coro.downcast(py).unwrap())
pyo3_async_runtimes::tokio::into_future(coro.into_bound(py))
})
.unwrap();

Expand Down
5 changes: 3 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ mod python;

#[cfg(feature = "python")]
#[pymodule]
fn lavalink_rs(py: Python<'_>, m: &PyModule) -> PyResult<()> {
fn lavalink_rs(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
let handle = pyo3_log::Logger::new(py, pyo3_log::Caching::LoggersAndLevels)?
.filter(log::LevelFilter::Trace)
.install()
Expand All @@ -90,7 +90,8 @@ fn lavalink_rs(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pymodule!(python::model::model))?;

let sys = PyModule::import(py, "sys")?;
let sys_modules: &PyDict = sys.getattr("modules")?.downcast()?;
let raw_modules = sys.getattr("modules")?;
let sys_modules: &Bound<'_, PyDict> = raw_modules.downcast()?;
sys_modules.set_item("lavalink_rs.model", m.getattr("model")?)?;

Ok(())
Expand Down
Loading

0 comments on commit 0f9a7bf

Please sign in to comment.