Skip to content

Commit 8c8c14e

Browse files
authored
Remove tests requiring genbmm (#97)
* update tests * style * update setup * update
1 parent c875d93 commit 8c8c14e

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

setup.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from setuptools import setup
22

3+
with open("README.md", "r", encoding="utf-8") as fh:
4+
long_description = fh.read()
5+
36
setup(
47
name="torch_struct",
58
version="0.5",
@@ -9,9 +12,12 @@
912
"torch_struct",
1013
"torch_struct.semirings",
1114
],
15+
long_description=long_description,
1216
package_data={"torch_struct": []},
13-
url="https://github.com/harvardnlp/pytorch_struct",
17+
long_description_content_type="text/markdown",
18+
url="https://github.com/harvardnlp/pytorch-struct",
1419
install_requires=["torch"],
1520
setup_requires=["pytest-runner"],
1621
tests_require=["pytest"],
22+
python_requires='>=3.6',
1723
)

torch_struct/semirings/checkpoint.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import torch
22

3+
has_genbmm = False
34
try:
45
import genbmm
56
from genbmm import BandedMatrix
7+
has_genbmm = True
68
except ImportError:
79
pass
810

@@ -52,7 +54,7 @@ def backward(ctx, grad_output):
5254
class _CheckpointSemiring(cls):
5355
@staticmethod
5456
def matmul(a, b):
55-
if isinstance(a, genbmm.BandedMatrix):
57+
if has_genbmm and isinstance(a, genbmm.BandedMatrix):
5658
lu = a.lu + b.lu
5759
ld = a.ld + b.ld
5860
c = _CheckBand.apply(a.data, a.lu, a.ld, b.data, b.lu, b.ld)

0 commit comments

Comments
 (0)