Skip to content

Commit

Permalink
Support GIF images (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
ikawaha authored Mar 3, 2022
1 parent 1122062 commit 8599988
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 30 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[![Go Reference](https://pkg.go.dev/badge/github.com/ikawaha/waifu2x.go.svg)](https://pkg.go.dev/github.com/ikawaha/waifu2x.go)
waifu2x.go
===

Expand Down
70 changes: 48 additions & 22 deletions cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"flag"
"fmt"
"image"
"image/gif"
"image/jpeg"
"image/png"
"io"
Expand Down Expand Up @@ -81,26 +82,25 @@ func (o *option) parse(args []string) error {
return nil
}

func parseInputImage(file string) (image.Image, error) {
var b []byte
in := os.Stdin
func parseInputImage(file string) ([]byte, string, error) {
r := os.Stdin
if file != "" {
var err error
b, err = os.ReadFile(file)
r, err = os.Open(file)
if err != nil {
return nil, err
}
} else {
var err error
b, err = io.ReadAll(in)
if err != nil {
return nil, err
return nil, "", err
}
defer r.Close()
}
_, format, err := image.DecodeConfig(bytes.NewReader(b))
b, err := io.ReadAll(r)
if err != nil {
return nil, err
return nil, "", err
}
_, format, err := image.DecodeConfig(bytes.NewReader(b))
return b, format, err
}

func decodeImage(b []byte, format string) (image.Image, error) {
var decoder func(io.Reader) (image.Image, error)
switch format {
case "jpeg":
Expand All @@ -113,13 +113,36 @@ func parseInputImage(file string) (image.Image, error) {
return decoder(bytes.NewReader(b))
}

func scaleUp(ctx context.Context, w2x *engine.Waifu2x, img image.Image, scale float64, w io.Writer) error {
ci, err := w2x.ScaleUp(ctx, img, scale)
if err != nil {
return err
}
rgba := ci.ImageRGBA()
if err := png.Encode(w, &rgba); err != nil {
return fmt.Errorf("output error: %w", err)
}
return nil
}

func scaleUpGIF(ctx context.Context, w2x *engine.Waifu2x, img *gif.GIF, scale float64, w io.Writer) error {
g, err := w2x.ScaleUpGIF(ctx, img, scale)
if err != nil {
return err
}
if err := gif.EncodeAll(w, g); err != nil {
return fmt.Errorf("output error: %w", err)
}
return nil
}

// Run executes the waifu2x command.
func Run(args []string) error {
opt := newOption(os.Stderr, flag.ExitOnError)
if err := opt.parse(args); err != nil {
return err
}
img, err := parseInputImage(opt.input)
b, format, err := parseInputImage(opt.input)
if err != nil {
return fmt.Errorf("input error: %w", err)
}
Expand All @@ -131,11 +154,6 @@ func Run(args []string) error {
if err != nil {
return err
}
rgba, err := w2x.ScaleUp(context.TODO(), img, opt.scale)
if err != nil {
return err
}

var w io.Writer = os.Stdout
if opt.output != "" {
fp, err := os.Create(opt.output)
Expand All @@ -145,8 +163,16 @@ func Run(args []string) error {
defer fp.Close()
w = fp
}
if err := png.Encode(w, &rgba); err != nil {
return fmt.Errorf("output error: %w", err)
if format != "gif" {
img, err := decodeImage(b, format)
if err != nil {
return err
}
return scaleUp(context.TODO(), w2x, img, opt.scale, w)
}
return nil
img, err := gif.DecodeAll(bytes.NewReader(b))
if err != nil {
return err
}
return scaleUpGIF(context.TODO(), w2x, img, opt.scale, w)
}
19 changes: 19 additions & 0 deletions engine/channel_image.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package engine
import (
"fmt"
"image"
"image/color"
"image/draw"
"math"
)

Expand Down Expand Up @@ -44,6 +46,15 @@ func NewChannelImage(img image.Image) (ChannelImage, bool, error) {
}
}
opaque = t.Opaque()
case *image.Paletted:
r := t.Rect
for y := 0; y < r.Dy(); y++ {
for x := 0; x < r.Dx(); x++ {
R, G, B, A := t.At(x, y).RGBA()
b = append(b, uint8(R>>8), uint8(G>>8), uint8(B>>8), uint8(A>>8))
}
}
opaque = t.Opaque()
default:
return ChannelImage{}, false, fmt.Errorf("unknown image format: %T", t)
}
Expand Down Expand Up @@ -79,6 +90,14 @@ func (c ChannelImage) ImageRGBA() image.RGBA {
}
}

// ImagePaletted converts the chanel image to an image.Paletted and return it.
func (c ChannelImage) ImagePaletted(p color.Palette) *image.Paletted {
rgba := c.ImageRGBA()
ret := image.NewPaletted(rgba.Bounds(), p)
draw.DrawMask(ret, image.Rect(0, 0, ret.Bounds().Max.X, ret.Bounds().Max.Y), &rgba, image.Point{}, nil, image.Point{}, draw.Src)
return ret
}

// ChannelDecompose decomposes a channel image to R, G, B and Alpha channels.
func ChannelDecompose(img ChannelImage) (r, g, b, a ChannelImage) {
r = NewChannelImageWidthHeight(img.Width, img.Height)
Expand Down
29 changes: 24 additions & 5 deletions engine/waifu2x.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"image"
"image/gif"
"io"
"math"
"os"
Expand Down Expand Up @@ -83,27 +84,45 @@ func (w Waifu2x) println(a ...interface{}) {
}
}

// ScaleUpGIF scales up the GIF image.
func (w Waifu2x) ScaleUpGIF(ctx context.Context, img *gif.GIF, scale float64) (*gif.GIF, error) {
frames := make([]*image.Paletted, 0, len(img.Image))
for _, v := range img.Image {
p := v.Palette
ci, err := w.ScaleUp(ctx, v, scale)
if err != nil {
return nil, err
}
ip := ci.ImagePaletted(p)
frames = append(frames, ip)
}
img.Image = frames
img.Config.Width = int(float64(img.Config.Width) * scale)
img.Config.Height = int(float64(img.Config.Height) * scale)
return img, nil
}

// ScaleUp scales up the image.
func (w Waifu2x) ScaleUp(ctx context.Context, img image.Image, scale float64) (image.RGBA, error) {
func (w Waifu2x) ScaleUp(ctx context.Context, img image.Image, scale float64) (ChannelImage, error) {
ci, _, err := NewChannelImage(img)
if err != nil {
return image.RGBA{}, err
return ChannelImage{}, err
}
for {
if scale < 2.0 {
ci, err = w.convertChannelImage(ctx, ci, scale)
if err != nil {
return image.RGBA{}, err
return ChannelImage{}, err
}
break
}
ci, err = w.convertChannelImage(ctx, ci, 2)
if err != nil {
return image.RGBA{}, err
return ChannelImage{}, err
}
scale = scale / 2.0
}
return ci.ImageRGBA(), err
return ci, err
}

func (w Waifu2x) convertChannelImage(ctx context.Context, img ChannelImage, scale float64) (ChannelImage, error) {
Expand Down
6 changes: 3 additions & 3 deletions engine/waifu2x_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ func TestWaifu2x_ScaleUp(t *testing.T) {
}
for _, tt := range testdata {
t.Run(tt.name, func(t *testing.T) {
imgX, err := w2x.ScaleUp(context.TODO(), img, tt.scale)
got, err := w2x.ScaleUp(context.TODO(), img, tt.scale)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if want, got := int(math.Round(float64(img.Bounds().Max.X)*tt.scale)), imgX.Bounds().Max.X; want != got {
if want, got := int(math.Round(float64(img.Bounds().Max.X)*tt.scale)), got.Width; want != got {
t.Errorf("want %d, got %d", want, got)
}
if want, got := int(math.Round(float64(img.Bounds().Max.Y)*tt.scale)), imgX.Bounds().Max.Y; want != got {
if want, got := int(math.Round(float64(img.Bounds().Max.Y)*tt.scale)), got.Height; want != got {
t.Errorf("want %d, got %d", want, got)
}
})
Expand Down

0 comments on commit 8599988

Please sign in to comment.