Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ai/worker: Absorb ai-worker library #3345

Merged
merged 18 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG_PENDING.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#### Broadcaster

#### Orchestrator
- [#3345](https://github.com/livepeer/go-livepeer/pull/3345) Move `ai-worker` code to a local package

#### Transcoder

Expand Down
12 changes: 11 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,17 @@ ifeq ($(BUILDOS),linux)
endif


.PHONY: livepeer livepeer_bench livepeer_cli livepeer_router docker swagger
.PHONY: ai_worker_codegen livepeer livepeer_bench livepeer_cli livepeer_router docker swagger

# Git reference to download the OpenAPI spec from, defaults to `main` branch.
# It can also be a simple git commit hash. e.g. `make ai_worker_codegen REF=c19289d`
REF ?= refs/heads/main
ai_worker_codegen:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@victorges I assume this command is excluded from all to prevent unexpected breakage? If so, perhaps we should document it in the developer docs?

go run github.com/deepmap/oapi-codegen/v2/cmd/[email protected] \
-package worker \
-generate types,client,chi-server,spec \
https://raw.githubusercontent.com/livepeer/ai-worker/$(REF)/runner/openapi.yaml \
| awk '!/WARNING/' > ai/worker/runner.gen.go

livepeer:
GO111MODULE=on CGO_ENABLED=1 CC="$(cc)" CGO_CFLAGS="$(cgo_cflags)" CGO_LDFLAGS="$(cgo_ldflags) ${CGO_LDFLAGS}" go build -o $(GO_BUILD_DIR) -tags "$(BUILD_TAGS)" -ldflags="$(ldflags)" cmd/livepeer/*.go
Expand Down
2 changes: 1 addition & 1 deletion ai/file_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"errors"
"os"

"github.com/livepeer/ai-worker/worker"
"github.com/livepeer/go-livepeer/ai/worker"
)

type FileWorker struct {
Expand All @@ -17,7 +17,7 @@
return &FileWorker{files: files}
}

func (w *FileWorker) TextToImage(ctx context.Context, req worker.GenTextToImageJSONRequestBody) (*worker.ImageResponse, error) {

Check warning on line 20 in ai/file_worker.go

View workflow job for this annotation

GitHub Actions / Run tests defined for the project

parameter 'ctx' seems to be unused, consider removing or renaming it as _

Check warning on line 20 in ai/file_worker.go

View workflow job for this annotation

GitHub Actions / Run tests defined for the project

parameter 'req' seems to be unused, consider removing or renaming it as _
fname, ok := w.files["text-to-image"]
if !ok {
return nil, errors.New("text-to-image response file not found")
Expand All @@ -36,7 +36,7 @@
return &resp, nil
}

func (w *FileWorker) ImageToImage(ctx context.Context, req worker.GenImageToImageMultipartRequestBody) (*worker.ImageResponse, error) {

Check warning on line 39 in ai/file_worker.go

View workflow job for this annotation

GitHub Actions / Run tests defined for the project

parameter 'ctx' seems to be unused, consider removing or renaming it as _

Check warning on line 39 in ai/file_worker.go

View workflow job for this annotation

GitHub Actions / Run tests defined for the project

parameter 'req' seems to be unused, consider removing or renaming it as _
fname, ok := w.files["image-to-image"]
if !ok {
return nil, errors.New("image-to-image response file not found")
Expand All @@ -55,7 +55,7 @@
return &resp, nil
}

func (w *FileWorker) ImageToVideo(ctx context.Context, req worker.GenImageToVideoMultipartRequestBody) (*worker.VideoResponse, error) {

Check warning on line 58 in ai/file_worker.go

View workflow job for this annotation

GitHub Actions / Run tests defined for the project

parameter 'ctx' seems to be unused, consider removing or renaming it as _

Check warning on line 58 in ai/file_worker.go

View workflow job for this annotation

GitHub Actions / Run tests defined for the project

parameter 'req' seems to be unused, consider removing or renaming it as _
fname, ok := w.files["image-to-video"]
if !ok {
return nil, errors.New("image-to-video response file not found")
Expand All @@ -74,7 +74,7 @@
return &resp, nil
}

func (w *FileWorker) Upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) {

Check warning on line 77 in ai/file_worker.go

View workflow job for this annotation

GitHub Actions / Run tests defined for the project

parameter 'ctx' seems to be unused, consider removing or renaming it as _

Check warning on line 77 in ai/file_worker.go

View workflow job for this annotation

GitHub Actions / Run tests defined for the project

parameter 'req' seems to be unused, consider removing or renaming it as _
fname, ok := w.files["upscale"]
if !ok {
return nil, errors.New("upscale response file not found")
Expand All @@ -93,7 +93,7 @@
return &resp, nil
}

func (w *FileWorker) Warm(ctx context.Context, containerName, modelID string) error {

Check warning on line 96 in ai/file_worker.go

View workflow job for this annotation

GitHub Actions / Run tests defined for the project

parameter 'ctx' seems to be unused, consider removing or renaming it as _
return nil
}

Expand Down
61 changes: 61 additions & 0 deletions ai/worker/b64.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package worker

import (
"bytes"
"fmt"
"image"
"image/gif"
"image/jpeg"
"image/png"
"io"
"os"

"github.com/vincent-petithory/dataurl"
)

func ReadImageB64DataUrl(url string, w io.Writer) error {
dataURL, err := dataurl.DecodeString(url)
if err != nil {
return err
}

img, _, err := image.Decode(bytes.NewReader(dataURL.Data))
if err != nil {
return err
}

switch dataURL.MediaType.ContentType() {
case "image/png":
err = png.Encode(w, img)
case "image/jpg", "image/jpeg":
err = jpeg.Encode(w, img, nil)
case "image/gif":
err = gif.Encode(w, img, nil)
// Add cases for other image formats if necessary
default:
return fmt.Errorf("unsupported image format: %s", dataURL.MediaType.ContentType())
}

return err
}

func SaveImageB64DataUrl(url, outputPath string) error {
file, err := os.Create(outputPath)
if err != nil {
return err
}
defer file.Close()

return ReadImageB64DataUrl(url, file)
}

func ReadAudioB64DataUrl(url string, w io.Writer) error {
dataURL, err := dataurl.DecodeString(url)
if err != nil {
return err
}

w.Write(dataURL.Data)

return nil
}
93 changes: 93 additions & 0 deletions ai/worker/b64_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package worker

import (
"bytes"
"encoding/base64"
"image"
"image/color"
"image/png"
"os"
"testing"

"github.com/stretchr/testify/require"
)

func TestReadImageB64DataUrl(t *testing.T) {
tests := []struct {
name string
dataURL string
expectError bool
}{
{
name: "Valid PNG Image",
dataURL: func() string {
img := image.NewRGBA(image.Rect(0, 0, 1, 1))
img.Set(0, 0, color.RGBA{255, 0, 0, 255}) // Set a single red pixel
var imgBuf bytes.Buffer
err := png.Encode(&imgBuf, img)
require.NoError(t, err)

return "data:image/png;base64," + base64.StdEncoding.EncodeToString(imgBuf.Bytes())
}(),
expectError: false,
},
{
name: "Unsupported Image Format",
dataURL: "data:image/bmp;base64," + base64.StdEncoding.EncodeToString([]byte{
0x42, 0x4D, // BMP header
// ... (rest of the BMP data)
}),
expectError: true,
},
{
name: "Invalid Data URL",
dataURL: "invalid-data-url",
expectError: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var buf bytes.Buffer
err := ReadImageB64DataUrl(tt.dataURL, &buf)
if tt.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
require.NotEmpty(t, buf.Bytes())
}
})
}
}

func TestSaveImageB64DataUrl(t *testing.T) {
img := image.NewRGBA(image.Rect(0, 0, 1, 1))
img.Set(0, 0, color.RGBA{255, 0, 0, 255}) // Set a single red pixel
var imgBuf bytes.Buffer
err := png.Encode(&imgBuf, img)
require.NoError(t, err)
dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imgBuf.Bytes())

outputPath := "test_output.png"
defer os.Remove(outputPath)

err = SaveImageB64DataUrl(dataURL, outputPath)
require.NoError(t, err)

// Verify that the file was created and is not empty
fileInfo, err := os.Stat(outputPath)
require.NoError(t, err)
require.False(t, fileInfo.IsDir())
require.NotZero(t, fileInfo.Size())
}

func TestReadAudioB64DataUrl(t *testing.T) {
// Create a sample audio data and encode it as a data URL
audioData := []byte{0x00, 0x01, 0x02, 0x03, 0x04}
dataURL := "data:audio/wav;base64," + base64.StdEncoding.EncodeToString(audioData)

var buf bytes.Buffer
err := ReadAudioB64DataUrl(dataURL, &buf)
require.NoError(t, err)
require.Equal(t, audioData, buf.Bytes())
}
121 changes: 121 additions & 0 deletions ai/worker/container.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package worker

import (
"context"
"errors"
"log/slog"
"time"

"github.com/deepmap/oapi-codegen/v2/pkg/securityprovider"
)

type RunnerContainerType int

const (
Managed RunnerContainerType = iota
External
)

type RunnerContainer struct {
RunnerContainerConfig
Name string
Client *ClientWithResponses
Hardware *HardwareInformation
}

type RunnerEndpoint struct {
URL string
Token string
}

type RunnerContainerConfig struct {
Type RunnerContainerType
Pipeline string
ModelID string
Endpoint RunnerEndpoint
ContainerImageID string

// For managed containers only
ID string
GPU string
KeepWarm bool
containerTimeout time.Duration
}

// Create global references to functions to allow for mocking in tests.
var runnerWaitUntilReadyFunc = runnerWaitUntilReady

func NewRunnerContainer(ctx context.Context, cfg RunnerContainerConfig, name string) (*RunnerContainer, error) {
// Ensure that timeout is set to a non-zero value.
timeout := cfg.containerTimeout
if timeout == 0 {
timeout = containerTimeout
}

var opts []ClientOption
if cfg.Endpoint.Token != "" {
bearerTokenProvider, err := securityprovider.NewSecurityProviderBearerToken(cfg.Endpoint.Token)
if err != nil {
return nil, err
}

opts = append(opts, WithRequestEditorFn(bearerTokenProvider.Intercept))
}

client, err := NewClientWithResponses(cfg.Endpoint.URL, opts...)
if err != nil {
return nil, err
}

cctx, cancel := context.WithTimeout(ctx, cfg.containerTimeout)
defer cancel()
if err := runnerWaitUntilReadyFunc(cctx, client, pollingInterval); err != nil {
return nil, err
}

var hardware *HardwareInformation
hctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
hdw, err := getRunnerHardware(hctx, client)
if err != nil {
hardware = &HardwareInformation{Pipeline: cfg.Pipeline, ModelId: cfg.ModelID, GpuInfo: nil}
} else {
hardware = hdw
}

return &RunnerContainer{
RunnerContainerConfig: cfg,
Name: name,
Client: client,
Hardware: hardware,
}, nil
}

func runnerWaitUntilReady(ctx context.Context, client *ClientWithResponses, pollingInterval time.Duration) error {
ticker := time.NewTicker(pollingInterval)
defer ticker.Stop()

tickerLoop:
for range ticker.C {
select {
case <-ctx.Done():
return errors.New("timed out waiting for runner")
default:
if _, err := client.HealthWithResponse(ctx); err == nil {
break tickerLoop
}
}
}

return nil
}

func getRunnerHardware(ctx context.Context, client *ClientWithResponses) (*HardwareInformation, error) {
resp, err := client.HardwareInfoWithResponse(ctx)
if err != nil {
slog.Error("Error getting hardware info for runner", slog.String("error", err.Error()))
return nil, err
}

return resp.JSON200, nil
}
20 changes: 20 additions & 0 deletions ai/worker/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
Package `worker` hosts the main AI worker logic for managing or using runner
containers for processing inference requests on the Livepeer AI subnet. The
package allows interacting with the [AI runner containers], and it includes:

- Golang API Bindings (./runner.gen.go):

Generated from the AI runner's OpenAPI spec. To re-generate them run: `make ai_worker_codegen`

- Worker (./worker.go):

Listens for inference requests from the Livepeer AI subnet and routes them to the AI runner.

- Docker Manager (./docker.go):

Manages AI runner containers.

[AI runner containers]: https://github.com/livepeer/ai-runner
*/
package worker
Loading
Loading