76
76
float32 : [0, 1]
77
77
install-numba : [0]
78
78
install-jax : [0]
79
+ install-torch : [0]
79
80
part :
80
81
- " tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
81
82
- " tests/scan"
@@ -116,6 +117,11 @@ jobs:
116
117
fast-compile : 0
117
118
float32 : 0
118
119
part : " tests/link/jax"
120
+ - install-torch : 1
121
+ python-version : " 3.10"
122
+ fast-compile : 0
123
+ float32 : 0
124
+ part : " tests/link/pytorch"
119
125
steps :
120
126
- uses : actions/checkout@v4
121
127
with :
@@ -142,9 +148,12 @@ jobs:
142
148
- name : Install dependencies
143
149
shell : micromamba-shell {0}
144
150
run : |
151
+
145
152
micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy
146
153
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi
147
154
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi
155
+ if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" pytorch pytorch-cuda=12.1 -c pytorch -c nvidia; fi
156
+
148
157
pip install -e ./
149
158
micromamba list && pip freeze
150
159
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
@@ -153,6 +162,7 @@ jobs:
153
162
PYTHON_VERSION : ${{ matrix.python-version }}
154
163
INSTALL_NUMBA : ${{ matrix.install-numba }}
155
164
INSTALL_JAX : ${{ matrix.install-jax }}
165
+ INSTALL_TORCH : ${{ matrix.install-torch}}
156
166
157
167
- name : Run tests
158
168
shell : micromamba-shell {0}
@@ -199,7 +209,7 @@ jobs:
199
209
- name : Install dependencies
200
210
shell : micromamba-shell {0}
201
211
run : |
202
- micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark
212
+ micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark pytorch pytorch-cuda=12.1 -c pytorch -c nvidia
203
213
pip install -e ./
204
214
micromamba list && pip freeze
205
215
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
@@ -268,3 +278,4 @@ jobs:
268
278
directory : ./coverage/
269
279
fail_ci_if_error : true
270
280
token : ${{ secrets.CODECOV_TOKEN }}
281
+
0 commit comments