diff --git a/.gitignore b/.gitignore index 0ed0c8d..25690cf 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,6 @@ /src/flash_ansr/models/encoders/deep_sets2.py /src/flash_ansr/models/encoders/point_net.py /src/flash_ansr/models/encoders/point_net2.py -/src/flash_ansr/models/encoders/set_transformer2.py # Models & Data /models @@ -16,6 +15,9 @@ cloud.md # Wandb wandb +# Profiling +.prof + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/README.md b/README.md index f43d8f1..2344af8 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,55 @@ -
FlashANSR(beam_width=32,\n", + " expression_space=<flash_ansr.expressions.expression_space.ExpressionSpace object at 0x7fe04f2140d0>,\n", + " flash_ansr_transformer=FlashANSRTransformer(\n", + " (pre_encoder): PreEncoder()\n", + " (encoder): SetTransformer(\n", + " (enc): Sequential(\n", + " (0): ISAB(\n", + " (mab0): MAB(\n", + " (W_q): Linear(in_features=512, out_features=512, bias=True)\n", + " (W_k): Linear(in_features=64, out_features=...\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): GELU(approximate='none')\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " (3): Linear(in_features=512, out_features=33, bias=True)\n", + " )\n", + " (num_out): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): GELU(approximate='none')\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " (3): Linear(in_features=512, out_features=1, bias=True)\n", + " )\n", + "),\n", + " n_restarts=8, verbose=True)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
FlashANSR(beam_width=32,\n", + " expression_space=<flash_ansr.expressions.expression_space.ExpressionSpace object at 0x7fe04f2140d0>,\n", + " flash_ansr_transformer=FlashANSRTransformer(\n", + " (pre_encoder): PreEncoder()\n", + " (encoder): SetTransformer(\n", + " (enc): Sequential(\n", + " (0): ISAB(\n", + " (mab0): MAB(\n", + " (W_q): Linear(in_features=512, out_features=512, bias=True)\n", + " (W_k): Linear(in_features=64, out_features=...\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): GELU(approximate='none')\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " (3): Linear(in_features=512, out_features=33, bias=True)\n", + " )\n", + " (num_out): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): GELU(approximate='none')\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " (3): Linear(in_features=512, out_features=1, bias=True)\n", + " )\n", + "),\n", + " n_restarts=8, verbose=True)
\n", + " | log_prob | \n", + "fvu | \n", + "score | \n", + "expression | \n", + "complexity | \n", + "target_complexity | \n", + "numeric_prediction | \n", + "beam | \n", + "function | \n", + "refiner | \n", + "beam_id | \n", + "fit_constants | \n", + "fit_covariances | \n", + "fit_loss | \n", + "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", + "-1.060831 | \n", + "1.152357e-15 | \n", + "0.0007 | \n", + "[+, x1, +, x2, *, <num>, x3] | \n", + "7 | \n", + "None | \n", + "None | \n", + "[1, 7, 30, 7, 31, 10, 6, 32, 2] | \n", + "<function <lambda> at 0x7fe04ea5b600> | \n", + "Refiner(expression=['+', 'x1', '+', 'x2', '*',... | \n", + "0 | \n", + "[2.9999999960488646] | \n", + "[[1.3347598938741106e-16]] | \n", + "4.053961e-13 | \n", + "
1 | \n", + "-1.060831 | \n", + "1.152357e-15 | \n", + "0.0007 | \n", + "[+, x1, +, x2, *, <num>, x3] | \n", + "7 | \n", + "None | \n", + "None | \n", + "[1, 7, 30, 7, 31, 10, 6, 32, 2] | \n", + "<function <lambda> at 0x7fe04ea5b600> | \n", + "Refiner(expression=['+', 'x1', '+', 'x2', '*',... | \n", + "0 | \n", + "[2.9999999960488646] | \n", + "[[1.33475989113999e-16]] | \n", + "4.053961e-13 | \n", + "
2 | \n", + "-1.060831 | \n", + "1.152357e-15 | \n", + "0.0007 | \n", + "[+, x1, +, x2, *, <num>, x3] | \n", + "7 | \n", + "None | \n", + "None | \n", + "[1, 7, 30, 7, 31, 10, 6, 32, 2] | \n", + "<function <lambda> at 0x7fe04ea5b600> | \n", + "Refiner(expression=['+', 'x1', '+', 'x2', '*',... | \n", + "0 | \n", + "[2.9999999960488646] | \n", + "[[1.3347598871570362e-16]] | \n", + "4.053961e-13 | \n", + "
3 | \n", + "-1.060831 | \n", + "1.152357e-15 | \n", + "0.0007 | \n", + "[+, x1, +, x2, *, <num>, x3] | \n", + "7 | \n", + "None | \n", + "None | \n", + "[1, 7, 30, 7, 31, 10, 6, 32, 2] | \n", + "<function <lambda> at 0x7fe04ea5b600> | \n", + "Refiner(expression=['+', 'x1', '+', 'x2', '*',... | \n", + "0 | \n", + "[2.9999999960488646] | \n", + "[[1.3347598928734102e-16]] | \n", + "4.053961e-13 | \n", + "
4 | \n", + "-1.060831 | \n", + "1.152357e-15 | \n", + "0.0007 | \n", + "[+, x1, +, x2, *, <num>, x3] | \n", + "7 | \n", + "None | \n", + "None | \n", + "[1, 7, 30, 7, 31, 10, 6, 32, 2] | \n", + "<function <lambda> at 0x7fe04ea5b600> | \n", + "Refiner(expression=['+', 'x1', '+', 'x2', '*',... | \n", + "0 | \n", + "[2.9999999960488646] | \n", + "[[1.3347599027477214e-16]] | \n", + "4.053961e-13 | \n", + "
... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "
237 | \n", + "-6.061850 | \n", + "3.152573e-15 | \n", + "0.0013 | \n", + "[+, <num>, +, <num>, +, x1, +, x2, +, x3, +, x... | \n", + "13 | \n", + "None | \n", + "None | \n", + "[1, 7, 6, 7, 6, 7, 30, 7, 31, 7, 32, 7, 32, 32... | \n", + "<function <lambda> at 0x7fe04ea58c20> | \n", + "Refiner(expression=['+', '<num>', '+', '<num>'... | \n", + "30 | \n", + "[-0.47078499131995105, 0.47078492063308375] | \n", + "[[0.07722945023551932, -0.07722946672644812], ... | \n", + "1.109068e-12 | \n", + "
238 | \n", + "-6.061850 | \n", + "3.152573e-15 | \n", + "0.0013 | \n", + "[+, <num>, +, <num>, +, x1, +, x2, +, x3, +, x... | \n", + "13 | \n", + "None | \n", + "None | \n", + "[1, 7, 6, 7, 6, 7, 30, 7, 31, 7, 32, 7, 32, 32... | \n", + "<function <lambda> at 0x7fe04ea58c20> | \n", + "Refiner(expression=['+', '<num>', '+', '<num>'... | \n", + "30 | \n", + "[0.5407137010650169, -0.5407137717518034] | \n", + "[[1.468625096212761, -1.4686252416604124], [-1... | \n", + "1.109068e-12 | \n", + "
239 | \n", + "-6.061850 | \n", + "3.152573e-15 | \n", + "0.0013 | \n", + "[+, <num>, +, <num>, +, x1, +, x2, +, x3, +, x... | \n", + "13 | \n", + "None | \n", + "None | \n", + "[1, 7, 6, 7, 6, 7, 30, 7, 31, 7, 32, 7, 32, 32... | \n", + "<function <lambda> at 0x7fe04ea58c20> | \n", + "Refiner(expression=['+', '<num>', '+', '<num>'... | \n", + "30 | \n", + "[0.3074403724255982, -0.30744044311224167] | \n", + "[[0.01270703826740337, -0.012707037074853388],... | \n", + "1.109068e-12 | \n", + "
240 | \n", + "-6.061850 | \n", + "3.152573e-15 | \n", + "0.0013 | \n", + "[+, <num>, +, <num>, +, x1, +, x2, +, x3, +, x... | \n", + "13 | \n", + "None | \n", + "None | \n", + "[1, 7, 6, 7, 6, 7, 30, 7, 31, 7, 32, 7, 32, 32... | \n", + "<function <lambda> at 0x7fe04ea58c20> | \n", + "Refiner(expression=['+', '<num>', '+', '<num>'... | \n", + "30 | \n", + "[0.31953753768780313, -0.3195376083731668] | \n", + "[[0.060332209908029115, -0.06033219583661418],... | \n", + "1.109068e-12 | \n", + "
241 | \n", + "-6.257017 | \n", + "6.773130e-15 | \n", + "0.0013 | \n", + "[+, x1, +, x1, +, x2, +, x3, +, x3, -, x3, x1] | \n", + "13 | \n", + "None | \n", + "None | \n", + "[1, 7, 30, 7, 30, 7, 31, 7, 32, 7, 32, 8, 32, ... | \n", + "<function <lambda> at 0x7fe04ea5b2e0> | \n", + "Refiner(expression=['+', 'x1', '+', 'x1', '+',... | \n", + "31 | \n", + "[] | \n", + "[] | \n", + "2.382770e-12 | \n", + "
242 rows × 14 columns
\n", + "