Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for manually modifying client/server learning rate #258

Open
marcociccone opened this issue Mar 1, 2022 · 1 comment
Open

Support for manually modifying client/server learning rate #258

marcociccone opened this issue Mar 1, 2022 · 1 comment

Comments

@marcociccone
Copy link

Hi,
I'm playing around with clients learning rate but I cannot find a clean way of modifying it.

Basically, I need to change the LR following a schedule based on the current round.
Is that possible?

Thanks

@jaehunro
Copy link
Collaborator

jaehunro commented Mar 7, 2022

Hi,

Thanks for trying out FedJAX and filing this issue!

I think this would mainly be an optax functionality since FedJAX optimizers just wrap around existing optax optimizers.

If you want a different client learning rate for each round of federated averaging, here are few suggestions:

  1. Pass round number as part of the server_state in the federated algorithm and then create a new client optimizer with learning rate based on it at each round of federated training. (caveat: potentially slow due to effectively recompiling client optimizer apply function each round)
  2. Optax has built-in support for learning rate decay e.g., optax.exponential_decay. Some of these schedulers can be used directly / tweaked to support your use cases.

Some potentially helpful links:

If you have more concrete examples of what you're trying to do, we'd be happy to look into it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants