Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
MalloryWittwer committed Jan 6, 2025
0 parents commit 03bbcca
Show file tree
Hide file tree
Showing 9 changed files with 243 additions and 0 deletions.
5 changes: 5 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
FROM imaging-server-kit:3.9

COPY . .

RUN python -m pip install -r requirements.txt
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
![EPFL Center for Imaging logo](https://imaging.epfl.ch/resources/logo-for-gitlab.svg)
# Rembg API in docker

Implementation of a web API server for [rembg](https://github.com/danielgatis/rembg).

Author: EPFL Center for Imaging
Binary file added __pycache__/main.cpython-39.pyc
Binary file not shown.
97 changes: 97 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from typing import List, Literal, Tuple, Type
from pathlib import Path
import skimage.io
import numpy as np
from pydantic import BaseModel, Field, validator
import imaging_server_kit as serverkit
import rembg


class Parameters(BaseModel):
image: str = Field(
title="Image",
description="Input image (2D grayscale, RGB)",
json_schema_extra={"widget_type": "image"},
)
rembg_model_name: Literal["silueta", "isnet", "u2net", "u2netp", "sam"] = Field(
default="silueta",
title="Model",
description="The model used for background removal.",
json_schema_extra={"widget_type": "dropdown"},
)

@validator("image", pre=False, always=True)
def decode_image_array(cls, v) -> np.ndarray:
image_array = serverkit.decode_contents(v)
if image_array.ndim not in [2, 3]:
raise ValueError("Array has the wrong dimensionality.")
return image_array


class Server(serverkit.Server):
def __init__(
self,
algorithm_name: str = "rembg",
parameters_model: Type[BaseModel] = Parameters,
):
super().__init__(algorithm_name, parameters_model)

self.sessions: dict[str, rembg.sessions.BaseSession] = {}

def run_algorithm(
self, image: np.ndarray, rembg_model_name: str = "silueta", **kwargs
) -> List[Tuple]:
"""Binary segmentation using rembg."""

session = self.sessions.setdefault(
rembg_model_name, rembg.new_session(rembg_model_name)
)

if rembg_model_name == "sam":
x0, y0, x1, y1 = 0, 0, image.shape[0], image.shape[1]

prompt = [
{
"type": "rectangle",
"data": [y0, x0, y1, x1],
"label": 2, # `label` is irrelevant for SAM in bounding boxes mode
}
]

segmentation = rembg.remove(
data=image,
session=session,
only_mask=True,
post_process_mask=True,
sam_prompt=prompt,
**kwargs,
)
segmentation = segmentation == 0 # Invert it (for some reason)

else:
segmentation = rembg.remove(
data=image,
session=session,
only_mask=True,
post_process_mask=True,
**kwargs,
)
segmentation = segmentation == 255

segmentation_params = {
"name": f"{rembg_model_name}_result",
}

return [
(segmentation, segmentation_params, "labels"),
]

def load_sample_images(self) -> List["np.ndarray"]:
"""Load one or multiple sample images."""
image_dir = Path(__file__).parent / "sample_images"
images = [skimage.io.imread(image_path) for image_path in image_dir.glob("*")]
return images


server = Server()
app = server.app
10 changes: 10 additions & 0 deletions metadata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package_name: "rembg"
project_url: "https://github.com/danielgatis/rembg"
serverkit_repo_url: "https://github.com/Imaging-Server-Kit/serverkit-rembg"
serverkit_author: "EPFL Center for Imaging"
project_name: "Rembg"
description: "A tool to remove images background."
used_for:
- "Segmentation"
tags:
- "Deep learning"
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
rembg
Binary file added sample_images/astronaut.tif
Binary file not shown.
Binary file not shown.
124 changes: 124 additions & 0 deletions tests/test_run_algorithm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""
Test suite for the run_algorithm function, which is tested both server-side and client-side.
"""

import sys
import os
import subprocess
import time

import numpy as np
from imaging_server_kit.client import Client

sys.path.append(os.path.dirname(os.path.dirname(__file__)))

from main import Server


def run_algorithm_server_side():
# Create a Server instance
server = Server()

# Get the algorithm parameters
algo_params_schema = server.parameters_model.model_json_schema()
required_params = algo_params_schema.get("required")
algo_params = algo_params_schema.get("properties")

# Get default parameter values for all non-requried parameters (which do not have defaults)
default_params = {}
for param, param_values in algo_params.items():
if param not in required_params:
default_params[param] = param_values.get("default")

# Add values to the required parameters, for example an input image
sample_image = server.load_sample_images()[0] # First sample image
default_params["image"] = sample_image

# All parameters must have values set to run the algorithm
params_missing = set(default_params.keys()) - set(algo_params.keys())
assert (
len(params_missing) == 0
), f"Values are missing for required parameters: {params_missing}"

# Run the algorithm
algo_output = server.run_algorithm(**default_params)

# Examine the output (add relevant assert statements to test the algorithm)
for data, data_params, data_type in algo_output:
if data_type == "image":
assert isinstance(
data, np.ndarray
), "Algorithm did not output a Numpy array."
elif data_type == "labels":
assert isinstance(
data, np.ndarray
), "Algorithm did not output a Numpy array."

return algo_output


def run_algorithm_client_side():
# Start the FastAPI server using uvicorn
server_process = subprocess.Popen(
["uvicorn", "main:app", "--host", "127.0.0.1", "--port", "8000"]
)
time.sleep(2) # Wait for the server to start

try:
# Connect to the server
client = Client("http://localhost:8000")

# A single algorithm should be available
assert (
len(client.algorithms) > 0
), f"No algorithm available: {client.algorithms}"
assert (
len(client.algorithms) < 2
), f"More than one algorithm available: {client.algorithms}"

# Get the algorithm parameters
algo_params_schema = client.get_algorithm_parameters()
required_params = algo_params_schema.get("required")
algo_params = algo_params_schema.get("properties")

# Get default parameter values for all non-requried parameters (which do not have defaults)
default_params = {}
for param, param_values in algo_params.items():
if param not in required_params:
default_params[param] = param_values.get("default")

# Add values to the required parameters, for example an input image
default_params["image"] = client.get_sample_images(
first_only=True
) # First sample image

# All parameters must have values set to run the algorithm
default_params_set = set(default_params.keys())
algo_params_set = set(algo_params.keys())
params_missing = algo_params_set.difference(default_params_set)

assert (
len(params_missing) == 0
), f"Values are missing for required parameters: {params_missing}"

# Run the algorithm
algo_output = client.run_algorithm(**default_params)
finally:
# Shut down the server
server_process.terminate()
server_process.wait()

return algo_output


def test_compare_algo_outputs():
server_output = run_algorithm_server_side()
client_output = run_algorithm_client_side()

for server_data_tuple, client_data_tuple in zip(server_output, client_output):
assert (
server_data_tuple[2] == client_data_tuple[2]
), "Server and client output data types are different."
assert np.allclose(
server_data_tuple[0], client_data_tuple[0]
), "Server and client algorithm outputs do not match."

0 comments on commit 03bbcca

Please sign in to comment.