Skip to content

Commit 87730c8

Browse files
authored
Merge pull request #12 from codingpot/fix-paper-list-params
fix: update PaperListParams to match Python
2 parents 70ac12b + 2b45ea8 commit 87730c8

File tree

2 files changed

+82
-14
lines changed

2 files changed

+82
-14
lines changed

Diff for: paper_list.go

+28-13
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@ package paperswithcode_go
33
import (
44
"encoding/json"
55
"fmt"
6+
"strings"
7+
68
"github.com/codingpot/paperswithcode-go/v2/models"
7-
"net/url"
89
)
910

1011
// PaperList returns multiple papers.
1112
func (c *Client) PaperList(params PaperListParams) (*models.PaperList, error) {
12-
papersListURL := c.baseURL + "/papers?" + params.build()
13+
papersListURL := c.baseURL + "/papers?" + params.Build()
1314

1415
response, err := c.httpClient.Get(papersListURL)
1516
if err != nil {
@@ -28,27 +29,41 @@ func (c *Client) PaperList(params PaperListParams) (*models.PaperList, error) {
2829

2930
// PaperListParams is the parameter for PaperList method.
3031
type PaperListParams struct {
31-
// Query to search papers (default: "")
32+
// Q to search papers (default: "")
3233
// If empty, it returns all papers.
33-
Query string
34+
Q string
35+
ArxivID string
36+
Title string
37+
Abstract string
3438
// Page is the number of page to search (default: 1)
3539
Page int
36-
// Limit returns how many papers are returned in a single response.
37-
Limit int
40+
// ItemsPerPage returns how many papers are returned in a single response.
41+
ItemsPerPage int
42+
}
43+
44+
func (p PaperListParams) Build() string {
45+
var b strings.Builder
46+
b.WriteString(fmt.Sprintf("page=%d&items_per_page=%d", p.Page, p.ItemsPerPage))
47+
48+
addParamsIfValid(&b, "q", p.Q)
49+
addParamsIfValid(&b, "arxiv_id", p.ArxivID)
50+
addParamsIfValid(&b, "title", p.Title)
51+
addParamsIfValid(&b, "abstract", p.Abstract)
52+
53+
return b.String()
3854
}
3955

40-
func (p PaperListParams) build() string {
41-
if p.Query == "" {
42-
return fmt.Sprintf("items_per_page=%d&page=%d", p.Limit, p.Page)
56+
func addParamsIfValid(b *strings.Builder, key string, value string) {
57+
if value != "" {
58+
b.WriteString(fmt.Sprintf("&%s=%s", key, value))
4359
}
44-
return fmt.Sprintf("q=%s&items_per_page=%d&page=%d", url.QueryEscape(p.Query), p.Limit, p.Page)
4560
}
4661

4762
// PaperListParamsDefault returns the default PaperListParams.
4863
func PaperListParamsDefault() PaperListParams {
4964
return PaperListParams{
50-
Query: "",
51-
Page: 1,
52-
Limit: 50,
65+
Q: "",
66+
Page: 1,
67+
ItemsPerPage: 50,
5368
}
5469
}

Diff for: paper_list_test.go

+54-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,65 @@
11
package paperswithcode_go
22

33
import (
4-
"github.com/stretchr/testify/assert"
54
"testing"
5+
6+
"github.com/stretchr/testify/assert"
67
)
78

89
func TestClient_PaperList(t *testing.T) {
910
client := NewClient(WithAPIToken(apiToken))
1011
_, err := client.PaperList(PaperListParamsDefault())
1112
assert.NoError(t, err)
1213
}
14+
15+
func TestPaperListParams_Build(t *testing.T) {
16+
type fields struct {
17+
Q string
18+
ArxivID string
19+
Title string
20+
Abstract string
21+
Page int
22+
ItemsPerPage int
23+
}
24+
tests := []struct {
25+
name string
26+
fields fields
27+
want string
28+
}{
29+
{
30+
name: "Q is given, it passes Q",
31+
fields: fields{
32+
Q: "wow",
33+
Page: 1,
34+
ItemsPerPage: 50,
35+
},
36+
want: "page=1&items_per_page=50&q=wow",
37+
},
38+
{
39+
name: "Q is not given, it shouldn't add Q param",
40+
fields: fields{
41+
Page: 1,
42+
ItemsPerPage: 50,
43+
},
44+
want: "page=1&items_per_page=50",
45+
},
46+
{
47+
name: "Default Param is valid",
48+
fields: fields(PaperListParamsDefault()),
49+
want: "page=1&items_per_page=50",
50+
},
51+
}
52+
for _, tt := range tests {
53+
t.Run(tt.name, func(t *testing.T) {
54+
p := PaperListParams{
55+
Q: tt.fields.Q,
56+
ArxivID: tt.fields.ArxivID,
57+
Title: tt.fields.Title,
58+
Abstract: tt.fields.Abstract,
59+
Page: tt.fields.Page,
60+
ItemsPerPage: tt.fields.ItemsPerPage,
61+
}
62+
assert.Equal(t, tt.want, p.Build())
63+
})
64+
}
65+
}

0 commit comments

Comments
 (0)