Skip to content
This repository was archived by the owner on Jan 9, 2025. It is now read-only.

Commit 1cab5fa

Browse files
authored
fix: add support for weights in stability AI (#18)
Because - current implementation doesn't have support to add weights in the prompt This commit - adds support to add weights in the prompt from metadata
1 parent 2969470 commit 1cab5fa

1 file changed

Lines changed: 23 additions & 5 deletions

File tree

pkg/stabilityai/main.go

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,20 @@ func (c *Connection) Execute(inputs []*connectorPB.DataPayload) ([]*connectorPB.
162162
Steps: uint32(dataPayload.GetMetadata().GetFields()["steps"].GetNumberValue()),
163163
StylePreset: dataPayload.GetMetadata().GetFields()["style_preset"].GetStringValue(),
164164
Height: uint32(dataPayload.GetMetadata().GetFields()["height"].GetNumberValue()),
165-
Width: uint32(dataPayload.GetMetadata().GetFields()["weight"].GetNumberValue()),
165+
Width: uint32(dataPayload.GetMetadata().GetFields()["width"].GetNumberValue()),
166+
}
167+
weights := dataPayload.GetMetadata().GetFields()["weights"].GetListValue().GetValues()
168+
//if no weights are given
169+
if weights == nil {
170+
weights = []*structpb.Value{}
166171
}
167172
req.TextPrompts = make([]TextPrompt, 0, len(dataPayload.Texts))
168-
for _, t := range dataPayload.Texts {
169-
req.TextPrompts = append(req.TextPrompts, TextPrompt{Text: t})
173+
var w float32
174+
for index, t := range dataPayload.Texts {
175+
if len(weights) > index {
176+
w = float32(weights[index].GetNumberValue())
177+
}
178+
req.TextPrompts = append(req.TextPrompts, TextPrompt{Text: t, Weight: w})
170179
}
171180
images, err := client.GenerateImageFromText(req, engine)
172181
if err != nil {
@@ -205,9 +214,18 @@ func (c *Connection) Execute(inputs []*connectorPB.DataPayload) ([]*connectorPB.
205214
InitImageMode: dataPayload.GetMetadata().GetFields()["init_image_mode"].GetStringValue(),
206215
ImageStrength: dataPayload.GetMetadata().GetFields()["image_strength"].GetNumberValue(),
207216
}
217+
weights := dataPayload.GetMetadata().GetFields()["weights"].GetListValue().GetValues()
218+
//if no weights are given
219+
if weights == nil {
220+
weights = []*structpb.Value{}
221+
}
208222
req.TextPrompts = make([]TextPrompt, 0, len(dataPayload.Texts))
209-
for _, t := range dataPayload.Texts {
210-
req.TextPrompts = append(req.TextPrompts, TextPrompt{Text: t})
223+
var w float32
224+
for index, t := range dataPayload.Texts {
225+
if len(weights) > index {
226+
w = float32(weights[index].GetNumberValue())
227+
}
228+
req.TextPrompts = append(req.TextPrompts, TextPrompt{Text: t, Weight: w})
211229
}
212230
images, err := client.GenerateImageFromImage(req, engine)
213231
if err != nil {

0 commit comments

Comments
 (0)