Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creating variants of the _Data object to allow gradient tracking and avoid unnecessary casting to numpy #208

Open
mpvanderschelling opened this issue Oct 30, 2023 · 1 comment · Fixed by #211
Assignees

Comments

@mpvanderschelling
Copy link
Collaborator

The problem

At the moment, the ExperimentData object consists of input_data, output_data, jobs, and domain. These are all custom objects that are private (except the Domain) object:

  • domain: f3dasm.design.Domain (public!)
  • input_data: f3dasm._src.design._data._Data
  • output_data: f3dasm._src.design._data._Data
  • jobs: f3dasm._src.design._jobqueue._JobQueue

Focussing on the input_data, any data (e.g. pd.DataFrame, numpy array, csv-file) that is given to ExperimentData will be converted to the _Data object. The _Data object back-end is pandas. This means that internally the data will be casted to something that is compatible with pandas datastorage; numpy

For automated differentiation tools this might be problematic, since the gradient needs to be 'tracked'. Any casting to numpy will break the chain.

In v1.4.3, we are using autograd.numpy to track these gradients and for tensorflow optimizers a conversion function will provide the 'custom gradient' so that it works with casting to numpy.

Additionally, optimized libraries will experience overhead costs when doing this conversion back an forth between e.g. jax arrays and numpy arrays

Proposal

Because the ExperimentData object is only depending on _Data and not directly on a pandas DataFrame, we can create a variant of the _Data object for any underlying datatype (e.g. a dictionary of tensorflow tensors).
We need to implement all the methods of the _Data object for that particular datatype.

Then, the user can choose upon creation of the ExperimentData object if they want to use the 'normal' backend (e.g. pandas/numpy) or any specialized backend (e.g. tensorflow, pytorch, jax).

This could also be inferred automatically when providing initial input_data.

First steps

This issue will investigate if we can implement this by starting with a _Data variant that works with an jax dataformat.

@mpvanderschelling mpvanderschelling self-assigned this Oct 30, 2023
@mpvanderschelling
Copy link
Collaborator Author

@SNMS95 ; I created an issue that might be relevant for your application with f3dasm. Feel free to add things here that might address this issue!

@mpvanderschelling mpvanderschelling linked a pull request Oct 31, 2023 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant