Skip to content

Commit 78ce8b5

Browse files
committed
init
1 parent 583e2a1 commit 78ce8b5

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

.github/scripts/m1_script.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/bin/bash
22

3-
export BUILD_VERSION=0.4.0
3+
export TORCHRL_BUILD_VERSION=0.4.0
44

55
${CONDA_RUN} pip install git+https://github.com/pytorch/tensordict.git -U

.github/workflows/wheels.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
run: |
3333
export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH"
3434
python3 -mpip install wheel
35-
BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel
35+
TORCHRL_BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel
3636
# NB: wheels have the linux_x86_64 tag so we rename to manylinux1
3737
# find . -name 'dist/*whl' -exec bash -c ' mv $0 ${0/linux/manylinux1}' {} \;
3838
# pytorch/pytorch binaries are also manylinux_2_17 compliant but they
@@ -72,7 +72,7 @@ jobs:
7272
shell: bash
7373
run: |
7474
python3 -mpip install wheel
75-
BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel
75+
TORCHRL_BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel
7676
- name: Upload wheel for the test-wheel job
7777
uses: actions/upload-artifact@v2
7878
with:

setup.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ def get_version():
3232
version_txt = os.path.join(cwd, "version.txt")
3333
with open(version_txt, "r") as f:
3434
version = f.readline().strip()
35-
if os.getenv("BUILD_VERSION"):
36-
version = os.getenv("BUILD_VERSION")
35+
if os.getenv("TORCHRL_BUILD_VERSION"):
36+
version = os.getenv("TORCHRL_BUILD_VERSION")
3737
elif sha != "Unknown":
3838
version += "+" + sha[:7]
3939
return version
@@ -68,11 +68,13 @@ def write_version_file(version):
6868
f.write("git_version = {}\n".format(repr(sha)))
6969

7070

71-
def _get_pytorch_version(is_nightly):
71+
def _get_pytorch_version(is_nightly, is_local):
7272
# if "PYTORCH_VERSION" in os.environ:
7373
# return f"torch=={os.environ['PYTORCH_VERSION']}"
7474
if is_nightly:
7575
return "torch>=2.4.0.dev"
76+
elif is_local:
77+
return "torch"
7678
return "torch>=2.3.0"
7779

7880

@@ -178,10 +180,12 @@ def _main(argv):
178180
else:
179181
version = get_version()
180182
write_version_file(version)
183+
TORCHRL_BUILD_VERSION = os.getenv("TORCHRL_BUILD_VERSION")
181184
logging.info("Building wheel {}-{}".format(package_name, version))
182-
logging.info(f"BUILD_VERSION is {os.getenv('BUILD_VERSION')}")
185+
logging.info(f"TORCHRL_BUILD_VERSION is {TORCHRL_BUILD_VERSION}")
183186

184-
pytorch_package_dep = _get_pytorch_version(is_nightly)
187+
is_local = TORCHRL_BUILD_VERSION is None
188+
pytorch_package_dep = _get_pytorch_version(is_nightly, is_local)
185189
logging.info("-- PyTorch dependency:", pytorch_package_dep)
186190
# branch = _run_cmd(["git", "rev-parse", "--abbrev-ref", "HEAD"])
187191
# tag = _run_cmd(["git", "describe", "--tags", "--exact-match", "@"])

0 commit comments

Comments
 (0)