Skip to content

Commit

Permalink
Optimize copy behavior of expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
psaegert committed Feb 22, 2025
1 parent 8c20dbd commit 383e2de
Show file tree
Hide file tree
Showing 12 changed files with 275 additions and 56 deletions.
182 changes: 182 additions & 0 deletions experimental/copy.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"def numerify_special_constants(prefix_expression: list[str], inplace: bool = False, append: bool = False) -> list[str]:\n",
" # print(id(prefix_expression))\n",
"\n",
" if not inplace:\n",
" if not append:\n",
" modified_prefix_expression = prefix_expression.copy()\n",
" for i, token in enumerate(modified_prefix_expression):\n",
" if token == '2':\n",
" modified_prefix_expression[i] = '?'\n",
" else:\n",
" modified_prefix_expression = []\n",
" for i, token in enumerate(prefix_expression):\n",
" if token == '2':\n",
" modified_prefix_expression.append('?')\n",
" modified_prefix_expression.append(token)\n",
" else:\n",
" modified_prefix_expression = prefix_expression\n",
"\n",
" # print(id(modified_prefix_expression))\n",
"\n",
" modified_prefix_expression[0] = '?'\n",
" \n",
" # print(id(modified_prefix_expression))\n",
"\n",
" return modified_prefix_expression"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"139663546625344\n",
"139663546688832\n"
]
}
],
"source": [
"my_list = ['1', '2', '3']\n",
"print(id(my_list))\n",
"my_returned_list = numerify_special_constants(my_list, inplace=False)\n",
"print(id(my_returned_list))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"139663546619072\n",
"139663546584640\n"
]
}
],
"source": [
"my_list = ['1', '2', '3']\n",
"print(id(my_list))\n",
"my_returned_list = numerify_special_constants(my_list, inplace=False, append=True)\n",
"print(id(my_returned_list))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"139663546562240\n",
"139663546562240\n"
]
}
],
"source": [
"my_list = ['1', '2', '3']\n",
"print(id(my_list))\n",
"my_returned_list = numerify_special_constants(my_list, inplace=True)\n",
"print(id(my_returned_list))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"142 ns ± 3.21 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"numerify_special_constants(['1', '2', '3'], inplace=False)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"158 ns ± 2.51 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"numerify_special_constants(['1', '2', '3'], inplace=False, append=True)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"47.7 ns ± 0.747 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"numerify_special_constants(['1', '2', '3'], inplace=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "flash-ansr",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 1 addition & 1 deletion experimental/eval/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def sympy_simplify_wrapper(expression: list[str], ratio=None, debug=False):
expression = pool.expression_space.parse_expression(expression)
if debug: print(expression)

expression = numbers_to_num(expression)
expression = numbers_to_num(expression, inplcae=True)
if debug: print(expression)

return tuple(expression)
Expand Down
4 changes: 4 additions & 0 deletions src/flash_ansr/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,3 +404,7 @@ def main(argv: str = None) -> None:
case _:
parser.print_help()
sys.exit(1)


if __name__ == '__main__':
main()
6 changes: 3 additions & 3 deletions src/flash_ansr/compat/convert_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def parse_data(self, test_set_df: pd.DataFrame, expression_space: ExpressionSpac

# Codify
prefix_expression_w_num = expression_space.operators_to_realizations(prefix_expression)
prefix_expression_w_constants, constants = num_to_constants(prefix_expression_w_num)
prefix_expression_w_constants, constants = num_to_constants(prefix_expression_w_num, inplace=True)
code_string = expression_space.prefix_to_infix(prefix_expression_w_constants, realization=True)
code = codify(code_string, expression_space.variables + constants)

Expand Down Expand Up @@ -157,7 +157,7 @@ def parse_data(self, test_set_df: pd.DataFrame, expression_space: ExpressionSpac

# Codify
prefix_expression_w_num = expression_space.operators_to_realizations(prefix_expression)
prefix_expression_w_constants, constants = num_to_constants(prefix_expression_w_num)
prefix_expression_w_constants, constants = num_to_constants(prefix_expression_w_num, inplace=True)
code_string = expression_space.prefix_to_infix(prefix_expression_w_constants, realization=True)
code = codify(code_string, expression_space.variables + constants)

Expand Down Expand Up @@ -220,7 +220,7 @@ def parse_data(self, test_set_df: pd.DataFrame, expression_space: ExpressionSpac

# Codify
prefix_expression_w_num = expression_space.operators_to_realizations(prefix_expression)
prefix_expression_w_constants, constants = num_to_constants(prefix_expression_w_num)
prefix_expression_w_constants, constants = num_to_constants(prefix_expression_w_num, inplace=True)
code_string = expression_space.prefix_to_infix(prefix_expression_w_constants, realization=True)
code = codify(code_string, expression_space.variables + constants)

Expand Down
2 changes: 1 addition & 1 deletion src/flash_ansr/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ def generate(
'n_rejected': [n_rejected],
'skeletons': skeleton,
'skeleton_hashes': skeleton_hash,
'expressions': substitude_constants(skeleton, values=literals),
'expressions': substitude_constants(skeleton, values=literals, inplace=True),
'constants': torch.tensor(literals, dtype=torch.float32),
'input_ids': input_ids,
'x_tensors': torch.tensor(x_support, dtype=torch.float32),
Expand Down
Loading

0 comments on commit 383e2de

Please sign in to comment.