Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for SetTransformer++, ScheduleFree, Complexity Control. Improve stability, 2D Analysis, Small Fixes #19

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
96815df
Improve stability, 2D Analysis
psaegert Feb 19, 2025
75e95df
Merge branch 'main' into dev
psaegert Feb 19, 2025
d1054c8
Add SetTransformer++
psaegert Feb 19, 2025
366027c
Merge branch 'dev' of github.com:psaegert/flash-ansr into dev
psaegert Feb 19, 2025
7fab8db
Improve isfinite check
psaegert Feb 19, 2025
427689b
Clean up notebook, re-add force remove option
psaegert Feb 19, 2025
69c19b3
Merge branch 'dev' of github.com:psaegert/flash-ansr into dev
psaegert Feb 19, 2025
63e95ae
Fix v7.11 config
psaegert Feb 19, 2025
d795b79
Add support for schedulefree, add v11.0
psaegert Feb 19, 2025
e526d45
Remove amsgrad option from schedulefree AdamW
psaegert Feb 19, 2025
e746bac
Fix string type lr in config
psaegert Feb 19, 2025
db69826
Remove valid check for flash simplification
psaegert Feb 19, 2025
52159cc
Improve config loading for nested configs
psaegert Feb 21, 2025
da30f31
Implement preprocessing for batched generation
psaegert Feb 21, 2025
7ee7500
Add suppport for single instance preprocessing
psaegert Feb 21, 2025
411caa2
Add support for collating single data instances
psaegert Feb 22, 2025
8c20dbd
Add support for complexity control during training,
psaegert Feb 22, 2025
383e2de
Optimize copy behavior of expressions
psaegert Feb 22, 2025
99639f0
Fix saving configs recursively
psaegert Feb 24, 2025
947750e
Begin adding support for complexity in inference
psaegert Feb 24, 2025
1a450a5
Add support for complexity control in beam search
psaegert Feb 25, 2025
33b03c8
Remove profiling files
psaegert Feb 25, 2025
922a32d
Add table of contents, visual abstract, resutls to README.md
psaegert Feb 25, 2025
11c88fb
Replace markdown table with png
psaegert Feb 25, 2025
ae02dac
Improve README
psaegert Feb 25, 2025
02020ae
Fix s not capital in NeSymReS
psaegert Feb 25, 2025
3ac939e
Fix capital S in results.png
psaegert Feb 25, 2025
ac9ae79
Fiy typo in experimental/eval/simplify.py
psaegert Feb 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

[![pytest](https://github.com/psaegert/flash-ansr/actions/workflows/pytest.yml/badge.svg)](https://github.com/psaegert/flash-ansr/actions/workflows/pytest.yml)
[![quality checks](https://github.com/psaegert/flash-ansr/actions/workflows/pre-commit.yml/badge.svg)](https://github.com/psaegert/flash-ansr/actions/workflows/pre-commit.yml)
[![CodeQL Advanced](https://github.com/psaegert/flash-ansr/actions/workflows/codeql.yml/badge.svg)](https://github.com/psaegert/flash-ansr/actions/workflows/codeql.yml)
[![CodeQL Advanced](https://github.com/psaegert/flash-ansr/actions/workflows/codeql.yaml/badge.svg)](https://github.com/psaegert/flash-ansr/actions/workflows/codeql.yaml)

</div>

Expand Down
2 changes: 2 additions & 0 deletions configs/v7.11/dataset_train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
skeleton_pool: './skeleton_pool_train.yaml'
padding: 'zero'
2 changes: 2 additions & 0 deletions configs/v7.11/dataset_val.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
skeleton_pool: './skeleton_pool_val.yaml'
padding: 'zero'
14 changes: 14 additions & 0 deletions configs/v7.11/evaluation.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
n_support: 512
beam_width: 32
n_restarts: 8
max_len: 32
refiner_method: 'curve_fit_lm'
numeric_head: False
equivalence_pruning: True
pointwise_close_criterion: 0.95
pointwise_close_accuracy_rtol: 0.05
pointwise_close_accuracy_atol: 0.001
r2_close_criterion: 0.95
refiner_p0_noise: 'uniform'
refiner_p0_noise_kwargs: {'low': -5, 'high': 5}
device: cuda
256 changes: 256 additions & 0 deletions configs/v7.11/expression_space.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
operators:
"+":
realization: "+"
alias: ["add", "plus"]
inverse: "-"
arity: 2
weight: 10
precedence: 1
commutative: true
symmetry: 0
positive: false
monotonicity: 0
"-":
realization: "-"
alias: ["sub", "minus"]
inverse: "+"
arity: 2
weight: 5
precedence: 1
commutative: false
symmetry: 0
positive: false
monotonicity: 0
neg:
realization: "nsrops.neg"
alias: ["negative"]
inverse: "neg"
arity: 1
weight: 5
precedence: 2.5
commutative: false
symmetry: -1 # anti-symmetric
positive: false
monotonicity: -1
"*":
realization: "*"
alias: ["mul", "times"]
inverse: "/"
arity: 2
weight: 10
precedence: 2
commutative: true
symmetry: 0
positive: false
monotonicity: 0
"/":
realization: "/"
alias: ["div", "divide"]
inverse: "*"
arity: 2
weight: 5
precedence: 2
commutative: false
symmetry: 0
positive: false
monotonicity: 0
abs:
realization: "abs"
alias: ["absolute"]
inverse: null
arity: 1
weight: 4
precedence: 3
commutative: false
symmetry: 1 # symmetric
positive: true
monotonicity: 0
inv:
realization: "nsrops.inv"
alias: ["inverse"]
inverse: "inv"
arity: 1
weight: 4
precedence: 4
commutative: false
symmetry: -1 # anti-symmetric
positive: false
monotonicity: -1
pow2:
realization: "nsrops.pow2"
alias: ["square"]
inverse: null
arity: 1
weight: 4
precedence: 3
commutative: false
symmetry: 1 # symmetric
positive: true
monotonicity: 0
pow3:
realization: "nsrops.pow3"
alias: ["cube"]
inverse: "pow1_3"
arity: 1
weight: 2
precedence: 3
commutative: false
symmetry: -1 # anti-symmetric
positive: false
monotonicity: 1
pow4:
realization: "nsrops.pow4"
alias: []
inverse: null
arity: 1
weight: 1
precedence: 3
commutative: false
symmetry: 1
positive: true
monotonicity: 0
pow5:
realization: "nsrops.pow5"
alias: []
inverse: "pow1_5"
arity: 1
weight: 1
precedence: 3
commutative: false
symmetry: -1 # anti-symmetric
positive: false
monotonicity: 1
pow1_2:
realization: "nsrops.pow1_2"
alias: ["sqrt"]
inverse: null
arity: 1
weight: 4
precedence: 3
commutative: false
symmetry: 0
positive: true
monotonicity: 1
pow1_3:
realization: "nsrops.pow1_3"
alias: []
inverse: null
arity: 1
weight: 2
precedence: 3
commutative: false
symmetry: -1 # anti-symmetric
positive: false
monotonicity: 1
pow1_4:
realization: "nsrops.pow1_4"
alias: []
inverse: null
arity: 1
weight: 1
precedence: 3
commutative: false
symmetry: 0
positive: true
monotonicity: 1
pow1_5:
realization: "nsrops.pow1_5"
alias: []
inverse: null
arity: 1
weight: 1
precedence: 3
commutative: false
symmetry: -1 # anti-symmetric
positive: false
monotonicity: 1
sin:
realization: "numpy.sin"
alias: []
inverse: "asin"
arity: 1
weight: 4
precedence: 2
commutative: false
symmetry: -1 # anti-symmetric
positive: false
monotonicity: 0
cos:
realization: "numpy.cos"
alias: []
inverse: "acos"
arity: 1
weight: 4
precedence: 2
commutative: false
symmetry: 1 # symmetric
positive: false
monotonicity: 0
tan:
realization: "numpy.tan"
alias: []
inverse: "atan"
arity: 1
weight: 4
precedence: 2
commutative: false
symmetry: -1 # anti-symmetric
positive: false
monotonicity: 0
asin:
realization: "numpy.arcsin"
alias: ["arcsin"]
inverse: "sin"
arity: 1
weight: 2
precedence: 2
commutative: false
symmetry: -1 # anti-symmetric
positive: false
monotonicity: 1
acos:
realization: "numpy.arccos"
alias: ["arccos"]
inverse: "cos"
arity: 1
weight: 2
precedence: 2
commutative: false
symmetry: 0
positive: true
monotonicity: 1
atan:
realization: "numpy.arctan"
alias: ["arctan"]
inverse: "tan"
arity: 1
weight: 2
precedence: 2
commutative: false
symmetry: -1 # anti-symmetric
positive: false
monotonicity: 1
exp:
realization: "numpy.exp"
alias: []
inverse: "log"
arity: 1
weight: 4
precedence: 3
commutative: false
symmetry: 0
positive: true
monotonicity: 1
log:
realization: "numpy.log"
alias: ["ln"]
inverse: "exp"
arity: 1
weight: 4
precedence: 2
commutative: false
symmetry: 0
positive: false
monotonicity: 1

variables: 3
25 changes: 25 additions & 0 deletions configs/v7.11/nsr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
encoder_max_n_variables: 4 # includes the dependent variable
size: 512

pre_encoder_input_type: "ieee-754"
pre_encoder_support_nan: False

encoder: "SetTransformer"
encoder_kwargs:
hidden_size: 512
n_enc_isab: 5
n_dec_sab: 2
n_induce: 64
n_heads: 8
layer_norm: False
n_seeds: 64

decoder_n_heads: 8
decoder_ff_size: 512
decoder_dropout: 0.1
decoder_n_layers: 5

learnable_positional_embeddings: False
max_input_length: null

expression_space: './expression_space.yaml'
38 changes: 38 additions & 0 deletions configs/v7.11/skeleton_pool_train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
expression_space: './expression_space.yaml'
holdout_pools: [
# "{{ROOT}}/data/ansr-data/v7.11/skeleton_pool_val/",
"{{ROOT}}/data/ansr-data/test_set/soose_nc/skeleton_pool/",
"{{ROOT}}/data/ansr-data/test_set/feynman/skeleton_pool/",
"{{ROOT}}/data/ansr-data/test_set/nguyen/skeleton_pool/",
"{{ROOT}}/data/ansr-data/test_set/pool_15/skeleton_pool/"
]

sample_strategy:
n_operator_distribution: "length_exponential"
min_operators: 0
max_operators: 15
power: 1
lambda: 1
max_length: 31
max_tries: 1
independent_dimensions: True

allow_nan: False
simplify: True

literal_prior: 'uniform'
literal_prior_kwargs:
low: -5
high: 5

support_prior: "uniform_intervals"
support_prior_kwargs:
low: -10
high: 10

n_support_prior: "uniform"
n_support_prior_kwargs:
low: 16
high: 512
min_value: 16
max_value: 512
37 changes: 37 additions & 0 deletions configs/v7.11/skeleton_pool_val.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
expression_space: './expression_space.yaml'
holdout_pools: [
"{{ROOT}}/data/ansr-data/test_set/soose_nc/skeleton_pool/",
"{{ROOT}}/data/ansr-data/test_set/feynman/skeleton_pool/",
"{{ROOT}}/data/ansr-data/test_set/nguyen/skeleton_pool/",
"{{ROOT}}/data/ansr-data/test_set/pool_15/skeleton_pool/"
]

sample_strategy:
n_operator_distribution: "length_exponential"
min_operators: 0
max_operators: 10
power: 0.8
lambda: 1
max_length: 21
max_tries: 1
independent_dimensions: True

allow_nan: False
simplify: True

literal_prior: 'uniform'
literal_prior_kwargs:
low: -5
high: 5

support_prior: "uniform_intervals"
support_prior_kwargs:
low: -10
high: 10

n_support_prior: "uniform"
n_support_prior_kwargs:
low: 16
high: 512
min_value: 16
max_value: 512
Loading