@@ -162,11 +162,20 @@ func (c *Connection) Execute(inputs []*connectorPB.DataPayload) ([]*connectorPB.
162162 Steps : uint32 (dataPayload .GetMetadata ().GetFields ()["steps" ].GetNumberValue ()),
163163 StylePreset : dataPayload .GetMetadata ().GetFields ()["style_preset" ].GetStringValue (),
164164 Height : uint32 (dataPayload .GetMetadata ().GetFields ()["height" ].GetNumberValue ()),
165- Width : uint32 (dataPayload .GetMetadata ().GetFields ()["weight" ].GetNumberValue ()),
165+ Width : uint32 (dataPayload .GetMetadata ().GetFields ()["width" ].GetNumberValue ()),
166+ }
167+ weights := dataPayload .GetMetadata ().GetFields ()["weights" ].GetListValue ().GetValues ()
168+ //if no weights are given
169+ if weights == nil {
170+ weights = []* structpb.Value {}
166171 }
167172 req .TextPrompts = make ([]TextPrompt , 0 , len (dataPayload .Texts ))
168- for _ , t := range dataPayload .Texts {
169- req .TextPrompts = append (req .TextPrompts , TextPrompt {Text : t })
173+ var w float32
174+ for index , t := range dataPayload .Texts {
175+ if len (weights ) > index {
176+ w = float32 (weights [index ].GetNumberValue ())
177+ }
178+ req .TextPrompts = append (req .TextPrompts , TextPrompt {Text : t , Weight : w })
170179 }
171180 images , err := client .GenerateImageFromText (req , engine )
172181 if err != nil {
@@ -205,9 +214,18 @@ func (c *Connection) Execute(inputs []*connectorPB.DataPayload) ([]*connectorPB.
205214 InitImageMode : dataPayload .GetMetadata ().GetFields ()["init_image_mode" ].GetStringValue (),
206215 ImageStrength : dataPayload .GetMetadata ().GetFields ()["image_strength" ].GetNumberValue (),
207216 }
217+ weights := dataPayload .GetMetadata ().GetFields ()["weights" ].GetListValue ().GetValues ()
218+ //if no weights are given
219+ if weights == nil {
220+ weights = []* structpb.Value {}
221+ }
208222 req .TextPrompts = make ([]TextPrompt , 0 , len (dataPayload .Texts ))
209- for _ , t := range dataPayload .Texts {
210- req .TextPrompts = append (req .TextPrompts , TextPrompt {Text : t })
223+ var w float32
224+ for index , t := range dataPayload .Texts {
225+ if len (weights ) > index {
226+ w = float32 (weights [index ].GetNumberValue ())
227+ }
228+ req .TextPrompts = append (req .TextPrompts , TextPrompt {Text : t , Weight : w })
211229 }
212230 images , err := client .GenerateImageFromImage (req , engine )
213231 if err != nil {
0 commit comments