diff --git a/internal/classifier/classifier_test.go b/internal/classifier/classifier_test.go index 8675d575..c5e580cc 100644 --- a/internal/classifier/classifier_test.go +++ b/internal/classifier/classifier_test.go @@ -3,6 +3,8 @@ package classifier import ( "context" "fmt" + "testing" + "github.com/bitmagnet-io/bitmagnet/internal/classifier/classification" classifier_mocks "github.com/bitmagnet-io/bitmagnet/internal/classifier/mocks" "github.com/bitmagnet-io/bitmagnet/internal/model" @@ -10,7 +12,6 @@ import ( tmdb_mocks "github.com/bitmagnet-io/bitmagnet/internal/tmdb/mocks" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "testing" ) func TestClassifier(t *testing.T) { @@ -195,7 +196,7 @@ func TestClassifier(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("torrent: %s", tc.torrent.Name), func(t *testing.T) { mocks := newTestClassifierMocks(t) - source, sourceErr := yamlSourceProvider{rawSourceProvider: coreSourceProvider{}}.source() + source, sourceErr := coreSourceProvider{}.provider().source() if sourceErr != nil { t.Fatal(sourceErr) return diff --git a/internal/classifier/json_schema_test.go b/internal/classifier/json_schema_test.go index 92d09fdf..668ecc55 100644 --- a/internal/classifier/json_schema_test.go +++ b/internal/classifier/json_schema_test.go @@ -3,9 +3,10 @@ package classifier import ( _ "embed" "encoding/json" + "testing" + "github.com/stretchr/testify/assert" "github.com/xeipuuv/gojsonschema" - "testing" ) //go:embed json-schema.draft-07.json @@ -24,7 +25,7 @@ func TestJsonSchema(t *testing.T) { assert.NoError(t, err) assert.True(t, metaResult.Valid()) - coreClassifier, err := yamlSourceProvider{rawSourceProvider: coreSourceProvider{}}.source() + coreClassifier, err := coreSourceProvider{}.provider().source() assert.NoError(t, err) coreClassifierJson, err := json.Marshal(coreClassifier) assert.NoError(t, err) diff --git a/internal/classifier/source_provider.go b/internal/classifier/source_provider.go index 7a8a3bd7..6382a5df 100644 --- a/internal/classifier/source_provider.go +++ b/internal/classifier/source_provider.go @@ -1,24 +1,24 @@ package classifier import ( + "os" + "path/filepath" + "github.com/adrg/xdg" "github.com/bitmagnet-io/bitmagnet/internal/tmdb" "gopkg.in/yaml.v3" - "os" ) func newSourceProvider(config Config, tmdbConfig tmdb.Config) sourceProvider { - return mergeSourceProvider{ - providers: []sourceProvider{ - yamlSourceProvider{rawSourceProvider: coreSourceProvider{}}, - yamlSourceProvider{rawSourceProvider: xdgSourceProvider{}}, - yamlSourceProvider{rawSourceProvider: cwdSourceProvider{}}, - configSourceProvider{ - config: config, - tmdbEnabled: tmdbConfig.Enabled, - }, - }, - } + var providers []sourceProvider + providers = append(providers, coreSourceProvider{}.provider()) + providers = append(providers, xdgSourceProvider{}.providers()...) + providers = append(providers, cwdSourceProvider{}.providers()...) + providers = append(providers, configSourceProvider{ + config: config, + tmdbEnabled: tmdbConfig.Enabled, + }) + return mergeSourceProvider{providers: providers} } type sourceProvider interface { @@ -45,21 +45,17 @@ func (m mergeSourceProvider) source() (Source, error) { return source, nil } -type rawSourceProvider interface { - source() ([]byte, error) -} - type yamlSourceProvider struct { - rawSourceProvider + raw []byte + err error } func (y yamlSourceProvider) source() (Source, error) { - raw, err := y.rawSourceProvider.source() - if err != nil { - return Source{}, err + if y.err != nil { + return Source{}, y.err } rawWorkflow := make(map[string]interface{}) - parseErr := yaml.Unmarshal(raw, &rawWorkflow) + parseErr := yaml.Unmarshal(y.raw, &rawWorkflow) if parseErr != nil { return Source{}, parseErr } @@ -76,32 +72,41 @@ func (y yamlSourceProvider) source() (Source, error) { type coreSourceProvider struct{} -func (c coreSourceProvider) source() ([]byte, error) { - return classifierCoreYaml, nil +func (c coreSourceProvider) provider() sourceProvider { + return yamlSourceProvider{raw: classifierCoreYaml} } type xdgSourceProvider struct{} -func (_ xdgSourceProvider) source() ([]byte, error) { - if path, pathErr := xdg.ConfigFile("bitmagnet/classifier.yml"); pathErr == nil { - if bytes, readErr := os.ReadFile(path); readErr == nil { - return bytes, nil - } else if !os.IsNotExist(readErr) { - return nil, readErr - } +func (yamlSourceProvider) providers(path string) []sourceProvider { + dir, fname := filepath.Split(path) + glob := dir + "classifier*" + filepath.Ext(fname) + paths, err := filepath.Glob(glob) + if err != nil { + return []sourceProvider{yamlSourceProvider{err: err}} } - return []byte{'{', '}'}, nil + providers := make([]sourceProvider, len(paths)) + for i, path := range paths { + bytes, readErr := os.ReadFile(path) + providers[i] = yamlSourceProvider{raw: bytes, err: readErr} + } + + return providers + +} + +func (xdgSourceProvider) providers() []sourceProvider { + path, err := xdg.ConfigFile("bitmagnet/classifier.yml") + if err != nil { + return []sourceProvider{yamlSourceProvider{err: err}} + } + return yamlSourceProvider{}.providers(path) } type cwdSourceProvider struct{} -func (_ cwdSourceProvider) source() ([]byte, error) { - if bytes, readErr := os.ReadFile("./classifier.yml"); readErr == nil { - return bytes, nil - } else if !os.IsNotExist(readErr) { - return nil, readErr - } - return []byte{'{', '}'}, nil +func (cwdSourceProvider) providers() []sourceProvider { + return yamlSourceProvider{}.providers("./classifier.yml") } type configSourceProvider struct {