Skip to content

Commit

Permalink
add stablediffusionapi support.
Browse files Browse the repository at this point in the history
  • Loading branch information
iMactool committed May 19, 2023
1 parent f94a738 commit c9f806f
Show file tree
Hide file tree
Showing 7 changed files with 520 additions and 167 deletions.
42 changes: 35 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
> 我对它进行了一些改造,大部分功能保持了相同。在这里感谢一下 RuliLG ,实现了如此强大好用的 stable-diffusion 组件。
基于 Replicate API 的 Stable Diffusion 实现。
基于 Replicate API 和 stablediffusionapi 的 Stable Diffusion 实现。
- 🎨 Built-in prompt helper to create better images
- 🚀 Store the results in your database
- 🎇 Generate multiple images in the same API call
Expand All @@ -19,6 +19,7 @@
composer require imactool/hyperf-stable-diffusion
```
注意 表新增了`platform`平台字段。-- 后续会使用迁移增加,目前不考虑 :)

## 发布配置(包含配置文件和迁移文件)

Expand All @@ -44,9 +45,9 @@ return [
### 文字生成图片(Text to Image)
```php
use Imactool\HyperfStableDiffusion\Prompt;
use Imactool\HyperfStableDiffusion\StableDiffusion;
use Imactool\HyperfStableDiffusion\Replicate;

$result = StableDiffusion::make()->withPrompt(
$result = Replicate::make()->withPrompt(
Prompt::make()
->with('a panda sitting on the streets of New York after a long day of walking')
->photograph()
Expand All @@ -61,14 +62,14 @@ use Imactool\HyperfStableDiffusion\StableDiffusion;
### 图片生成图片(Image to Image)
```php
use Imactool\HyperfStableDiffusion\Prompt;
use Imactool\HyperfStableDiffusion\StableDiffusion;
use Imactool\HyperfStableDiffusion\Replicate;
use Intervention\Image\ImageManager;

//这里使用了 intervention/image 扩展来处理图片文件,你也可以更换为其他的
$sourceImg = (string) (new ImageManager(['driver' => 'imagick']))->make('path/image/source.png')->encode('data-url');

$prompt = 'Petite 21-year-old Caucasian female gamer streaming from her bedroom with pastel pink pigtails and gaming gear. Dynamic and engaging image inspired by colorful LED lights and the energy of Twitch culture, in 1920x1080 resolution.';
$result = StableDiffusion::make()
$result = Replicate::make()
->converVersion('a991dcab77024471af6a89ef758d98d1a54c5a25fc52a06ccfd7754b7ad04b35')
->withPrompt(
Prompt::make()
Expand All @@ -88,8 +89,8 @@ $result = StableDiffusion::make()
### 查询结果

```php
use Imactool\HyperfStableDiffusion\StableDiffusion;
$freshResults = StableDiffusion::get($replicate_id);
use Imactool\HyperfStableDiffusion\Replicate;
$freshResults = Replicate::get($replicate_id);

```

Expand Down Expand Up @@ -157,6 +158,33 @@ Additionally, you can add custom styles with the following methods:

To learn more on how to build prompts for Stable Diffusion, please [enter this link](https://beta.dreamstudio.ai/prompt-guide).

## 基于 stablediffusionapi 平台 [https://stablediffusionapi.com/docs/](https://stablediffusionapi.com/docs/)
或者 [Postman Collection](https://documenter.getpostman.com/view/18679074/2s83zdwReZ)

### 文字生成图片(Text to Image)
```php
use Imactool\HyperfStableDiffusion\StableDiffusion;

$res = StableDiffusion::make()
->useDreamboothApiV4()
->withPayload('key', '')
->withPayload('model_id', 'anything-v4')
->withPayload('prompt', 'ultra realistic close up portrait ((beautiful pale cyberpunk female with heavy black eyeliner)), blue eyes, shaved side haircut, hyper detail, cinematic lighting, magic neon, dark red city, Canon EOS R3, nikon, f/1.4, ISO 200, 1/160s, 8K, RAW, unedited, symmetrical balance, in-frame, 8K')
->withPayload('negative_prompt', 'painting, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, cloned face, skinny, glitchy, double torso, extra arms, extra hands, mangled fingers, missing lips, ugly face, distorted face, extra legs, anime')
->withPayload('width', '512')
->withPayload('height', '512')
->withPayload('samples', '1')
->withPayload('num_inference_steps', '30')
->withPayload('seed', null)
->withPayload('guidance_scale', '7.5')
->withPayload('webhook', null)
->withPayload('track_id', null)
->text2img();

var_dump( $res);

```


## License

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public function up(): void
Schema::create('stable_diffusion_results', function (Blueprint $table) {
$table->bigIncrements('id');
$table->string('replicate_id')->unique();
$table->string('platform')->comment('with platform');
$table->text('user_prompt');
$table->mediumText('full_prompt');
$table->string('url');
Expand Down
3 changes: 2 additions & 1 deletion src/Models/StableDiffusionResult.php
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
/**
* @property int $id
* @property string $replicate_id
* @property string $platform
* @property string $user_prompt
* @property string $full_prompt
* @property string $url
Expand All @@ -36,7 +37,7 @@ class StableDiffusionResult extends Model
/**
* The attributes that are mass assignable.
*/
protected array $fillable = ['replicate_id','user_prompt','full_prompt','url','status', 'output', 'error', 'predict_time'];
protected array $fillable = ['replicate_id', 'platform', 'user_prompt', 'full_prompt', 'url', 'status', 'output', 'error', 'predict_time'];

/**
* The attributes that should be cast to native types.
Expand Down
254 changes: 254 additions & 0 deletions src/Replicate.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
<?php

declare(strict_types=1);
/**
* This file is part of the imactool/hyperf-stable-diffusion.
*
* (c) imactool <[email protected]>
* This source file is subject to the MIT license that is bundled
* with this source code in the file LICENSE.
*/
namespace Imactool\HyperfStableDiffusion;

use Exception;
use Hyperf\Collection\Arr;
use Hyperf\Context\ApplicationContext;
use Hyperf\Guzzle\ClientFactory;
use Imactool\HyperfStableDiffusion\Models\StableDiffusionResult;
use PDOException;
use Psr\Http\Client\ClientInterface;

class Replicate
{
private static $platform = 'replicate';

private array $inputParams = [];

private string $baseUrl = '';

private string $token = '';

private string $version = '';

private function __construct(
public ?Prompt $prompt = null,
private int $width = 512,
private int $height = 512
) {
}

public function converUrl(string $url): self
{
$this->baseUrl = $url;
return $this;
}

public function getBaseUrl(): string
{
if (empty($this->baseUrl)) {
$this->baseUrl = config('stable-diffusion.url');
}
return $this->baseUrl;
}

public function converToken(string $token): self
{
$this->token = $token;
return $this;
}

public function getToken(): string
{
if (empty($this->token)) {
$this->token = config('stable-diffusion.token');
}
return $this->token;
}

public function converVersion(string $version): self
{
$this->version = $version;
return $this;
}

public function getVersion(): string
{
if (empty($this->version)) {
$this->version = config('stable-diffusion.version');
}
return $this->version;
}

public static function make(): self
{
return new self();
}

public function getV2(string $replicateId)
{
$result = StableDiffusionResult::query()->where('replicate_id', $replicateId)->first();
assert($result !== null, 'Unknown id');
$idleStatuses = ['starting', 'processing'];
if (! in_array($result->status, $idleStatuses)) {
return $result;
}

$response = $this->client()->get($result->url);

if ($response->getStatusCode() !== 200) {
throw new Exception('Failed to retrieve data.');
}

$responseData = json_decode((string) $response->getBody(), true);

$result->status = Arr::get($responseData, 'status', $result->status);
$result->output = Arr::has($responseData, 'output') ? Arr::get($responseData, 'output') : null;
$result->error = Arr::get($responseData, 'error');
$result->predict_time = Arr::get($responseData, 'metrics.predict_time');
$result->save();

return $result;
}

public static function get(string $replicateId)
{
$result = StableDiffusionResult::query()->where('replicate_id', $replicateId)->first();
assert($result !== null, 'Unknown id');
$idleStatuses = ['starting', 'processing'];
if (! in_array($result->status, $idleStatuses)) {
return $result;
}

$response = self::make()
->client()
->get($result->url);

if ($response->getStatusCode() !== 200) {
throw new Exception('Failed to retrieve data.');
}

$responseData = json_decode((string) $response->getBody(), true);

$result->status = Arr::get($responseData, 'status', $result->status);
$result->output = Arr::has($responseData, 'output') ? Arr::get($responseData, 'output') : null;
$result->error = Arr::get($responseData, 'error');
$result->predict_time = Arr::get($responseData, 'metrics.predict_time');
$result->save();

return $result;
}

public function withPrompt(Prompt $prompt)
{
$this->prompt = $prompt;
return $this;
}

/**
* except prompt,other API parameters.
*
* @param string $key 参数本身
* @param mixed $value 参数值
*
* @return $this
*/
public function inputParams(string $key, mixed $value)
{
$this->inputParams[$key] = $value;
return $this;
}

public function width(int $width)
{
assert($width > 0, 'Width must be greater than 0');
if ($width <= 768) {
assert($width <= 768 && $this->width <= 1024, 'Width must be lower than 768 and height lower than 1024');
} else {
assert($width <= 1024 && $this->width <= 768, 'Width must be lower than 1024 and height lower than 768');
}
$this->width = $width;
return $this;
}

public function height(int $height)
{
assert($height > 0, 'Height must be greater than 0');
if ($height <= 768) {
assert($height <= 768 && $this->width <= 1024, 'Height must be lower than 768 and width lower than 1024');
} else {
assert($height <= 1024 && $this->width <= 768, 'Height must be lower than 1024 and width lower than 768');
}

$this->height = $height;

return $this;
}

public function generate(int $numberOfImages)
{
assert($this->prompt !== null, 'You must provide a prompt');
assert($numberOfImages > 0, 'You must provide a number greater than 0');

$input = [
'prompt' => $this->prompt->toString(),
'num_outputs' => $numberOfImages,
];

$input = array_merge($input, $this->inputParams);

$response = $this->client()->post(
$this->getBaseUrl(),
[
'json' => [
'version' => $this->getVersion(),
'input' => $input,
],
]
);

$result = json_decode($response->getBody()->getContents(), true);

$data = [
'replicate_id' => $result['id'],
'platform' => self::$platform,
'user_prompt' => $this->prompt->userPrompt(),
'full_prompt' => $this->prompt->toString($this->inputParams),
'url' => $result['urls']['get'],
'status' => $result['status'],
'output' => isset($result['output']) ? $result['output'] : null,
'error' => $result['error'],
'predict_time' => null,
];

try {
StableDiffusionResult::create($data);
} catch (Exception $exception) {
$msg = $exception->getMessage();
var_dump(['data insert error' => $msg]);
if ($exception instanceof PDOException) {
$errorInfo = $exception->errorInfo;
$code = $errorInfo[1];
// $sql_state = $errorInfo[0];
// $msg = isset($errorInfo[2]) ? $errorInfo[2] : $sql_state;
}
if ((int) $code !== 1062) {
return $result;
}
}

return $result;
}

private function client(): ClientInterface
{
return ApplicationContext::getContainer()->get(ClientFactory::class)->create([
// 'base_uri' => $this->getBaseUrl(),
// 'timeout' => 10,
'headers' => [
'Authorization' => 'Token ' . $this->getToken(),
'Accept' => 'application/json',
'Content-Type' => 'application/json',
],
]);
}
}
Loading

0 comments on commit c9f806f

Please sign in to comment.