|
| 1 | +<?php |
| 2 | + |
| 3 | + declare(strict_types=1); |
| 4 | +/** |
| 5 | + * This file is part of the imactool/hyperf-stable-diffusion. |
| 6 | + * |
| 7 | + * (c) imactool <[email protected]> |
| 8 | + * This source file is subject to the MIT license that is bundled |
| 9 | + * with this source code in the file LICENSE. |
| 10 | + */ |
| 11 | +namespace Imactool\HyperfStableDiffusion; |
| 12 | + |
| 13 | + use Exception; |
| 14 | + use Hyperf\Collection\Arr; |
| 15 | + use Hyperf\Context\ApplicationContext; |
| 16 | + use Hyperf\Guzzle\ClientFactory; |
| 17 | + use Imactool\HyperfStableDiffusion\Models\StableDiffusionResult; |
| 18 | + use PDOException; |
| 19 | + use Psr\Http\Client\ClientInterface; |
| 20 | + |
| 21 | + class Replicate |
| 22 | + { |
| 23 | + private static $platform = 'replicate'; |
| 24 | + |
| 25 | + private array $inputParams = []; |
| 26 | + |
| 27 | + private string $baseUrl = ''; |
| 28 | + |
| 29 | + private string $token = ''; |
| 30 | + |
| 31 | + private string $version = ''; |
| 32 | + |
| 33 | + private function __construct( |
| 34 | + public ?Prompt $prompt = null, |
| 35 | + private int $width = 512, |
| 36 | + private int $height = 512 |
| 37 | + ) { |
| 38 | + } |
| 39 | + |
| 40 | + public function converUrl(string $url): self |
| 41 | + { |
| 42 | + $this->baseUrl = $url; |
| 43 | + return $this; |
| 44 | + } |
| 45 | + |
| 46 | + public function getBaseUrl(): string |
| 47 | + { |
| 48 | + if (empty($this->baseUrl)) { |
| 49 | + $this->baseUrl = config('stable-diffusion.url'); |
| 50 | + } |
| 51 | + return $this->baseUrl; |
| 52 | + } |
| 53 | + |
| 54 | + public function converToken(string $token): self |
| 55 | + { |
| 56 | + $this->token = $token; |
| 57 | + return $this; |
| 58 | + } |
| 59 | + |
| 60 | + public function getToken(): string |
| 61 | + { |
| 62 | + if (empty($this->token)) { |
| 63 | + $this->token = config('stable-diffusion.token'); |
| 64 | + } |
| 65 | + return $this->token; |
| 66 | + } |
| 67 | + |
| 68 | + public function converVersion(string $version): self |
| 69 | + { |
| 70 | + $this->version = $version; |
| 71 | + return $this; |
| 72 | + } |
| 73 | + |
| 74 | + public function getVersion(): string |
| 75 | + { |
| 76 | + if (empty($this->version)) { |
| 77 | + $this->version = config('stable-diffusion.version'); |
| 78 | + } |
| 79 | + return $this->version; |
| 80 | + } |
| 81 | + |
| 82 | + public static function make(): self |
| 83 | + { |
| 84 | + return new self(); |
| 85 | + } |
| 86 | + |
| 87 | + public function getV2(string $replicateId) |
| 88 | + { |
| 89 | + $result = StableDiffusionResult::query()->where('replicate_id', $replicateId)->first(); |
| 90 | + assert($result !== null, 'Unknown id'); |
| 91 | + $idleStatuses = ['starting', 'processing']; |
| 92 | + if (! in_array($result->status, $idleStatuses)) { |
| 93 | + return $result; |
| 94 | + } |
| 95 | + |
| 96 | + $response = $this->client()->get($result->url); |
| 97 | + |
| 98 | + if ($response->getStatusCode() !== 200) { |
| 99 | + throw new Exception('Failed to retrieve data.'); |
| 100 | + } |
| 101 | + |
| 102 | + $responseData = json_decode((string) $response->getBody(), true); |
| 103 | + |
| 104 | + $result->status = Arr::get($responseData, 'status', $result->status); |
| 105 | + $result->output = Arr::has($responseData, 'output') ? Arr::get($responseData, 'output') : null; |
| 106 | + $result->error = Arr::get($responseData, 'error'); |
| 107 | + $result->predict_time = Arr::get($responseData, 'metrics.predict_time'); |
| 108 | + $result->save(); |
| 109 | + |
| 110 | + return $result; |
| 111 | + } |
| 112 | + |
| 113 | + public static function get(string $replicateId) |
| 114 | + { |
| 115 | + $result = StableDiffusionResult::query()->where('replicate_id', $replicateId)->first(); |
| 116 | + assert($result !== null, 'Unknown id'); |
| 117 | + $idleStatuses = ['starting', 'processing']; |
| 118 | + if (! in_array($result->status, $idleStatuses)) { |
| 119 | + return $result; |
| 120 | + } |
| 121 | + |
| 122 | + $response = self::make() |
| 123 | + ->client() |
| 124 | + ->get($result->url); |
| 125 | + |
| 126 | + if ($response->getStatusCode() !== 200) { |
| 127 | + throw new Exception('Failed to retrieve data.'); |
| 128 | + } |
| 129 | + |
| 130 | + $responseData = json_decode((string) $response->getBody(), true); |
| 131 | + |
| 132 | + $result->status = Arr::get($responseData, 'status', $result->status); |
| 133 | + $result->output = Arr::has($responseData, 'output') ? Arr::get($responseData, 'output') : null; |
| 134 | + $result->error = Arr::get($responseData, 'error'); |
| 135 | + $result->predict_time = Arr::get($responseData, 'metrics.predict_time'); |
| 136 | + $result->save(); |
| 137 | + |
| 138 | + return $result; |
| 139 | + } |
| 140 | + |
| 141 | + public function withPrompt(Prompt $prompt) |
| 142 | + { |
| 143 | + $this->prompt = $prompt; |
| 144 | + return $this; |
| 145 | + } |
| 146 | + |
| 147 | + /** |
| 148 | + * except prompt,other API parameters. |
| 149 | + * |
| 150 | + * @param string $key 参数本身 |
| 151 | + * @param mixed $value 参数值 |
| 152 | + * |
| 153 | + * @return $this |
| 154 | + */ |
| 155 | + public function inputParams(string $key, mixed $value) |
| 156 | + { |
| 157 | + $this->inputParams[$key] = $value; |
| 158 | + return $this; |
| 159 | + } |
| 160 | + |
| 161 | + public function width(int $width) |
| 162 | + { |
| 163 | + assert($width > 0, 'Width must be greater than 0'); |
| 164 | + if ($width <= 768) { |
| 165 | + assert($width <= 768 && $this->width <= 1024, 'Width must be lower than 768 and height lower than 1024'); |
| 166 | + } else { |
| 167 | + assert($width <= 1024 && $this->width <= 768, 'Width must be lower than 1024 and height lower than 768'); |
| 168 | + } |
| 169 | + $this->width = $width; |
| 170 | + return $this; |
| 171 | + } |
| 172 | + |
| 173 | + public function height(int $height) |
| 174 | + { |
| 175 | + assert($height > 0, 'Height must be greater than 0'); |
| 176 | + if ($height <= 768) { |
| 177 | + assert($height <= 768 && $this->width <= 1024, 'Height must be lower than 768 and width lower than 1024'); |
| 178 | + } else { |
| 179 | + assert($height <= 1024 && $this->width <= 768, 'Height must be lower than 1024 and width lower than 768'); |
| 180 | + } |
| 181 | + |
| 182 | + $this->height = $height; |
| 183 | + |
| 184 | + return $this; |
| 185 | + } |
| 186 | + |
| 187 | + public function generate(int $numberOfImages) |
| 188 | + { |
| 189 | + assert($this->prompt !== null, 'You must provide a prompt'); |
| 190 | + assert($numberOfImages > 0, 'You must provide a number greater than 0'); |
| 191 | + |
| 192 | + $input = [ |
| 193 | + 'prompt' => $this->prompt->toString(), |
| 194 | + 'num_outputs' => $numberOfImages, |
| 195 | + ]; |
| 196 | + |
| 197 | + $input = array_merge($input, $this->inputParams); |
| 198 | + |
| 199 | + $response = $this->client()->post( |
| 200 | + $this->getBaseUrl(), |
| 201 | + [ |
| 202 | + 'json' => [ |
| 203 | + 'version' => $this->getVersion(), |
| 204 | + 'input' => $input, |
| 205 | + ], |
| 206 | + ] |
| 207 | + ); |
| 208 | + |
| 209 | + $result = json_decode($response->getBody()->getContents(), true); |
| 210 | + |
| 211 | + $data = [ |
| 212 | + 'replicate_id' => $result['id'], |
| 213 | + 'platform' => self::$platform, |
| 214 | + 'user_prompt' => $this->prompt->userPrompt(), |
| 215 | + 'full_prompt' => $this->prompt->toString($this->inputParams), |
| 216 | + 'url' => $result['urls']['get'], |
| 217 | + 'status' => $result['status'], |
| 218 | + 'output' => isset($result['output']) ? $result['output'] : null, |
| 219 | + 'error' => $result['error'], |
| 220 | + 'predict_time' => null, |
| 221 | + ]; |
| 222 | + |
| 223 | + try { |
| 224 | + StableDiffusionResult::create($data); |
| 225 | + } catch (Exception $exception) { |
| 226 | + $msg = $exception->getMessage(); |
| 227 | + var_dump(['data insert error' => $msg]); |
| 228 | + if ($exception instanceof PDOException) { |
| 229 | + $errorInfo = $exception->errorInfo; |
| 230 | + $code = $errorInfo[1]; |
| 231 | + // $sql_state = $errorInfo[0]; |
| 232 | + // $msg = isset($errorInfo[2]) ? $errorInfo[2] : $sql_state; |
| 233 | + } |
| 234 | + if ((int) $code !== 1062) { |
| 235 | + return $result; |
| 236 | + } |
| 237 | + } |
| 238 | + |
| 239 | + return $result; |
| 240 | + } |
| 241 | + |
| 242 | + private function client(): ClientInterface |
| 243 | + { |
| 244 | + return ApplicationContext::getContainer()->get(ClientFactory::class)->create([ |
| 245 | + // 'base_uri' => $this->getBaseUrl(), |
| 246 | + // 'timeout' => 10, |
| 247 | + 'headers' => [ |
| 248 | + 'Authorization' => 'Token ' . $this->getToken(), |
| 249 | + 'Accept' => 'application/json', |
| 250 | + 'Content-Type' => 'application/json', |
| 251 | + ], |
| 252 | + ]); |
| 253 | + } |
| 254 | + } |
0 commit comments