Skip to content

Conversation

@QuantuMope
Copy link
Contributor

@QuantuMope QuantuMope commented May 12, 2025

This PR now warns users if they are using the offline data buffer without the Agent wrapper.

@QuantuMope QuantuMope requested a review from emailweixu May 12, 2025 21:14
action, action_distribution):

if isinstance(rollout_info, BasicRolloutInfo):
rollout_info = rollout_info.rl
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be put outside of this function. The general principle is, the algorithm should always receive what it's supposed to receive. In this case, this means that the rollout_info passed in should already be SacInfo.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be put outside of this function. The general principle is, the algorithm should always receive what it's supposed to receive. In this case, this means that the rollout_info passed in should already be SacInfo.

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Fixed.

if isinstance(rollout_info, BasicRolloutInfo):
rollout_info = rollout_info.rl
state: SacCriticState,
rollout_info: SacInfo | BasicRLInfo, action,
Copy link
Collaborator

@hnyu hnyu May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still should always be SacInfo? If it's BasicRLInfo, the algorithm will crash.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Offline buffer data is stored as BasicRLInfo which comprises of just (s,a,r) data.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess, I could convert BasicRLInfo into SacInfo with some fields empty? Not sure which is a better design. Lmk which one you think is cleaner and I can change.

Copy link
Collaborator

@hnyu hnyu May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Offline buffer data is stored as BasicRLInfo which comprises of just (s,a,r) data.

If you look at SAC's train_step(), it will get access to rollout_info.repr. This means that SAC is currently incompatible with offline training.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm using a frozen encoder so not training a repr.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, it seems that repr is stored in BasicRolloutInfo. Not sure how this code was running then. I'll take a look

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's using elastic_namedtuple so any missing field returns (). Anyway, a little weird but it works.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we should not include BasicRLInfo here as it could confuse the pure sac users. The better alternative might be comply with the Agent assumption and possibly extend it.

Copy link
Contributor Author

@QuantuMope QuantuMope May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed typehints. Also added a warning message advising users to use Agent. Before the code would simply crash due to interface conflict.

hnyu
hnyu previously approved these changes May 12, 2025
@QuantuMope QuantuMope changed the title [Bug Fix] Enable hybrid SAC training Warn if using offline data buffer without Agent wrapper May 13, 2025
logging.WARNING,
"Detected offline buffer training without Agent wrapper. "
"For best compatibility, it is advised to use the Agent wrapper.",
n=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This warning won't work. When using Agent, we still get rollout_info as BasicRolloutInfo.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will. When using Agent, it properly feeds the nested BasicRLInfo to this function instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In other words, there was never a bug with hybrid RL training, just that it was never meant to be used without using Agent.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will. When using Agent, it properly feeds the nested BasicRLInfo to this function instead.

you're right. I didn't know Agent overwrites this function itself.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants