diff --git a/README.md b/README.md
index 37ae3bb..6bf28ad 100644
--- a/README.md
+++ b/README.md
@@ -8,6 +8,10 @@ waifu2x.go is a clone of waifu2x-js.
waifu2x-js: https://github.com/takuyaa/waifu2x-js
+Changes
+---
+* 2022-02-09: Imported changes from [go-waifu2x](https://github.com/puhitaku/go-waifu2x), a fork of this repository. This is an excellent job done by @puhitaku and @orisano. It is 14x faster than the original in the non-GPU case.
+
Install
---
@@ -15,9 +19,30 @@ Install
go install github.com/ikawaha/waifu2x.go@latest
```
-Changes
+Usage
---
-* 2022-02-09: Imported changes from [go-waifu2x](https://github.com/puhitaku/go-waifu2x), a fork of this repository. This is an excellent job done by @puhitaku and @orisano. It is 14x faster than the original in the non-GPU case.
+
+```shell
+$ waifu2x.go --help
+Usage of waifu2x:
+ -i string
+ input file (default stdin)
+ -m string
+ waifu2x mode, choose from 'anime' and 'photo' (default "anime")
+ -n int
+ noise reduction level 0 <= n <= 3
+ -o string
+ output file (default stdout)
+ -p int
+ concurrency (default 8)
+ -s float
+ scale multiplier >= 1.0 (default 2)
+ -v verbose
+```
+
+
+
+The Go gopher was designed by [Renée French](https://reneefrench.blogspot.com/).
Note
---
diff --git a/cmd/cmd.go b/cmd/cmd.go
index da7fab6..54aeea5 100644
--- a/cmd/cmd.go
+++ b/cmd/cmd.go
@@ -1,6 +1,7 @@
package cmd
import (
+ "bytes"
"context"
"flag"
"fmt"
@@ -10,50 +11,45 @@ import (
"io"
"os"
"runtime"
- "strings"
"github.com/ikawaha/waifu2x.go/engine"
)
-const (
- commandName = "waifu2x"
- usageMessage = "%s (-i|--input) [-o|--output ] [-s|--scale ] [-j|--jobs ] [-n|--noise ] [-m|--mode (anime|photo)]\n"
-)
-
const (
modeAnime = "anime"
modePhoto = "photo"
)
type option struct {
- input string
- output string
- scale float64
- jobs int
- noiseReduction int
- mode string
- flagSet *flag.FlagSet
+ // flagSet args
+ input string
+ output string
+ scale float64
+ noise int
+ parallel int
+ modeStr string
+ verbose bool
+
+ // option values
+ mode engine.Mode
+ flagSet *flag.FlagSet
}
+const commandName = `waifu2x`
+
func newOption(w io.Writer, eh flag.ErrorHandling) (o *option) {
o = &option{
flagSet: flag.NewFlagSet(commandName, eh),
}
// option settings
o.flagSet.SetOutput(w)
- o.flagSet.StringVar(&o.input, "i", "", "input file (short)")
- o.flagSet.StringVar(&o.input, "input", "", "input file")
- o.flagSet.StringVar(&o.output, "o", "", "output file (short) (default stdout)")
- o.flagSet.StringVar(&o.output, "output", "", "output file (default stdout)")
- o.flagSet.Float64Var(&o.scale, "s", 2.0, "scale multiplier >= 1.0 (short)")
- o.flagSet.Float64Var(&o.scale, "scale", 2.0, "scale multiplier >= 1.0")
- o.flagSet.IntVar(&o.jobs, "j", runtime.NumCPU(), "# of goroutines (short)")
- o.flagSet.IntVar(&o.jobs, "jobs", runtime.NumCPU(), "# of goroutines")
- o.flagSet.IntVar(&o.noiseReduction, "n", 0, "noise reduction level 0 <= n <= 3 (short)")
- o.flagSet.IntVar(&o.noiseReduction, "noise", 0, "noise reduction level 0 <= n <= 3")
- o.flagSet.StringVar(&o.mode, "m", modeAnime, "waifu2x mode, choose from 'anime' and 'photo' (short) (default anime)")
- o.flagSet.StringVar(&o.mode, "mode", modeAnime, "waifu2x mode, choose from 'anime' and 'photo' (default anime)")
-
+ o.flagSet.StringVar(&o.input, "i", "", "input file (default stdin)")
+ o.flagSet.StringVar(&o.output, "o", "", "output file (default stdout)")
+ o.flagSet.Float64Var(&o.scale, "s", 2.0, "scale multiplier >= 1.0")
+ o.flagSet.IntVar(&o.noise, "n", 0, "noise reduction level 0 <= n <= 3")
+ o.flagSet.IntVar(&o.parallel, "p", runtime.GOMAXPROCS(runtime.NumCPU()), "concurrency")
+ o.flagSet.StringVar(&o.modeStr, "m", modeAnime, "waifu2x mode, choose from 'anime' and 'photo'")
+ o.flagSet.BoolVar(&o.verbose, "v", false, "verbose")
return
}
@@ -65,72 +61,79 @@ func (o *option) parse(args []string) error {
if nonFlag := o.flagSet.Args(); len(nonFlag) != 0 {
return fmt.Errorf("invalid argument: %v", nonFlag)
}
- if o.input == "" {
- return fmt.Errorf("input file is empty")
- }
if o.scale < 1.0 {
return fmt.Errorf("invalid scale, %v > 1", o.scale)
}
- if o.jobs < 1 {
- return fmt.Errorf("invalid number of jobs, %v < 1", o.jobs)
+ if o.noise < 0 || o.noise > 3 {
+ return fmt.Errorf("invalid number of noise reduction level, it must be [0,3]")
}
- if o.noiseReduction < 0 || o.noiseReduction > 3 {
- return fmt.Errorf("invalid number of noise reduction level, it must be 0 - 3")
+ if o.parallel < 1 {
+ return fmt.Errorf("invalid number of parallel, it must be >= 1")
}
- if o.mode != modeAnime && o.mode != modePhoto {
+ switch o.modeStr {
+ case modeAnime:
+ o.mode = engine.Anime
+ case modePhoto:
+ o.mode = engine.Photo
+ default:
return fmt.Errorf("invalid mode, choose from 'anime' or 'photo'")
}
return nil
}
-// Usage shows a usage message.
-func Usage() {
- fmt.Printf(usageMessage, commandName)
- opt := newOption(os.Stdout, flag.ContinueOnError)
- opt.flagSet.PrintDefaults()
+func parseInputImage(file string) (image.Image, error) {
+ var b []byte
+ in := os.Stdin
+ if file != "" {
+ var err error
+ b, err = os.ReadFile(file)
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ var err error
+ b, err = io.ReadAll(in)
+ if err != nil {
+ return nil, err
+ }
+ }
+ _, format, err := image.DecodeConfig(bytes.NewReader(b))
+ if err != nil {
+ return nil, err
+ }
+ var decoder func(io.Reader) (image.Image, error)
+ switch format {
+ case "jpeg":
+ decoder = jpeg.Decode
+ case "png":
+ decoder = png.Decode
+ default:
+ return nil, fmt.Errorf("unsupported image type: %s", format)
+ }
+ return decoder(bytes.NewReader(b))
}
// Run executes the waifu2x command.
func Run(args []string) error {
- opt := newOption(os.Stderr, flag.ContinueOnError)
+ opt := newOption(os.Stderr, flag.ExitOnError)
if err := opt.parse(args); err != nil {
return err
}
-
- fp, err := os.Open(opt.input)
+ img, err := parseInputImage(opt.input)
if err != nil {
- return fmt.Errorf("input file %v, %w", opt.input, err)
+ return fmt.Errorf("input error: %w", err)
}
- defer fp.Close()
- var img image.Image
- if strings.HasSuffix(fp.Name(), "jpg") || strings.HasSuffix(fp.Name(), "jpeg") {
- img, err = jpeg.Decode(fp)
- if err != nil {
- return fmt.Errorf("load file %v, %w", opt.input, err)
- }
- } else if strings.HasSuffix(fp.Name(), "png") {
- img, err = png.Decode(fp)
- if err != nil {
- return fmt.Errorf("load file %v, %w", opt.input, err)
- }
- }
-
- mode := engine.Anime
- switch opt.mode {
- case "anime":
- mode = engine.Anime
- case "photo":
- mode = engine.Photo
- }
- w2x, err := engine.NewWaifu2x(mode, opt.noiseReduction, engine.Parallel(8), engine.Verbose())
+ w2x, err := engine.NewWaifu2x(opt.mode, opt.noise, []engine.Option{
+ engine.Verbose(opt.verbose),
+ engine.Parallel(opt.parallel),
+ }...)
if err != nil {
return err
}
-
rgba, err := w2x.ScaleUp(context.TODO(), img, opt.scale)
if err != nil {
- return fmt.Errorf("calc error: %w", err)
+ return err
}
var w io.Writer = os.Stdout
@@ -143,7 +146,7 @@ func Run(args []string) error {
w = fp
}
if err := png.Encode(w, &rgba); err != nil {
- panic(err)
+ return fmt.Errorf("output error: %w", err)
}
return nil
}
diff --git a/engine/channel_image.go b/engine/channel_image.go
index f95e06f..8a51937 100644
--- a/engine/channel_image.go
+++ b/engine/channel_image.go
@@ -18,7 +18,7 @@ func NewChannelImageWidthHeight(width, height int) ChannelImage {
return ChannelImage{
Width: width,
Height: height,
- Buffer: make([]uint8, width*height), // XXX 0以下を0, 255以上を255 として登録する必要あり
+ Buffer: make([]uint8, width*height), // note. it is necessary to register all values less than 0 as 0 and greater than 255 as 255
}
}
diff --git a/engine/model.go b/engine/model.go
index 68d4eec..24325d9 100644
--- a/engine/model.go
+++ b/engine/model.go
@@ -16,6 +16,7 @@ type Param struct {
Weight [][][][]float64 `json:"weight"` // 重み
NInputPlane int `json:"nInputPlane"` // 入力平面数
NOutputPlane int `json:"nOutputPlane"` // 出力平面数
+ WeightVec []float64
}
// Model represents a trained model.
@@ -38,6 +39,7 @@ func LoadModel(r io.Reader) (Model, error) {
if err := dec.Decode(&m); err != nil {
return nil, err
}
+ m.setWeightVec()
return m, nil
}
@@ -122,3 +124,29 @@ func NewAssetModelSet(t Mode, noiseLevel int) (*ModelSet, error) {
NoiseModel: noise,
}, nil
}
+
+func (m Model) setWeightVec() {
+ for l := range m {
+ param := m[l]
+ // [nOutputPlane][nInputPlane][3][3]
+ const square = 9
+ vec := make([]float64, param.NInputPlane*param.NOutputPlane*9)
+ for i := 0; i < param.NInputPlane; i++ {
+ for o := 0; o < param.NOutputPlane; o++ {
+ offset := i*param.NOutputPlane*square + o*square
+ vec[offset+0] = param.Weight[o][i][0][0]
+ vec[offset+1] = param.Weight[o][i][0][1]
+ vec[offset+2] = param.Weight[o][i][0][2]
+
+ vec[offset+3] = param.Weight[o][i][1][0]
+ vec[offset+4] = param.Weight[o][i][1][1]
+ vec[offset+5] = param.Weight[o][i][1][2]
+
+ vec[offset+6] = param.Weight[o][i][2][0]
+ vec[offset+7] = param.Weight[o][i][2][1]
+ vec[offset+8] = param.Weight[o][i][2][2]
+ }
+ }
+ m[l].WeightVec = vec
+ }
+}
diff --git a/engine/model_test.go b/engine/model_test.go
index 589fb0e..e50b90f 100644
--- a/engine/model_test.go
+++ b/engine/model_test.go
@@ -26,3 +26,39 @@ func TestLoadModel(t *testing.T) {
}
}
}
+
+func Test_setWeightVec(t *testing.T) {
+ model, err := LoadModelFile("./model/anime_style_art/scale2.0x_model.json")
+ if err != nil {
+ t.Fatalf("unexpected error, %v", err)
+ }
+ matrix := typeW(model)
+ model.setWeightVec()
+ for i, param := range model {
+ for j, v := range param.WeightVec {
+ if matrix[i][j] != v {
+ t.Fatalf("[%d, %d]=%v <> %v", i, j, matrix[i][j], v)
+ }
+ }
+ }
+}
+
+// W[][O*I*9]
+func typeW(model Model) [][]float64 {
+ var W [][]float64
+ for l := range model {
+ // initialize weight matrix
+ param := model[l]
+ var vec []float64
+ // [nOutputPlane][nInputPlane][3][3]
+ for i := 0; i < param.NInputPlane; i++ {
+ for o := 0; o < param.NOutputPlane; o++ {
+ vec = append(vec, param.Weight[o][i][0]...)
+ vec = append(vec, param.Weight[o][i][1]...)
+ vec = append(vec, param.Weight[o][i][2]...)
+ }
+ }
+ W = append(W, vec)
+ }
+ return W
+}
diff --git a/engine/waifu2x.go b/engine/waifu2x.go
index 24b7c42..c57d7d6 100644
--- a/engine/waifu2x.go
+++ b/engine/waifu2x.go
@@ -7,34 +7,36 @@ import (
"io"
"math"
"os"
+ "runtime"
"sync"
-
- "golang.org/x/sync/semaphore"
)
// Option represents an option of waifu2x.
type Option func(w *Waifu2x) error
-// Parallel is the option that specifies the number of concurrency.
+// Parallel sets the option that specifies the limit number of concurrency.
func Parallel(p int) Option {
return func(w *Waifu2x) error {
+ if p < 0 {
+ return fmt.Errorf("an integer value less than 1")
+ }
w.parallel = p
return nil
}
}
-// Verbose is the verbose option.
-func Verbose() Option {
+// Verbose sets the verbose option.
+func Verbose(v bool) Option {
return func(w *Waifu2x) error {
- w.verbose = true
+ w.verbose = v
return nil
}
}
-// Output is the option that sets the output destination.
-func Output(w io.Writer) Option {
+// LogOutput sets the log output destination.
+func LogOutput(w io.Writer) Option {
return func(w2x *Waifu2x) error {
- w2x.output = w
+ w2x.logOutput = w
return nil
}
}
@@ -45,7 +47,7 @@ type Waifu2x struct {
noiseModel Model
parallel int
verbose bool
- output io.Writer
+ logOutput io.Writer
}
// NewWaifu2x creates a Waifu2x structure.
@@ -57,8 +59,8 @@ func NewWaifu2x(mode Mode, noise int, opts ...Option) (*Waifu2x, error) {
ret := &Waifu2x{
scaleModel: m.Scale2xModel,
noiseModel: m.NoiseModel,
- parallel: 1,
- output: os.Stderr,
+ logOutput: os.Stderr,
+ parallel: runtime.GOMAXPROCS(runtime.NumCPU()),
verbose: false,
}
for _, opt := range opts {
@@ -71,32 +73,47 @@ func NewWaifu2x(mode Mode, noise int, opts ...Option) (*Waifu2x, error) {
func (w Waifu2x) printf(format string, a ...interface{}) {
if w.verbose {
- fmt.Fprintf(w.output, format, a...)
+ fmt.Fprintf(w.logOutput, format, a...)
}
}
func (w Waifu2x) println(a ...interface{}) {
if w.verbose {
- fmt.Fprintln(w.output, a...)
+ fmt.Fprintln(w.logOutput, a...)
}
}
// ScaleUp scales up the image.
func (w Waifu2x) ScaleUp(ctx context.Context, img image.Image, scale float64) (image.RGBA, error) {
- ci, opaque, err := NewChannelImage(img)
+ ci, _, err := NewChannelImage(img)
if err != nil {
return image.RGBA{}, err
}
- ci, err = w.convertChannelImage(ctx, ci, opaque, scale)
+ for {
+ if scale < 2.0 {
+ ci, err = w.convertChannelImage(ctx, ci, scale)
+ if err != nil {
+ return image.RGBA{}, err
+ }
+ break
+ }
+ ci, err = w.convertChannelImage(ctx, ci, 2)
+ if err != nil {
+ return image.RGBA{}, err
+ }
+ scale = scale / 2.0
+ }
return ci.ImageRGBA(), err
}
-func (w Waifu2x) convertChannelImage(ctx context.Context, img ChannelImage, opaque bool, scale float64) (ChannelImage, error) {
+func (w Waifu2x) convertChannelImage(ctx context.Context, img ChannelImage, scale float64) (ChannelImage, error) {
if (w.scaleModel == nil && w.noiseModel == nil) || scale <= 1 {
return img, nil
}
- w.printf("# of goroutines: %d\n", w.parallel)
+ if w.parallel > 0 {
+ w.printf("# of goroutines: %d\n", w.parallel)
+ }
// decompose
w.println("decomposing channels ...")
@@ -106,7 +123,7 @@ func (w Waifu2x) convertChannelImage(ctx context.Context, img ChannelImage, opaq
if w.noiseModel != nil {
w.println("de-noising ...")
var err error
- r, g, b, err = w.convertRGB(ctx, r, g, b, w.noiseModel, 1, w.parallel)
+ r, g, b, err = w.convertRGB(ctx, r, g, b, w.noiseModel, 1)
if err != nil {
return ChannelImage{}, err
}
@@ -116,23 +133,26 @@ func (w Waifu2x) convertChannelImage(ctx context.Context, img ChannelImage, opaq
if w.scaleModel != nil {
w.println("scaling ...")
var err error
- r, g, b, err = w.convertRGB(ctx, r, g, b, w.scaleModel, scale, w.parallel)
+ r, g, b, err = w.convertRGB(ctx, r, g, b, w.scaleModel, scale)
if err != nil {
return ChannelImage{}, err
}
}
// alpha channel
- if !opaque {
- a = a.Resize(scale) // Resize simply
- } else if w.scaleModel != nil { // upscale the alpha channel
- w.println("scaling alpha ...")
- var err error
- a, _, _, err = w.convertRGB(ctx, a, a, a, w.scaleModel, scale, w.parallel)
- if err != nil {
- return ChannelImage{}, err
+ a = a.Resize(scale)
+ /*
+ if !opaque {
+ a = a.Resize(scale) // Resize simply
+ } else if w.scaleModel != nil { // upscale the alpha channel
+ w.println("scaling alpha ...")
+ var err error
+ a, _, _, err = w.convertRGB(ctx, a, a, a, w.scaleModel, scale)
+ if err != nil {
+ return ChannelImage{}, err
+ }
}
- }
+ */
if len(a.Buffer) != len(r.Buffer) {
return ChannelImage{}, fmt.Errorf("channel image size must be same, A=%d, R=%d", len(a.Buffer), len(r.Buffer))
@@ -143,7 +163,7 @@ func (w Waifu2x) convertChannelImage(ctx context.Context, img ChannelImage, opaq
return ChannelCompose(r, g, b, a), nil
}
-func (w Waifu2x) convertRGB(ctx context.Context, imageR, imageG, imageB ChannelImage, model Model, scale float64, jobs int) (r, g, b ChannelImage, err error) {
+func (w Waifu2x) convertRGB(_ context.Context, imageR, imageG, imageB ChannelImage, model Model, scale float64) (r, g, b ChannelImage, err error) {
var inputPlanes [3]ImagePlane
for i, img := range []ChannelImage{imageR, imageG, imageB} {
imgResized := img.Resize(scale)
@@ -159,55 +179,37 @@ func (w Waifu2x) convertRGB(ctx context.Context, imageR, imageG, imageB ChannelI
inputBlocks, blocksW, blocksH := Blocking(inputPlanes)
// init W
- W := typeW(model)
-
- inputLock := &sync.Mutex{}
- outputLock := &sync.Mutex{}
- sem := semaphore.NewWeighted(int64(jobs))
- wg := sync.WaitGroup{}
outputBlocks := make([][]ImagePlane, len(inputBlocks))
digits := int(math.Log10(float64(len(inputBlocks)))) + 2
fmtStr := fmt.Sprintf("%%%dd/%%%dd", digits, digits) + " (%.1f%%)"
-
w.printf(fmtStr, 0, len(inputBlocks), 0.0)
- for b := range inputBlocks {
- err := sem.Acquire(ctx, 1)
- if err != nil {
- panic(fmt.Sprintf("failed to acquire the semaphore: %s", err))
- }
+ limit := make(chan struct{}, w.parallel)
+ wg := sync.WaitGroup{}
+ for i := range inputBlocks {
wg.Add(1)
-
- go func(cb int) {
- if cb >= 10 {
- w.printf("\x1b[2K\r"+fmtStr, cb+1, len(inputBlocks), float32(cb+1)/float32(len(inputBlocks))*100)
+ go func(i int) {
+ limit <- struct{}{}
+ defer wg.Done()
+ if i >= 10 {
+ w.printf("\x1b[2K\r"+fmtStr, i+1, len(inputBlocks), float32(i+1)/float32(len(inputBlocks))*100)
}
-
- inputBlock := inputBlocks[cb]
+ inputBlock := inputBlocks[i]
var outputBlock []ImagePlane
- for l := 0; l < len(model); l++ {
+ for l := range model {
nOutputPlane := model[l].NOutputPlane
// convolution
- if model == nil {
- panic("xxx model nil")
- }
- outputBlock = convolution(inputBlock, W[l], nOutputPlane, model[l].Bias)
+ outputBlock = convolution(inputBlock, model[l].WeightVec, nOutputPlane, model[l].Bias)
inputBlock = outputBlock // propagate output plane to next layer input
-
- inputLock.Lock()
- inputBlocks[cb] = nil
- inputLock.Unlock()
+ inputBlocks[i] = nil
}
- outputLock.Lock()
- outputBlocks[cb] = outputBlock
- outputLock.Unlock()
- sem.Release(1)
- wg.Done()
- }(b)
+ outputBlocks[i] = outputBlock
+ <-limit
+ }(i)
}
-
wg.Wait()
+
w.println()
inputBlocks = nil
@@ -219,26 +221,6 @@ func (w Waifu2x) convertRGB(ctx context.Context, imageR, imageG, imageB ChannelI
return R, G, B, nil
}
-// W[][O*I*9]
-func typeW(model Model) [][]float64 {
- var W [][]float64
- for l := 0; l < len(model); l++ {
- // initialize weight matrix
- param := model[l]
- var vec []float64
- // [nOutputPlane][nInputPlane][3][3]
- for i := 0; i < param.NInputPlane; i++ {
- for o := 0; o < param.NOutputPlane; o++ {
- vec = append(vec, param.Weight[o][i][0]...)
- vec = append(vec, param.Weight[o][i][1]...)
- vec = append(vec, param.Weight[o][i][2]...)
- }
- }
- W = append(W, vec)
- }
- return W
-}
-
func convolution(inputPlanes []ImagePlane, W []float64, nOutputPlane int, bias []float64) []ImagePlane {
if len(inputPlanes) == 0 {
return nil
@@ -256,18 +238,20 @@ func convolution(inputPlanes []ImagePlane, W []float64, nOutputPlane int, bias [
}
for y := 1; y < height-1; y++ {
for x := 1; x < width-1; x++ {
- for i := 0; i < len(biasValues); i++ {
+ for i := range biasValues {
sumValues[i] = biasValues[i]
}
+ const square = 9
wi := 0
for i := range inputPlanes {
- i00, i10, i20, i01, i11, i21, i02, i12, i22 := inputPlanes[i].SegmentAt(x, y)
+ a0, a1, a2, b0, b1, b2, c0, c1, c2 := inputPlanes[i].SegmentAt(x, y)
for o := 0; o < nOutputPlane; o++ {
- ws := W[wi : wi+9]
- sumValues[o] += ws[0]*i00 + ws[1]*i10 + ws[2]*i20 +
- ws[3]*i01 + ws[4]*i11 + ws[5]*i21 +
- ws[6]*i02 + ws[7]*i12 + ws[8]*i22
- wi += 9
+ ws := W[wi : wi+square] // 3x3 square
+ sumValues[o] = sumValues[o] +
+ ws[0]*a0 + ws[1]*a1 + ws[2]*a2 +
+ ws[3]*b0 + ws[4]*b1 + ws[5]*b2 +
+ ws[6]*c0 + ws[7]*c1 + ws[8]*c2
+ wi += square
}
}
for o := 0; o < nOutputPlane; o++ {
diff --git a/engine/waifu2x_test.go b/engine/waifu2x_test.go
index c571a78..234a45a 100644
--- a/engine/waifu2x_test.go
+++ b/engine/waifu2x_test.go
@@ -5,35 +5,70 @@ import (
"fmt"
"image"
"image/png"
+ "math"
"os"
"runtime"
"testing"
)
-const (
- modeAnime = "anime"
- modePhoto = "photo"
-)
+func TestWaifu2x_ScaleUp(t *testing.T) {
+ w2x, err := NewWaifu2x(Anime, 0)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ fp, err := os.Open("../testdata/neko_small.png")
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ img, err := png.Decode(fp)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ testdata := []struct {
+ name string
+ scale float64
+ }{
+ {name: "scale up x1.0", scale: 1.0},
+ {name: "scale up x1.7", scale: 1.7},
+ {name: "scale up x2.0", scale: 2.0},
+ {name: "scale up x3.3", scale: 3.3},
+ {name: "scale up x4.0", scale: 4.0},
+ }
+ for _, tt := range testdata {
+ t.Run(tt.name, func(t *testing.T) {
+ imgX, err := w2x.ScaleUp(context.TODO(), img, tt.scale)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if want, got := int(math.Round(float64(img.Bounds().Max.X)*tt.scale)), imgX.Bounds().Max.X; want != got {
+ t.Errorf("want %d, got %d", want, got)
+ }
+ if want, got := int(math.Round(float64(img.Bounds().Max.Y)*tt.scale)), imgX.Bounds().Max.Y; want != got {
+ t.Errorf("want %d, got %d", want, got)
+ }
+ })
+ }
+}
func BenchmarkWaifu(b *testing.B) {
tests := []struct {
name string
pic string
- mode string
+ mode Mode
noise int
alpha bool
}{
{
name: "Neko",
pic: "../testdata/neko_small.png",
- mode: modeAnime,
+ mode: Anime,
noise: 0,
alpha: false,
},
{
name: "Neko-alpha",
pic: "../testdata/neko_alpha.png",
- mode: modeAnime,
+ mode: Anime,
noise: 0,
alpha: true,
},
@@ -59,9 +94,9 @@ func BenchmarkWaifu(b *testing.B) {
var noiseFn string
switch tt.mode {
- case modeAnime:
+ case Anime:
modelDir = "anime_style_art_rgb"
- case modePhoto:
+ case Photo:
modelDir = "photo"
}
@@ -96,7 +131,7 @@ func BenchmarkWaifu(b *testing.B) {
Height: rgba.Bounds().Max.Y,
Buffer: rgba.Pix,
}
- if _, err := w2x.convertChannelImage(context.TODO(), img, tt.alpha, 2); err != nil {
+ if _, err := w2x.convertChannelImage(context.TODO(), img, 2); err != nil {
b.Errorf("unexpected error: %v", err)
}
})
@@ -106,47 +141,47 @@ func BenchmarkWaifu(b *testing.B) {
func TestAllCombinations(t *testing.T) {
tests := []struct {
name string
- mode string
+ mode Mode
noise int
}{
{
name: "Anime, noiseModel reduction level 0",
- mode: modeAnime,
+ mode: Anime,
noise: 0,
},
{
name: "Anime, noiseModel reduction level 1",
- mode: modeAnime,
+ mode: Anime,
noise: 1,
},
{
name: "Anime, noiseModel reduction level 2",
- mode: modeAnime,
+ mode: Anime,
noise: 2,
},
{
name: "Anime, noiseModel reduction level 3",
- mode: modeAnime,
+ mode: Anime,
noise: 3,
},
{
name: "Photo, noiseModel reduction level 0",
- mode: modePhoto,
+ mode: Photo,
noise: 0,
},
{
name: "Photo, noiseModel reduction level 1",
- mode: modePhoto,
+ mode: Photo,
noise: 1,
},
{
name: "Photo, noiseModel reduction level 2",
- mode: modePhoto,
+ mode: Photo,
noise: 2,
},
{
name: "Photo, noiseModel reduction level 3",
- mode: modePhoto,
+ mode: Photo,
noise: 3,
},
}
@@ -172,9 +207,9 @@ func TestAllCombinations(t *testing.T) {
var noiseFn string
switch tt.mode {
- case modeAnime:
+ case Anime:
modelDir = "anime_style_art_rgb"
- case modePhoto:
+ case Photo:
modelDir = "photo"
}
@@ -206,7 +241,7 @@ func TestAllCombinations(t *testing.T) {
Width: rgba.Bounds().Max.X,
Height: rgba.Bounds().Max.Y,
}
- if _, err := w2x.convertChannelImage(context.TODO(), img, true, 2); err != nil {
+ if _, err := w2x.convertChannelImage(context.TODO(), img, 2); err != nil {
t.Errorf("unexpected error: %v", err)
}
})
diff --git a/go.mod b/go.mod
index a03a241..0cedb64 100644
--- a/go.mod
+++ b/go.mod
@@ -1,5 +1,3 @@
module github.com/ikawaha/waifu2x.go
go 1.17
-
-require golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
diff --git a/go.sum b/go.sum
index 5c00efd..e69de29 100644
--- a/go.sum
+++ b/go.sum
@@ -1,2 +0,0 @@
-golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
-golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
diff --git a/main.go b/main.go
index c551e7a..060051a 100644
--- a/main.go
+++ b/main.go
@@ -10,7 +10,6 @@ import (
func main() {
if err := cmd.Run(os.Args[1:]); err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
- cmd.Usage()
os.Exit(1)
}
os.Exit(0)