Skip to content

Commit 3531940

Browse files
authored
Feature/support existing environments (#1891)
* Add option to ModelParameters to use an exisiting python environment * Don't delete environment in tests * Format code with black formatter * Update user guide * Expand path before creating the environment * Undo catching exception * Add unittests for existing environments * Format code
1 parent af365e1 commit 3531940

File tree

8 files changed

+135
-19
lines changed

8 files changed

+135
-19
lines changed

docs/user-guide/custom.md

+17-1
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,8 @@ In these cases, to load your custom runtime, MLServer will need access to these
215215
dependencies.
216216

217217
It is possible to load this custom set of dependencies by providing them
218-
through an [environment tarball](../examples/conda/README), whose path can be
218+
through an [environment tarball](../examples/conda/README) or by giving a
219+
path to an already exisiting python environment. Both paths can be
219220
specified within your `model-settings.json` file.
220221

221222
```{warning}
@@ -277,6 +278,21 @@ Note that, in the folder layout above, we are assuming that:
277278
}
278279
```
279280

281+
If you want to use an already exisiting python environment, you can use the parameter `environment_path` of your `model-settings.json`:
282+
283+
```
284+
---
285+
emphasize-lines: 5
286+
---
287+
{
288+
"model": "sum-model",
289+
"implementation": "models.MyCustomRuntime",
290+
"parameters": {
291+
"environment_path": "~/micromambda/envs/my-conda-environment"
292+
}
293+
}
294+
```
295+
280296
## Building a custom MLServer image
281297

282298
```{note}

mlserver/env.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import multiprocessing
33
import os
4+
import shutil
45
import sys
56
import tarfile
67
import glob
@@ -18,7 +19,7 @@ def _extract_env(tarball_path: str, env_path: str) -> None:
1819
tarball.extractall(path=env_path)
1920

2021

21-
def _compute_hash(tarball_path: str) -> str:
22+
def _compute_hash_of_file(tarball_path: str) -> str:
2223
"""
2324
From Python 3.11's implementation of `hashlib.file_digest()`:
2425
https://github.com/python/cpython/blob/3.11/Lib/hashlib.py#L257
@@ -39,9 +40,20 @@ def _compute_hash(tarball_path: str) -> str:
3940
return h.hexdigest()
4041

4142

42-
async def compute_hash(tarball_path: str) -> str:
43+
def _compute_hash_of_string(string: str) -> str:
44+
h = hashlib.sha256()
45+
h.update(string.encode())
46+
return h.hexdigest()
47+
48+
49+
async def compute_hash_of_file(tarball_path: str) -> str:
4350
loop = asyncio.get_running_loop()
44-
return await loop.run_in_executor(None, _compute_hash, tarball_path)
51+
return await loop.run_in_executor(None, _compute_hash_of_file, tarball_path)
52+
53+
54+
async def compute_hash_of_string(string: str) -> str:
55+
loop = asyncio.get_running_loop()
56+
return await loop.run_in_executor(None, _compute_hash_of_string, string)
4557

4658

4759
class Environment:
@@ -51,7 +63,8 @@ class Environment:
5163
environment.
5264
"""
5365

54-
def __init__(self, env_path: str, env_hash: str):
66+
def __init__(self, env_path: str, env_hash: str, delete_env: bool = True):
67+
self._delete_env = delete_env
5568
self._env_path = env_path
5669
self.env_hash = env_hash
5770

@@ -67,7 +80,7 @@ async def from_tarball(
6780
await loop.run_in_executor(None, _extract_env, tarball_path, env_path)
6881

6982
if not env_hash:
70-
env_hash = await compute_hash(tarball_path)
83+
env_hash = await compute_hash_of_file(tarball_path)
7184

7285
return cls(env_path, env_hash)
7386

@@ -136,3 +149,8 @@ def __exit__(self, *exc_details) -> None:
136149
multiprocessing.set_executable(sys.executable)
137150
sys.path = self._prev_sys_path
138151
os.environ["PATH"] = self._prev_bin_path
152+
153+
def __del__(self) -> None:
154+
logger.info("Cleaning up environment")
155+
if self._delete_env:
156+
shutil.rmtree(self._env_path)

mlserver/parallel/registry.py

+43-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import asyncio
22
import os
3-
import shutil
43
import signal
54

65
from typing import Optional, Dict, List
@@ -9,7 +8,7 @@
98
from ..utils import to_absolute_path
109
from ..model import MLModel
1110
from ..settings import Settings
12-
from ..env import Environment, compute_hash
11+
from ..env import Environment, compute_hash_of_file, compute_hash_of_string
1312
from ..registry import model_initialiser
1413

1514
from .errors import EnvironmentNotFound
@@ -76,11 +75,52 @@ async def _handle_worker_stop(self, signum, frame):
7675
)
7776

7877
async def _get_or_create(self, model: MLModel) -> InferencePool:
78+
if (
79+
model.settings.parameters is not None
80+
and model.settings.parameters.environment_path
81+
):
82+
pool = await self._get_or_create_with_existing_env(
83+
model.settings.parameters.environment_path
84+
)
85+
else:
86+
pool = await self._get_or_create_with_tarball(model)
87+
return pool
88+
89+
async def _get_or_create_with_existing_env(
90+
self, environment_path: str
91+
) -> InferencePool:
92+
"""
93+
Creates or returns the InferencePool for a model that uses an existing
94+
python environment.
95+
"""
96+
expanded_environment_path = os.path.abspath(
97+
os.path.expanduser(os.path.expandvars(environment_path))
98+
)
99+
logger.info(f"Using environment {expanded_environment_path}")
100+
env_hash = await compute_hash_of_string(expanded_environment_path)
101+
if env_hash in self._pools:
102+
return self._pools[env_hash]
103+
env = Environment(
104+
env_path=expanded_environment_path,
105+
env_hash=env_hash,
106+
delete_env=False,
107+
)
108+
pool = InferencePool(
109+
self._settings, env=env, on_worker_stop=self._on_worker_stop
110+
)
111+
self._pools[env_hash] = pool
112+
return pool
113+
114+
async def _get_or_create_with_tarball(self, model: MLModel) -> InferencePool:
115+
"""
116+
Creates or returns the InferencePool for a model that uses a
117+
tarball as python environment.
118+
"""
79119
env_tarball = _get_env_tarball(model)
80120
if not env_tarball:
81121
return self._default_pool
82122

83-
env_hash = await compute_hash(env_tarball)
123+
env_hash = await compute_hash_of_file(env_tarball)
84124
if env_hash in self._pools:
85125
return self._pools[env_hash]
86126

@@ -223,5 +263,3 @@ async def _close_pool(self, env_hash: Optional[str] = None):
223263

224264
if env_hash:
225265
del self._pools[env_hash]
226-
env_path = self._get_env_path(env_hash)
227-
shutil.rmtree(env_path)

mlserver/settings.py

+4
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,10 @@ class ModelParameters(BaseSettings):
305305
version: Optional[str] = None
306306
"""Version of the model."""
307307

308+
environment_path: Optional[str] = None
309+
"""Path to a directory that contains the python environment to be used
310+
to load this model."""
311+
308312
environment_tarball: Optional[str] = None
309313
"""Path to the environment tarball which should be used to load this
310314
model."""

tests/conftest.py

-4
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,6 @@ async def env(env_tarball: str, tmp_path: str) -> Environment:
9797
env = await Environment.from_tarball(env_tarball, str(tmp_path))
9898
yield env
9999

100-
# Envs can be quite heavy, so let's make sure we're clearing them up once
101-
# the test finishes
102-
shutil.rmtree(tmp_path)
103-
104100

105101
@pytest.fixture(autouse=True)
106102
def logger(settings: Settings):

tests/parallel/conftest.py

+15
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,21 @@ def env_model_settings(env_tarball: str) -> ModelSettings:
161161
)
162162

163163

164+
@pytest.fixture
165+
def existing_env_model_settings(env_tarball: str, tmp_path) -> ModelSettings:
166+
from mlserver.env import _extract_env
167+
168+
env_path = str(tmp_path)
169+
170+
_extract_env(env_tarball, env_path)
171+
model_settings = ModelSettings(
172+
name="exising_env_model",
173+
implementation=EnvModel,
174+
parameters=ModelParameters(environment_path=env_path),
175+
)
176+
yield model_settings
177+
178+
164179
@pytest.fixture
165180
async def worker_with_env(
166181
settings: Settings,

tests/parallel/test_registry.py

+31-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import asyncio
44

5-
from mlserver.env import Environment, compute_hash
5+
from mlserver.env import Environment, compute_hash_of_file
66
from mlserver.model import MLModel
77
from mlserver.settings import Settings, ModelSettings
88
from mlserver.types import InferenceRequest
@@ -30,6 +30,19 @@ async def env_model(
3030
await inference_pool_registry.unload_model(model)
3131

3232

33+
@pytest.fixture
34+
async def existing_env_model(
35+
inference_pool_registry: InferencePoolRegistry,
36+
existing_env_model_settings: ModelSettings,
37+
) -> MLModel:
38+
env_model = EnvModel(existing_env_model_settings)
39+
model = await inference_pool_registry.load_model(env_model)
40+
41+
yield model
42+
43+
await inference_pool_registry.unload_model(model)
44+
45+
3346
def test_set_environment_hash(sum_model: MLModel):
3447
env_hash = "0e46fce1decb7a89a8b91c71d8b6975630a17224d4f00094e02e1a732f8e95f3"
3548
_set_environment_hash(sum_model, env_hash)
@@ -90,6 +103,22 @@ async def test_load_model_with_env(
90103
assert sklearn_version == "1.0.2"
91104

92105

106+
async def test_load_model_with_existing_env(
107+
inference_pool_registry: InferencePoolRegistry,
108+
existing_env_model: MLModel,
109+
inference_request: InferenceRequest,
110+
):
111+
response = await existing_env_model.predict(inference_request)
112+
113+
assert len(response.outputs) == 1
114+
115+
# Note: These versions come from the `environment.yml` found in
116+
# `./tests/testdata/environment.yaml`
117+
assert response.outputs[0].name == "sklearn_version"
118+
[sklearn_version] = StringCodec.decode_output(response.outputs[0])
119+
assert sklearn_version == "1.0.2"
120+
121+
93122
async def test_load_creates_pool(
94123
inference_pool_registry: InferencePoolRegistry,
95124
env_model_settings: MLModel,
@@ -124,7 +153,7 @@ async def test_load_reuses_env_folder(
124153
new_model = EnvModel(env_model_settings)
125154

126155
# Make sure there's already existing env
127-
env_hash = await compute_hash(env_tarball)
156+
env_hash = await compute_hash_of_file(env_tarball)
128157
env_path = inference_pool_registry._get_env_path(env_hash)
129158
await Environment.from_tarball(env_tarball, env_path, env_hash)
130159

tests/test_env.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from typing import Tuple
77

8-
from mlserver.env import Environment, compute_hash
8+
from mlserver.env import Environment, compute_hash_of_file
99

1010

1111
@pytest.fixture
@@ -15,7 +15,7 @@ def expected_python_folder(env_python_version: Tuple[int, int]) -> str:
1515

1616

1717
async def test_compute_hash(env_tarball: str):
18-
env_hash = await compute_hash(env_tarball)
18+
env_hash = await compute_hash_of_file(env_tarball)
1919
assert len(env_hash) == 64
2020

2121

0 commit comments

Comments
 (0)