Skip to content

Commit 1a1f8a3

Browse files
authored
Merge pull request wk8#25 from xiegeo/master
allow special character as keys
2 parents 4d71287 + 5294187 commit 1a1f8a3

File tree

3 files changed

+68
-41
lines changed

3 files changed

+68
-41
lines changed

json.go

Lines changed: 25 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -116,41 +116,37 @@ func (om *OrderedMap[K, V]) UnmarshalJSON(data []byte) error {
116116
var key K
117117
var value V
118118

119-
if typedKeyPointer, ok := any(&key).(encoding.TextUnmarshaler); ok {
120-
// pointer receiver
121-
if err := typedKeyPointer.UnmarshalText(keyData); err != nil {
119+
switch tkp := any(&key).(type) {
120+
case *string:
121+
*tkp = string(keyData)
122+
case encoding.TextUnmarshaler:
123+
if err := tkp.UnmarshalText(keyData); err != nil {
122124
return err
123125
}
124-
} else {
125-
keyAlreadyUnmarshalled := false
126-
switch typedKey := any(key).(type) {
127-
case string:
128-
keyData = quoteString(keyData)
129-
case encoding.TextUnmarshaler:
130-
// not a pointer receiver
131-
if err := typedKey.UnmarshalText(keyData); err != nil {
132-
return err
133-
}
134-
keyAlreadyUnmarshalled = true
135-
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
136-
default:
137-
138-
// this switch takes care of wrapper types around primitive types, such as
139-
// type myType string
140-
switch reflect.TypeOf(key).Kind() {
141-
case reflect.String:
142-
keyData = quoteString(keyData)
143-
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
144-
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
145-
default:
146-
return fmt.Errorf("unsupported key type: %T", typedKey)
147-
}
126+
case *encoding.TextUnmarshaler:
127+
// This is to preserve compatibility with original implementation
128+
// that handled none pointer receivers, but I (xiegeo) believes this is unused.
129+
if err := (*tkp).UnmarshalText(keyData); err != nil {
130+
return err
148131
}
149-
150-
if !keyAlreadyUnmarshalled {
132+
case *int, *int8, *int16, *int32, *int64, *uint, *uint8, *uint16, *uint32, *uint64:
133+
if err := json.Unmarshal(keyData, tkp); err != nil {
134+
return err
135+
}
136+
default:
137+
// this switch takes care of wrapper types around primitive types, such as
138+
// type myType string
139+
switch reflect.TypeOf(key).Kind() {
140+
case reflect.String:
141+
convertedkeyData := reflect.ValueOf(keyData).Convert(reflect.TypeOf(key))
142+
reflect.ValueOf(&key).Elem().Set(convertedkeyData)
143+
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
144+
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
151145
if err := json.Unmarshal(keyData, &key); err != nil {
152146
return err
153147
}
148+
default:
149+
return fmt.Errorf("unsupported key type: %T", key)
154150
}
155151
}
156152

@@ -162,11 +158,3 @@ func (om *OrderedMap[K, V]) UnmarshalJSON(data []byte) error {
162158
return nil
163159
})
164160
}
165-
166-
func quoteString(data []byte) []byte {
167-
withQuotes := make([]byte, len(data)+2) //nolint:gomnd
168-
copy(withQuotes[1:], data)
169-
withQuotes[0] = '"'
170-
withQuotes[len(data)+1] = '"'
171-
return withQuotes
172-
}

json_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,45 @@ func TestUnmarshallJSON(t *testing.T) {
184184
})
185185
}
186186

187+
const specialCharacters = "\\\\/\"\b\f\n\r\t\x00\uffff\ufffd世界\u007f\u00ff\U0010FFFF"
188+
189+
func TestJSONSpecialCharacters(t *testing.T) {
190+
baselineMap := map[string]any{specialCharacters: specialCharacters}
191+
baselineData, err := json.Marshal(baselineMap)
192+
require.NoError(t, err) // baseline proves this key is supported by official json library
193+
t.Logf("specialCharacters: %#v as []rune:%v", specialCharacters, []rune(specialCharacters))
194+
t.Logf("baseline json data: %s", baselineData)
195+
196+
t.Run("marshal "+specialCharacters, func(t *testing.T) {
197+
om := New[string, any]()
198+
om.Set(specialCharacters, specialCharacters)
199+
b, err := json.Marshal(om)
200+
require.NoError(t, err)
201+
require.Equal(t, baselineData, b)
202+
203+
type myString string
204+
om2 := New[myString, myString]()
205+
om2.Set(specialCharacters, specialCharacters)
206+
b, err = json.Marshal(om2)
207+
require.NoError(t, err)
208+
require.Equal(t, baselineData, b)
209+
})
210+
t.Run("unmarshall "+specialCharacters, func(t *testing.T) {
211+
om := New[string, any]()
212+
require.NoError(t, json.Unmarshal([]byte(baselineData), &om))
213+
assertOrderedPairsEqual(t, om,
214+
[]string{specialCharacters},
215+
[]any{specialCharacters})
216+
217+
type myString string
218+
om2 := New[myString, myString]()
219+
require.NoError(t, json.Unmarshal([]byte(baselineData), &om2))
220+
assertOrderedPairsEqual(t, om2,
221+
[]myString{specialCharacters},
222+
[]myString{specialCharacters})
223+
})
224+
}
225+
187226
// to test structs that have nested map fields
188227
type nestedMaps struct {
189228
X int `json:"x"`

test_utils.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ func assertOrderedPairsEqualFromNewest[K comparable, V any](
2626
if assert.Equal(t, len(expectedKeys), len(expectedValues)) && assert.Equal(t, len(expectedKeys), orderedMap.Len()) {
2727
i := orderedMap.Len() - 1
2828
for pair := orderedMap.Newest(); pair != nil; pair = pair.Prev() {
29-
assert.Equal(t, expectedKeys[i], pair.Key)
30-
assert.Equal(t, expectedValues[i], pair.Value)
29+
assert.Equal(t, expectedKeys[i], pair.Key, "from newest index=%d on key", i)
30+
assert.Equal(t, expectedValues[i], pair.Value, "from newest index=%d on value", i)
3131
i--
3232
}
3333
}
@@ -41,8 +41,8 @@ func assertOrderedPairsEqualFromOldest[K comparable, V any](
4141
if assert.Equal(t, len(expectedKeys), len(expectedValues)) && assert.Equal(t, len(expectedKeys), orderedMap.Len()) {
4242
i := 0
4343
for pair := orderedMap.Oldest(); pair != nil; pair = pair.Next() {
44-
assert.Equal(t, expectedKeys[i], pair.Key)
45-
assert.Equal(t, expectedValues[i], pair.Value)
44+
assert.Equal(t, expectedKeys[i], pair.Key, "from oldest index=%d on key", i)
45+
assert.Equal(t, expectedValues[i], pair.Value, "from oldest index=%d on value", i)
4646
i++
4747
}
4848
}

0 commit comments

Comments
 (0)