Skip to content

Commit 1dd9d49

Browse files
author
Richard Kennedy
committed
Support signing with CAI Server in sdk
1 parent a72b5f1 commit 1dd9d49

File tree

3 files changed

+39
-5
lines changed

3 files changed

+39
-5
lines changed

src/stability_sdk/api.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def generate(
122122
guidance_strength: float = 0.0,
123123
preset: Optional[str] = None,
124124
return_request: bool = False,
125+
sign_with_cai: bool = False,
125126
) -> Dict[int, List[Any]]:
126127
"""
127128
Generate an image from a set of weighted prompts.
@@ -164,7 +165,7 @@ def generate(
164165
start_schedule = 1.0 - init_strength
165166
image_params = self._build_image_params(width, height, sampler, steps, seed, samples, cfg_scale,
166167
start_schedule, init_noise_scale, masked_area_init,
167-
guidance_preset, guidance_cuts, guidance_strength)
168+
guidance_preset, guidance_cuts, guidance_strength, sign_with_cai)
168169

169170
extras = Struct()
170171
if preset and preset.lower() != 'none':
@@ -235,7 +236,7 @@ def inpaint(
235236
start_schedule = 1.0-init_strength
236237
image_params = self._build_image_params(width, height, sampler, steps, seed, samples, cfg_scale,
237238
start_schedule, init_noise_scale, masked_area_init,
238-
guidance_preset, guidance_cuts, guidance_strength)
239+
guidance_preset, guidance_cuts, guidance_strength, sign_with_cai=False)
239240

240241
extras = Struct()
241242
if preset and preset.lower() != 'none':
@@ -538,7 +539,7 @@ def _adjust_request_for_retry(self, request: generation.Request, attempt: int):
538539

539540
def _build_image_params(self, width, height, sampler, steps, seed, samples, cfg_scale,
540541
schedule_start, init_noise_scale, masked_area_init,
541-
guidance_preset, guidance_cuts, guidance_strength):
542+
guidance_preset, guidance_cuts, guidance_strength, sign_with_cai):
542543

543544
if not seed:
544545
seed = [random.randrange(0, 4294967295)]
@@ -568,6 +569,12 @@ def _build_image_params(self, width, height, sampler, steps, seed, samples, cfg_
568569
)
569570
]
570571
)
572+
# empty CAI Parameters will result in images not being signed by the CAI server
573+
caip = generation.CAIParameters()
574+
if sign_with_cai:
575+
caip = generation.CAIParameters(
576+
model_metadata=generation._CAIPARAMETERS_MODELMETADATA.values_by_name[
577+
'SIGN_WITH_ENGINE_ID'].number)
571578

572579
return generation.ImageParameters(
573580
transform=None if sampler is None else generation.TransformType(diffusion=sampler),
@@ -578,6 +585,7 @@ def _build_image_params(self, width, height, sampler, steps, seed, samples, cfg_
578585
samples=samples,
579586
masked_area_init=masked_area_init,
580587
parameters=[generation.StepParameter(**step_parameters)],
588+
cai_parameters=caip
581589
)
582590

583591
def _process_response(self, response) -> Dict[int, List[Any]]:

src/stability_sdk/client.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,9 @@ def process_cli(logger: logging.Logger = None,
498498
parser_generate.add_argument(
499499
"--width", "-W", type=int, default=512, help="[512] width of image"
500500
)
501+
parser_generate.add_argument(
502+
"--sign_with_cai", type=bool, default=False, help="Sign artifacts using C2PA to include providence data containing engine id"
503+
)
501504
parser_generate.add_argument(
502505
"--start_schedule",
503506
type=float,
@@ -626,11 +629,12 @@ def process_cli(logger: logging.Logger = None,
626629
"width": args.width,
627630
"start_schedule": args.start_schedule,
628631
"end_schedule": args.end_schedule,
629-
"cfg_scale": args.cfg_scale,
632+
"cfg_scale": args.cfg_scale,
630633
"seed": args.seed,
631634
"samples": args.num_samples,
632635
"init_image": args.init_image,
633636
"mask_image": args.mask_image,
637+
"sign_with_cai": args.sign_with_cai,
634638
}
635639

636640
if args.sampler:

tests/test_api.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def ChainGenerate(self, chain: generation.ChainRequest, **kwargs) -> Generator[g
3232
for answer in self.Generate(stage.request):
3333
artifacts.extend(answer.artifacts)
3434
for artifact in artifacts:
35-
yield generation.Answer(artifacts=[artifact])
35+
yield generation.Answer(artifacts=[artifact])
3636

3737
def Generate(self, request: generation.Request, **kwargs) -> Generator[generation.Answer, None, None]:
3838
if request.HasField("image"):
@@ -99,6 +99,28 @@ def test_api_generate():
9999
assert isinstance(image, Image.Image)
100100
assert image.size == (width, height)
101101

102+
def test_api_generate_cai_signing_set():
103+
class CAIMockStub(MockStub):
104+
def Generate(self, request: generation.Request, **kwargs) -> Generator[generation.Answer, None, None]:
105+
assert request.image.cai_parameters.model_metadata == \
106+
generation._CAIPARAMETERS_MODELMETADATA.values_by_name['SIGN_WITH_ENGINE_ID'].number
107+
return super().Generate(request, **kwargs)
108+
api = Context(stub=CAIMockStub())
109+
width, height = 512, 768
110+
results = api.generate(prompts=["foo bar"], weights=[1.0], width=width, height=height, sign_with_cai=True)
111+
112+
def test_api_generate_cai_signing_unset():
113+
class CAIMockStub(MockStub):
114+
def Generate(self, request: generation.Request, **kwargs) -> Generator[generation.Answer, None, None]:
115+
assert request.image.cai_parameters.model_metadata == \
116+
generation._CAIPARAMETERS_MODELMETADATA.values_by_name['METADATA_UNSPECIFIED'].number
117+
return super().Generate(request, **kwargs)
118+
api = Context(stub=CAIMockStub())
119+
width, height = 512, 768
120+
# sign_with_cai should default to false.
121+
results = api.generate(prompts=["foo bar"], weights=[1.0], width=width, height=height)
122+
123+
102124
def test_api_inpaint():
103125
api = Context(stub=MockStub())
104126
width, height = 512, 768

0 commit comments

Comments
 (0)