diff --git a/core/orchestrator.go b/core/orchestrator.go index 185c343f82..f6f3166255 100644 --- a/core/orchestrator.go +++ b/core/orchestrator.go @@ -113,7 +113,7 @@ func (orch *orchestrator) ImageToImage(ctx context.Context, req worker.ImageToIm return orch.node.imageToImage(ctx, req) } -func (orch *orchestrator) ImageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) ([]*TranscodeResult, error) { +func (orch *orchestrator) ImageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) { return orch.node.imageToVideo(ctx, req) } @@ -892,7 +892,7 @@ func (n *LivepeerNode) imageToImage(ctx context.Context, req worker.ImageToImage return n.AIWorker.ImageToImage(ctx, req) } -func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) ([]*TranscodeResult, error) { +func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) { // We might support generating more than one video in the future (i.e. multiple input images/prompts) numVideos := 1 @@ -911,8 +911,9 @@ func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.ImageToVideo clog.V(common.DEBUG).Infof(ctx, "Generating frames took=%v", took) sessionID := string(RandomManifestID()) + // HACK: Re-use worker.ImageResponse to return results // Transcode frames into segments. - results := make([]*TranscodeResult, len(resp.Frames)) + videos := make([]worker.Media, len(resp.Frames)) for i, batch := range resp.Frames { // Create slice of frame urls for a batch urls := make([]string, len(batch)) @@ -926,10 +927,25 @@ func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.ImageToVideo return nil, res.Err } - results[i] = res + // Assume only single rendition right now + seg := res.TranscodeData.Segments[0] + name := fmt.Sprintf("%v.mp4", RandomManifestID()) + segData := bytes.NewReader(seg.Data) + uri, err := res.OS.SaveData(ctx, name, segData, nil, 0) + if err != nil { + return nil, err + } + + videos[i] = worker.Media{ + Url: uri, + } + + if len(batch) > 0 { + videos[i].Seed = batch[0].Seed + } } - return results, nil + return &worker.ImageResponse{Images: videos}, nil } func (rtm *RemoteTranscoderManager) transcoderResults(tcID int64, res *RemoteTranscoderResult) { diff --git a/go.mod b/go.mod index 6bb76ba70b..f4a1cbaaca 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/golang/protobuf v1.5.3 github.com/jaypipes/ghw v0.10.0 github.com/jaypipes/pcidb v1.0.0 - github.com/livepeer/ai-worker v0.0.0-20240205185039-5c4895915580 + github.com/livepeer/ai-worker v0.0.0-20240208153040-7c92507e2a40 github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b github.com/livepeer/livepeer-data v0.7.5-0.20231004073737-06f1f383fb18 github.com/livepeer/lpms v0.0.0-20240120150405-de94555cdc69 diff --git a/go.sum b/go.sum index 6b077971ba..d40375f0f9 100644 --- a/go.sum +++ b/go.sum @@ -541,6 +541,10 @@ github.com/livepeer/ai-worker v0.0.0-20240202211855-823caeaa265f h1:8owDNiBfN0j6 github.com/livepeer/ai-worker v0.0.0-20240202211855-823caeaa265f/go.mod h1:3+A2/SYTqs+551SKTPy20AVnB8b0Yp26Va5SY37eQ/4= github.com/livepeer/ai-worker v0.0.0-20240205185039-5c4895915580 h1:7ACCHUpeJsoWADgST/nWfGD0LVRSXFcYG6FTGvzUGn4= github.com/livepeer/ai-worker v0.0.0-20240205185039-5c4895915580/go.mod h1:3+A2/SYTqs+551SKTPy20AVnB8b0Yp26Va5SY37eQ/4= +github.com/livepeer/ai-worker v0.0.0-20240207221157-87e4f48ec353 h1:Ee1+i+q1EpP9D3AOufAnMSyEP06zaRhcyMRfSk6GJF8= +github.com/livepeer/ai-worker v0.0.0-20240207221157-87e4f48ec353/go.mod h1:3+A2/SYTqs+551SKTPy20AVnB8b0Yp26Va5SY37eQ/4= +github.com/livepeer/ai-worker v0.0.0-20240208153040-7c92507e2a40 h1:vVbuu5wqrzq6M6Rlutk0eZv6qZ/kO2OrqQv5n6yt57s= +github.com/livepeer/ai-worker v0.0.0-20240208153040-7c92507e2a40/go.mod h1:3+A2/SYTqs+551SKTPy20AVnB8b0Yp26Va5SY37eQ/4= github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b h1:VQcnrqtCA2UROp7q8ljkh2XA/u0KRgVv0S1xoUvOweE= github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b/go.mod h1:hwJ5DKhl+pTanFWl+EUpw1H7ukPO/H+MFpgA7jjshzw= github.com/livepeer/joy4 v0.1.2-0.20191121080656-b2fea45cbded h1:ZQlvR5RB4nfT+cOQee+WqmaDOgGtP2oDMhcVvR4L0yA= diff --git a/server/ai_http.go b/server/ai_http.go index 57c59bb366..c6343a454f 100644 --- a/server/ai_http.go +++ b/server/ai_http.go @@ -1,20 +1,15 @@ package server import ( - "bytes" "context" "encoding/json" - "fmt" "net/http" "time" "github.com/getkin/kin-openapi/openapi3filter" - "github.com/golang/protobuf/proto" "github.com/livepeer/ai-worker/worker" "github.com/livepeer/go-livepeer/clog" "github.com/livepeer/go-livepeer/common" - "github.com/livepeer/go-livepeer/core" - "github.com/livepeer/go-livepeer/net" middleware "github.com/oapi-codegen/nethttp-middleware" "github.com/oapi-codegen/runtime" ) @@ -56,7 +51,7 @@ func (h *lphttp) TextToImage() http.Handler { return } - clog.V(common.VERBOSE).Infof(r.Context(), "Received TextToImage request prompt=%v model_id=%v", req.Prompt, *req.ModelId) + clog.V(common.VERBOSE).Infof(ctx, "Received TextToImage request prompt=%v model_id=%v", req.Prompt, *req.ModelId) start := time.Now() resp, err := h.orchestrator.TextToImage(r.Context(), req) @@ -129,61 +124,17 @@ func (h *lphttp) ImageToVideo() http.Handler { clog.V(common.VERBOSE).Infof(ctx, "Received ImageToVideo request imageSize=%v model_id=%v", req.Image.FileSize(), *req.ModelId) start := time.Now() - results, err := h.orchestrator.ImageToVideo(ctx, req) + resp, err := h.orchestrator.ImageToVideo(ctx, req) if err != nil { respondWithError(w, err.Error(), http.StatusInternalServerError) return } - // TODO: Handle more than one video - if len(results) != 1 { - respondWithError(w, "failed to return results", http.StatusInternalServerError) - return - } - took := time.Since(start) clog.Infof(ctx, "Processed ImageToVideo request imageSize=%v model_id=%v took=%v", req.Image.FileSize(), *req.ModelId, took) - res := results[0] - - // Assume only single rendition right now - seg := res.TranscodeData.Segments[0] - name := fmt.Sprintf("%v.mp4", core.RandomManifestID()) - segData := bytes.NewReader(seg.Data) - uri, err := res.OS.SaveData(ctx, name, segData, nil, 0) - if err != nil { - clog.Errorf(ctx, "Could not upload segment err=%q", err) - } - - var result net.TranscodeResult - if err != nil { - clog.Errorf(ctx, "Could not transcode err=%q", err) - result = net.TranscodeResult{Result: &net.TranscodeResult_Error{Error: err.Error()}} - } else { - result = net.TranscodeResult{ - Result: &net.TranscodeResult_Data{ - Data: &net.TranscodeData{ - Segments: []*net.TranscodedSegmentData{ - {Url: uri, Pixels: seg.Pixels}, - }, - Sig: res.Sig, - }, - }, - } - } - - tr := &net.TranscodeResult{ - Result: result.Result, - // TODO: Add other fields - } - - buf, err := proto.Marshal(tr) - if err != nil { - respondWithError(w, err.Error(), http.StatusInternalServerError) - return - } - + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(buf) + _ = json.NewEncoder(w).Encode(resp) }) } diff --git a/server/ai_process.go b/server/ai_process.go index 14fbc46ef9..d43ed193b8 100644 --- a/server/ai_process.go +++ b/server/ai_process.go @@ -4,20 +4,18 @@ import ( "bufio" "bytes" "context" + "encoding/json" "errors" "io" - "net/http" "path/filepath" "strings" "time" "github.com/cenkalti/backoff" - "github.com/golang/protobuf/proto" "github.com/livepeer/ai-worker/worker" "github.com/livepeer/go-livepeer/clog" "github.com/livepeer/go-livepeer/common" "github.com/livepeer/go-livepeer/core" - "github.com/livepeer/go-livepeer/net" "github.com/livepeer/go-tools/drivers" ) @@ -84,7 +82,7 @@ func processTextToImage(ctx context.Context, params aiRequestParams, req worker. return nil, err } - newMedia[i] = worker.Media{Url: newUrl} + newMedia[i] = worker.Media{Url: newUrl, Seed: media.Seed} } resp.Images = newMedia @@ -155,7 +153,7 @@ func processImageToImage(ctx context.Context, params aiRequestParams, req worker return nil, err } - newMedia[i] = worker.Media{Url: newUrl} + newMedia[i] = worker.Media{Url: newUrl, Seed: media.Seed} } resp.Images = newMedia @@ -204,10 +202,10 @@ func processImageToVideo(ctx context.Context, params aiRequestParams, req worker orchUrl := orchInfos[0].Transcoder - var urls []string + var resp *worker.ImageResponse op := func() error { var err error - urls, err = submitImageToVideo(ctx, orchUrl, req) + resp, err = submitImageToVideo(ctx, orchUrl, req) return err } notify := func(err error, dur time.Duration) { @@ -220,42 +218,43 @@ func processImageToVideo(ctx context.Context, params aiRequestParams, req worker } // HACK: Re-use worker.ImageResponse to return results - videos := make([]worker.Media, len(urls)) - for i, url := range urls { - data, err := downloadSeg(ctx, url) + videos := make([]worker.Media, len(resp.Images)) + for i, media := range resp.Images { + data, err := downloadSeg(ctx, media.Url) if err != nil { return nil, err } - name := filepath.Base(url) + name := filepath.Base(media.Url) newUrl, err := params.os.SaveData(ctx, name, bytes.NewReader(data), nil, 0) if err != nil { return nil, err } videos[i] = worker.Media{ - Url: newUrl, + Url: newUrl, + Seed: media.Seed, } } - resp := &worker.ImageResponse{Images: videos} + resp.Images = videos + return resp, nil } -func submitImageToVideo(ctx context.Context, url string, req worker.ImageToVideoMultipartRequestBody) ([]string, error) { +func submitImageToVideo(ctx context.Context, url string, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) { var buf bytes.Buffer mw, err := worker.NewImageToVideoMultipartWriter(&buf, req) if err != nil { return nil, err } - r, err := http.NewRequestWithContext(ctx, "POST", url+"/image-to-video", &buf) + client, err := worker.NewClientWithResponses(url, worker.WithHTTPClient(httpClient)) if err != nil { return nil, err } - r.Header.Set("Content-Type", mw.FormDataContentType()) - resp, err := sendReqWithTimeout(r, imageToVideoTimeout) + resp, err := client.ImageToVideoWithBody(ctx, mw.FormDataContentType(), &buf) if err != nil { return nil, err } @@ -270,25 +269,10 @@ func submitImageToVideo(ctx context.Context, url string, req worker.ImageToVideo return nil, errors.New(string(data)) } - var tr net.TranscodeResult - if err := proto.Unmarshal(data, &tr); err != nil { + var res worker.ImageResponse + if err := json.Unmarshal(data, &res); err != nil { return nil, err } - var tdata *net.TranscodeData - switch res := tr.Result.(type) { - case *net.TranscodeResult_Error: - return nil, errors.New(res.Error) - case *net.TranscodeResult_Data: - tdata = res.Data - default: - return nil, errors.New("UnknownResponse") - } - - urls := make([]string, len(tdata.Segments)) - for i, seg := range tdata.Segments { - urls[i] = seg.Url - } - - return urls, nil + return &res, nil } diff --git a/server/rpc.go b/server/rpc.go index dbe1c0fadc..4151c43286 100644 --- a/server/rpc.go +++ b/server/rpc.go @@ -63,7 +63,7 @@ type Orchestrator interface { AuthToken(sessionID string, expiration int64) *net.AuthToken TextToImage(ctx context.Context, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) ImageToImage(ctx context.Context, req worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error) - ImageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) ([]*core.TranscodeResult, error) + ImageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) } // Balance describes methods for a session's balance maintenance