Skip to content

Commit ba01224

Browse files
committed
Merge branch '86-clones-data-races' into 'master'
fix: fix a data race on clones map (#86) See merge request postgres-ai/database-lab!68
2 parents c914b54 + 5a824ab commit ba01224

File tree

2 files changed

+156
-23
lines changed

2 files changed

+156
-23
lines changed

pkg/services/cloning/mode_base.go

+50-23
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ const idleCheckDuration = 5 * time.Minute
2727
type baseCloning struct {
2828
cloning
2929

30-
// TODO(akartasov): Fix data race.
3130
cloneMutex sync.RWMutex
3231
clones map[string]*CloneWrapper
3332
instanceStatus *models.InstanceStatus
@@ -91,10 +90,10 @@ func (c *baseCloning) CreateClone(clone *models.Clone) error {
9190

9291
w := NewCloneWrapper(clone)
9392

94-
clone.Status = &models.Status{
93+
c.updateStatus(w.clone, models.Status{
9594
Code: models.StatusCreating,
9695
Message: models.CloneMessageCreating,
97-
}
96+
})
9897

9998
w.timeCreatedAt = time.Now()
10099
clone.CreatedAt = util.FormatTime(w.timeCreatedAt)
@@ -119,10 +118,10 @@ func (c *baseCloning) CreateClone(clone *models.Clone) error {
119118
session, err := c.provision.StartSession(w.username, w.password, snapshotID)
120119
if err != nil {
121120
// TODO(anatoly): Empty room case.
122-
clone.Status = &models.Status{
121+
c.updateStatus(w.clone, models.Status{
123122
Code: models.StatusFatal,
124123
Message: models.CloneMessageFatal,
125-
}
124+
})
126125

127126
log.Errf("Failed to start session: %+v.", err)
128127

@@ -133,10 +132,10 @@ func (c *baseCloning) CreateClone(clone *models.Clone) error {
133132

134133
w.timeStartedAt = time.Now()
135134

136-
clone.Status = &models.Status{
135+
c.updateStatus(w.clone, models.Status{
137136
Code: models.StatusOK,
138137
Message: models.CloneMessageOK,
139-
}
138+
})
140139

141140
clone.DB.Port = strconv.FormatUint(uint64(session.Port), 10)
142141

@@ -169,30 +168,28 @@ func (c *baseCloning) DestroyClone(id string) error {
169168
return errors.New("clone is protected")
170169
}
171170

172-
w.clone.Status = &models.Status{
171+
c.updateStatus(w.clone, models.Status{
173172
Code: models.StatusDeleting,
174173
Message: models.CloneMessageDeleting,
175-
}
174+
})
176175

177176
if w.session == nil {
178177
return errors.New("clone is not started yet")
179178
}
180179

181180
go func() {
182181
if err := c.provision.StopSession(w.session); err != nil {
183-
w.clone.Status = &models.Status{
182+
c.updateStatus(w.clone, models.Status{
184183
Code: models.StatusFatal,
185184
Message: models.CloneMessageFatal,
186-
}
185+
})
187186

188187
log.Errf("Failed to delete clone: %+v.", err)
189188

190189
return
191190
}
192191

193-
c.cloneMutex.Lock()
194-
delete(c.clones, w.clone.ID)
195-
c.cloneMutex.Unlock()
192+
c.deleteClone(w.clone.ID)
196193
}()
197194

198195
return nil
@@ -274,10 +271,10 @@ func (c *baseCloning) ResetClone(id string) error {
274271
return errors.New("clone not found")
275272
}
276273

277-
w.clone.Status = &models.Status{
274+
c.updateStatus(w.clone, models.Status{
278275
Code: models.StatusResetting,
279276
Message: models.CloneMessageResetting,
280-
}
277+
})
281278

282279
if w.session == nil {
283280
return errors.New("clone is not started yet")
@@ -291,20 +288,20 @@ func (c *baseCloning) ResetClone(id string) error {
291288

292289
err := c.provision.ResetSession(w.session, snapshotID)
293290
if err != nil {
294-
w.clone.Status = &models.Status{
291+
c.updateStatus(w.clone, models.Status{
295292
Code: models.StatusFatal,
296293
Message: models.CloneMessageFatal,
297-
}
294+
})
298295

299296
log.Errf("Failed to reset session: %+v.", err)
300297

301298
return
302299
}
303300

304-
w.clone.Status = &models.Status{
301+
c.updateStatus(w.clone, models.Status{
305302
Code: models.StatusOK,
306303
Message: models.CloneMessageOK,
307-
}
304+
})
308305
}()
309306

310307
return nil
@@ -337,10 +334,13 @@ func (c *baseCloning) GetSnapshots() ([]*models.Snapshot, error) {
337334

338335
// GetClones returns all clones.
339336
func (c *baseCloning) GetClones() []*models.Clone {
340-
clones := make([]*models.Clone, 0, len(c.clones))
337+
clones := make([]*models.Clone, 0, c.lenClones())
338+
339+
c.cloneMutex.RLock()
341340
for _, clone := range c.clones {
342341
clones = append(clones, clone.clone)
343342
}
343+
c.cloneMutex.RUnlock()
344344

345345
return clones
346346
}
@@ -361,20 +361,47 @@ func (c *baseCloning) setWrapper(id string, wrapper *CloneWrapper) {
361361
c.cloneMutex.Unlock()
362362
}
363363

364+
// updateStatus updates the clone status.
365+
func (c *baseCloning) updateStatus(clone *models.Clone, status models.Status) {
366+
c.cloneMutex.Lock()
367+
clone.Status = &status
368+
c.cloneMutex.Unlock()
369+
}
370+
371+
// deleteClone removes the clone by ID.
372+
func (c *baseCloning) deleteClone(cloneID string) {
373+
c.cloneMutex.Lock()
374+
delete(c.clones, cloneID)
375+
c.cloneMutex.Unlock()
376+
}
377+
378+
// lenClones returns the number of clones.
379+
func (c *baseCloning) lenClones() int {
380+
c.cloneMutex.RLock()
381+
lenClones := len(c.clones)
382+
c.cloneMutex.RUnlock()
383+
384+
return lenClones
385+
}
386+
364387
func (c *baseCloning) getExpectedCloningTime() float64 {
365-
if len(c.clones) == 0 {
388+
lenClones := c.lenClones()
389+
390+
if lenClones == 0 {
366391
return 0
367392
}
368393

369394
sum := 0.0
370395

396+
c.cloneMutex.RLock()
371397
for _, cloneWrapper := range c.clones {
372398
if cloneWrapper.clone.Metadata != nil {
373399
sum += cloneWrapper.clone.Metadata.CloningTime
374400
}
375401
}
402+
c.cloneMutex.RUnlock()
376403

377-
return sum / float64(len(c.clones))
404+
return sum / float64(lenClones)
378405
}
379406

380407
func (c *baseCloning) fetchSnapshots() error {
+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
package cloning
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
"github.com/stretchr/testify/require"
8+
"github.com/stretchr/testify/suite"
9+
10+
"gitlab.com/postgres-ai/database-lab/pkg/models"
11+
)
12+
13+
func TestBaseCloningSuite(t *testing.T) {
14+
suite.Run(t, new(BaseCloningSuite))
15+
}
16+
17+
type BaseCloningSuite struct {
18+
cloning *baseCloning
19+
20+
suite.Suite
21+
}
22+
23+
func (s *BaseCloningSuite) SetupSuite() {
24+
cloning := &baseCloning{
25+
clones: make(map[string]*CloneWrapper),
26+
}
27+
28+
s.cloning = cloning
29+
}
30+
31+
func (s *BaseCloningSuite) TearDownTest() {
32+
s.cloning.clones = make(map[string]*CloneWrapper)
33+
}
34+
35+
func (s *BaseCloningSuite) TestFindWrapper() {
36+
wrapper, ok := s.cloning.findWrapper("testCloneID")
37+
assert.False(s.T(), ok)
38+
assert.Nil(s.T(), wrapper)
39+
40+
s.cloning.setWrapper("testCloneID", &CloneWrapper{clone: &models.Clone{ID: "testCloneID"}})
41+
42+
wrapper, ok = s.cloning.findWrapper("testCloneID")
43+
assert.True(s.T(), ok)
44+
assert.NotNil(s.T(), wrapper)
45+
assert.Equal(s.T(), CloneWrapper{clone: &models.Clone{ID: "testCloneID"}}, *wrapper)
46+
}
47+
48+
func (s *BaseCloningSuite) TestUpdateStatus() {
49+
s.cloning.setWrapper("testCloneID", &CloneWrapper{clone: &models.Clone{Status: &models.Status{
50+
Code: models.StatusCreating,
51+
Message: models.CloneMessageCreating,
52+
}}})
53+
54+
wrapper, ok := s.cloning.findWrapper("testCloneID")
55+
require.True(s.T(), ok)
56+
require.NotNil(s.T(), wrapper)
57+
58+
s.cloning.updateStatus(wrapper.clone, models.Status{
59+
Code: models.StatusOK,
60+
Message: models.CloneMessageOK,
61+
})
62+
63+
wrapper, ok = s.cloning.findWrapper("testCloneID")
64+
require.True(s.T(), ok)
65+
require.NotNil(s.T(), wrapper)
66+
67+
assert.Equal(s.T(), models.Status{
68+
Code: models.StatusOK,
69+
Message: models.CloneMessageOK,
70+
}, *wrapper.clone.Status)
71+
}
72+
73+
func (s *BaseCloningSuite) TestDeleteClone() {
74+
wrapper, ok := s.cloning.findWrapper("testCloneID")
75+
assert.False(s.T(), ok)
76+
assert.Nil(s.T(), wrapper)
77+
78+
s.cloning.setWrapper("testCloneID", &CloneWrapper{})
79+
80+
wrapper, ok = s.cloning.findWrapper("testCloneID")
81+
require.True(s.T(), ok)
82+
require.NotNil(s.T(), wrapper)
83+
assert.Equal(s.T(), CloneWrapper{}, *wrapper)
84+
85+
s.cloning.deleteClone("testCloneID")
86+
87+
wrapper, ok = s.cloning.findWrapper("testCloneID")
88+
assert.False(s.T(), ok)
89+
assert.Nil(s.T(), wrapper)
90+
}
91+
92+
func (s *BaseCloningSuite) TestLenClones() {
93+
lenClones := s.cloning.lenClones()
94+
assert.Equal(s.T(), 0, lenClones)
95+
96+
s.cloning.setWrapper("testCloneID1", &CloneWrapper{})
97+
s.cloning.setWrapper("testCloneID2", &CloneWrapper{})
98+
99+
lenClones = s.cloning.lenClones()
100+
assert.Equal(s.T(), 2, lenClones)
101+
102+
s.cloning.deleteClone("testCloneID1")
103+
104+
lenClones = s.cloning.lenClones()
105+
assert.Equal(s.T(), 1, lenClones)
106+
}

0 commit comments

Comments
 (0)