Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat Sebulba recurrent IQL #1148

Open
wants to merge 13 commits into
base: develop
Choose a base branch
from
Open

Conversation

Louay-Ben-nessir
Copy link
Contributor

What?

A recurrent IQL implementation using the Sebulba architecture.

Why?

Offline Sebulba base and non-jax envs in Mava.

How?

Mixed the Sebulba structure from PPO with the learner code from Anakin IQL.

Copy link
Contributor

@sash-a sash-a left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've looked through everything except the system file and it looks good, Sebulba utils especially! Just some relatively minor style changes

mava/configs/system/q_learning/rec_iql.yaml Outdated Show resolved Hide resolved
mava/systems/q_learning/types.py Outdated Show resolved Hide resolved
mava/utils/config.py Outdated Show resolved Hide resolved

# todo: remove the ppo dependencies when we make sebulba for other systems
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point though, maybe there's something we can do about it 🤔

Maybe a protocol like that has action, obs, reward, not sure if there's any other common attributes?

mava/utils/sebulba.py Outdated Show resolved Hide resolved
mava/utils/sebulba.py Outdated Show resolved Hide resolved
mava/utils/sebulba.py Outdated Show resolved Hide resolved
mava/utils/sebulba.py Outdated Show resolved Hide resolved
mava/utils/sebulba.py Outdated Show resolved Hide resolved
Comment on lines +275 to +277
terminated = np.repeat(
terminated[..., np.newaxis], repeats=self.num_agents, axis=-1
) # (B,) --> (B, N)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this already happen for smax and lbf?

Copy link
Contributor

@sash-a sash-a left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work here! Really minor changes required. Happy to merge this pending some benchmarks

mava/systems/q_learning/sebulba/rec_iql.py Outdated Show resolved Hide resolved
mava/systems/q_learning/sebulba/rec_iql.py Show resolved Hide resolved
target: Array,
) -> Tuple[Array, Metrics]:
# axes switched here to scan over time
hidden_state, obs_term_or_trunc = prep_inputs_to_scannedrnn(obs, term_or_trunc)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A general comment, I think this would be a lot easier to read if we used done to mean term_or_trunc which I think is a reasonable thing. Would have to make the change in anakin also though :/

mava/systems/q_learning/sebulba/rec_iql.py Outdated Show resolved Hide resolved
"""

eps = jnp.maximum(
config.system.eps_min, 1 - (t / config.system.eps_decay) * (1 - config.system.eps_min)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice if we could set a different decay per actor, although I think that's out of scope for this PR. Maybe if you could make an issue to add in some of the ape-X DQN features that would be great

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can easily add this in this PR 👀

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think rather leave it for now, no need to make this more complex

mava/systems/q_learning/sebulba/rec_iql.py Outdated Show resolved Hide resolved
mava/configs/system/q_learning/rec_iql.yaml Outdated Show resolved Hide resolved
mava/systems/q_learning/sebulba/rec_iql.py Outdated Show resolved Hide resolved
mava/systems/q_learning/sebulba/rec_iql.py Outdated Show resolved Hide resolved
mava/systems/q_learning/sebulba/rec_iql.py Outdated Show resolved Hide resolved
Copy link
Contributor

@SimonDuToit SimonDuToit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work Louay! Just two questions from my side as you can see in the comments.

rewards = np.zeros((num_envs, num_agents), dtype=float)
teminated = np.zeros(num_envs, dtype=float)
rewards = np.zeros((num_envs, self.num_agents), dtype=float)
terminated = np.zeros((num_envs, self.num_agents), dtype=float)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume with this change we would also need to change sebulba PPO? Since currently it does this same operation. We should decide if generally its better doing this in the system or the wrapper.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I think we do this in the wrappers for the anakin systems

mava/utils/sebulba/pipelines.py Outdated Show resolved Hide resolved
Copy link
Contributor

@sash-a sash-a left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me, just a few minor things 🙏

mava/utils/sebulba/pipelines.py Outdated Show resolved Hide resolved
mava/utils/sebulba/pipelines.py Outdated Show resolved Hide resolved
mava/utils/sebulba/pipelines.py Outdated Show resolved Hide resolved
mava/utils/sebulba/pipelines.py Show resolved Hide resolved
mava/utils/sebulba/pipelines.py Outdated Show resolved Hide resolved
mava/systems/q_learning/sebulba/rec_iql.py Outdated Show resolved Hide resolved
mava/systems/q_learning/sebulba/rec_iql.py Outdated Show resolved Hide resolved
mava/systems/q_learning/sebulba/rec_iql.py Outdated Show resolved Hide resolved
mava/systems/q_learning/sebulba/rec_iql.py Outdated Show resolved Hide resolved
mava/systems/q_learning/sebulba/rec_iql.py Outdated Show resolved Hide resolved
Copy link
Contributor

@sash-a sash-a left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔥

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants