-
Notifications
You must be signed in to change notification settings - Fork 58
Warn if using offline data buffer without Agent wrapper #1771
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: pytorch
Are you sure you want to change the base?
Conversation
alf/algorithms/sac_algorithm.py
Outdated
| action, action_distribution): | ||
|
|
||
| if isinstance(rollout_info, BasicRolloutInfo): | ||
| rollout_info = rollout_info.rl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be put outside of this function. The general principle is, the algorithm should always receive what it's supposed to receive. In this case, this means that the rollout_info passed in should already be SacInfo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be put outside of this function. The general principle is, the algorithm should always receive what it's supposed to receive. In this case, this means that the rollout_info passed in should already be SacInfo.
+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Fixed.
alf/algorithms/sac_algorithm.py
Outdated
| if isinstance(rollout_info, BasicRolloutInfo): | ||
| rollout_info = rollout_info.rl | ||
| state: SacCriticState, | ||
| rollout_info: SacInfo | BasicRLInfo, action, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still should always be SacInfo? If it's BasicRLInfo, the algorithm will crash.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Offline buffer data is stored as BasicRLInfo which comprises of just (s,a,r) data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess, I could convert BasicRLInfo into SacInfo with some fields empty? Not sure which is a better design. Lmk which one you think is cleaner and I can change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Offline buffer data is stored as
BasicRLInfowhich comprises of just (s,a,r) data.
If you look at SAC's train_step(), it will get access to rollout_info.repr. This means that SAC is currently incompatible with offline training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm using a frozen encoder so not training a repr.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, it seems that repr is stored in BasicRolloutInfo. Not sure how this code was running then. I'll take a look
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's using elastic_namedtuple so any missing field returns (). Anyway, a little weird but it works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, we should not include BasicRLInfo here as it could confuse the pure sac users. The better alternative might be comply with the Agent assumption and possibly extend it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed typehints. Also added a warning message advising users to use Agent. Before the code would simply crash due to interface conflict.
| logging.WARNING, | ||
| "Detected offline buffer training without Agent wrapper. " | ||
| "For best compatibility, it is advised to use the Agent wrapper.", | ||
| n=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This warning won't work. When using Agent, we still get rollout_info as BasicRolloutInfo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will. When using Agent, it properly feeds the nested BasicRLInfo to this function instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In other words, there was never a bug with hybrid RL training, just that it was never meant to be used without using Agent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will. When using
Agent, it properly feeds the nestedBasicRLInfoto this function instead.
you're right. I didn't know Agent overwrites this function itself.
This PR now warns users if they are using the offline data buffer without the
Agentwrapper.