-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 03bbcca
Showing
9 changed files
with
243 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
FROM imaging-server-kit:3.9 | ||
|
||
COPY . . | ||
|
||
RUN python -m pip install -r requirements.txt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
![EPFL Center for Imaging logo](https://imaging.epfl.ch/resources/logo-for-gitlab.svg) | ||
# Rembg API in docker | ||
|
||
Implementation of a web API server for [rembg](https://github.com/danielgatis/rembg). | ||
|
||
Author: EPFL Center for Imaging |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from typing import List, Literal, Tuple, Type | ||
from pathlib import Path | ||
import skimage.io | ||
import numpy as np | ||
from pydantic import BaseModel, Field, validator | ||
import imaging_server_kit as serverkit | ||
import rembg | ||
|
||
|
||
class Parameters(BaseModel): | ||
image: str = Field( | ||
title="Image", | ||
description="Input image (2D grayscale, RGB)", | ||
json_schema_extra={"widget_type": "image"}, | ||
) | ||
rembg_model_name: Literal["silueta", "isnet", "u2net", "u2netp", "sam"] = Field( | ||
default="silueta", | ||
title="Model", | ||
description="The model used for background removal.", | ||
json_schema_extra={"widget_type": "dropdown"}, | ||
) | ||
|
||
@validator("image", pre=False, always=True) | ||
def decode_image_array(cls, v) -> np.ndarray: | ||
image_array = serverkit.decode_contents(v) | ||
if image_array.ndim not in [2, 3]: | ||
raise ValueError("Array has the wrong dimensionality.") | ||
return image_array | ||
|
||
|
||
class Server(serverkit.Server): | ||
def __init__( | ||
self, | ||
algorithm_name: str = "rembg", | ||
parameters_model: Type[BaseModel] = Parameters, | ||
): | ||
super().__init__(algorithm_name, parameters_model) | ||
|
||
self.sessions: dict[str, rembg.sessions.BaseSession] = {} | ||
|
||
def run_algorithm( | ||
self, image: np.ndarray, rembg_model_name: str = "silueta", **kwargs | ||
) -> List[Tuple]: | ||
"""Binary segmentation using rembg.""" | ||
|
||
session = self.sessions.setdefault( | ||
rembg_model_name, rembg.new_session(rembg_model_name) | ||
) | ||
|
||
if rembg_model_name == "sam": | ||
x0, y0, x1, y1 = 0, 0, image.shape[0], image.shape[1] | ||
|
||
prompt = [ | ||
{ | ||
"type": "rectangle", | ||
"data": [y0, x0, y1, x1], | ||
"label": 2, # `label` is irrelevant for SAM in bounding boxes mode | ||
} | ||
] | ||
|
||
segmentation = rembg.remove( | ||
data=image, | ||
session=session, | ||
only_mask=True, | ||
post_process_mask=True, | ||
sam_prompt=prompt, | ||
**kwargs, | ||
) | ||
segmentation = segmentation == 0 # Invert it (for some reason) | ||
|
||
else: | ||
segmentation = rembg.remove( | ||
data=image, | ||
session=session, | ||
only_mask=True, | ||
post_process_mask=True, | ||
**kwargs, | ||
) | ||
segmentation = segmentation == 255 | ||
|
||
segmentation_params = { | ||
"name": f"{rembg_model_name}_result", | ||
} | ||
|
||
return [ | ||
(segmentation, segmentation_params, "labels"), | ||
] | ||
|
||
def load_sample_images(self) -> List["np.ndarray"]: | ||
"""Load one or multiple sample images.""" | ||
image_dir = Path(__file__).parent / "sample_images" | ||
images = [skimage.io.imread(image_path) for image_path in image_dir.glob("*")] | ||
return images | ||
|
||
|
||
server = Server() | ||
app = server.app |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
package_name: "rembg" | ||
project_url: "https://github.com/danielgatis/rembg" | ||
serverkit_repo_url: "https://github.com/Imaging-Server-Kit/serverkit-rembg" | ||
serverkit_author: "EPFL Center for Imaging" | ||
project_name: "Rembg" | ||
description: "A tool to remove images background." | ||
used_for: | ||
- "Segmentation" | ||
tags: | ||
- "Deep learning" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
rembg |
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
""" | ||
Test suite for the run_algorithm function, which is tested both server-side and client-side. | ||
""" | ||
|
||
import sys | ||
import os | ||
import subprocess | ||
import time | ||
|
||
import numpy as np | ||
from imaging_server_kit.client import Client | ||
|
||
sys.path.append(os.path.dirname(os.path.dirname(__file__))) | ||
|
||
from main import Server | ||
|
||
|
||
def run_algorithm_server_side(): | ||
# Create a Server instance | ||
server = Server() | ||
|
||
# Get the algorithm parameters | ||
algo_params_schema = server.parameters_model.model_json_schema() | ||
required_params = algo_params_schema.get("required") | ||
algo_params = algo_params_schema.get("properties") | ||
|
||
# Get default parameter values for all non-requried parameters (which do not have defaults) | ||
default_params = {} | ||
for param, param_values in algo_params.items(): | ||
if param not in required_params: | ||
default_params[param] = param_values.get("default") | ||
|
||
# Add values to the required parameters, for example an input image | ||
sample_image = server.load_sample_images()[0] # First sample image | ||
default_params["image"] = sample_image | ||
|
||
# All parameters must have values set to run the algorithm | ||
params_missing = set(default_params.keys()) - set(algo_params.keys()) | ||
assert ( | ||
len(params_missing) == 0 | ||
), f"Values are missing for required parameters: {params_missing}" | ||
|
||
# Run the algorithm | ||
algo_output = server.run_algorithm(**default_params) | ||
|
||
# Examine the output (add relevant assert statements to test the algorithm) | ||
for data, data_params, data_type in algo_output: | ||
if data_type == "image": | ||
assert isinstance( | ||
data, np.ndarray | ||
), "Algorithm did not output a Numpy array." | ||
elif data_type == "labels": | ||
assert isinstance( | ||
data, np.ndarray | ||
), "Algorithm did not output a Numpy array." | ||
|
||
return algo_output | ||
|
||
|
||
def run_algorithm_client_side(): | ||
# Start the FastAPI server using uvicorn | ||
server_process = subprocess.Popen( | ||
["uvicorn", "main:app", "--host", "127.0.0.1", "--port", "8000"] | ||
) | ||
time.sleep(2) # Wait for the server to start | ||
|
||
try: | ||
# Connect to the server | ||
client = Client("http://localhost:8000") | ||
|
||
# A single algorithm should be available | ||
assert ( | ||
len(client.algorithms) > 0 | ||
), f"No algorithm available: {client.algorithms}" | ||
assert ( | ||
len(client.algorithms) < 2 | ||
), f"More than one algorithm available: {client.algorithms}" | ||
|
||
# Get the algorithm parameters | ||
algo_params_schema = client.get_algorithm_parameters() | ||
required_params = algo_params_schema.get("required") | ||
algo_params = algo_params_schema.get("properties") | ||
|
||
# Get default parameter values for all non-requried parameters (which do not have defaults) | ||
default_params = {} | ||
for param, param_values in algo_params.items(): | ||
if param not in required_params: | ||
default_params[param] = param_values.get("default") | ||
|
||
# Add values to the required parameters, for example an input image | ||
default_params["image"] = client.get_sample_images( | ||
first_only=True | ||
) # First sample image | ||
|
||
# All parameters must have values set to run the algorithm | ||
default_params_set = set(default_params.keys()) | ||
algo_params_set = set(algo_params.keys()) | ||
params_missing = algo_params_set.difference(default_params_set) | ||
|
||
assert ( | ||
len(params_missing) == 0 | ||
), f"Values are missing for required parameters: {params_missing}" | ||
|
||
# Run the algorithm | ||
algo_output = client.run_algorithm(**default_params) | ||
finally: | ||
# Shut down the server | ||
server_process.terminate() | ||
server_process.wait() | ||
|
||
return algo_output | ||
|
||
|
||
def test_compare_algo_outputs(): | ||
server_output = run_algorithm_server_side() | ||
client_output = run_algorithm_client_side() | ||
|
||
for server_data_tuple, client_data_tuple in zip(server_output, client_output): | ||
assert ( | ||
server_data_tuple[2] == client_data_tuple[2] | ||
), "Server and client output data types are different." | ||
assert np.allclose( | ||
server_data_tuple[0], client_data_tuple[0] | ||
), "Server and client algorithm outputs do not match." |