Skip to content

Minor performance improvement and go test addition #644

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1711,8 +1711,8 @@ where
return Ok(());
}

let client_given_name = Parse::get_name(&message)?;
let parse: Parse = (&message).try_into()?;
let client_given_name = &parse.name;

// Compute the hash of the parse statement
let hash = parse.get_hash();
Expand All @@ -1734,7 +1734,7 @@ where
);

self.prepared_statements
.insert(client_given_name, (new_parse.clone(), hash));
.insert(client_given_name.clone(), (new_parse.clone(), hash));

self.extended_protocol_data_buffer
.push_back(ExtendedProtocolData::create_new_parse(
Expand Down
56 changes: 47 additions & 9 deletions tests/go/prepared_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import (
"context"
"database/sql"
"fmt"
_ "github.com/lib/pq"
"sync"
"testing"

_ "github.com/lib/pq"
)

func Test(t *testing.T) {
Expand Down Expand Up @@ -36,17 +38,53 @@ func namedParameterizedPreparedStatement(t *testing.T) {
}

func unnamedParameterizedPreparedStatement(t *testing.T) {
db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d database=sharded_db user=sharding_user password=sharding_user sslmode=disable", port))
if err != nil {
t.Fatalf("could not open connection: %+v", err)
var wg sync.WaitGroup
errCh := make(chan error, 2) // create error channel

// Have two concurrent clients executing different unnamed prepared statements
for i := 0; i < 2; i++ {
wg.Add(1)

go func(id int) {
defer wg.Done()

db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d database=sharded_db user=sharding_user password=sharding_user sslmode=disable", port))
if err != nil {
errCh <- err // send error to channel
return
}

for j := 0; j < 100; j++ {

// Under the hood QueryContext generates an unnamed parameterized prepared statement
switch id {
case 0:
rows, err := db.QueryContext(context.Background(), "SELECT $1", 1)
if err != nil {
errCh <- err // send error to channel
return
}
_ = rows.Close()

case 1:
rows, err := db.QueryContext(context.Background(), "SELECT $1, $2", 1, 2)
if err != nil {
errCh <- err // send error to channel
return
}
_ = rows.Close()
}

}
}(i)
}

for i := 0; i < 100; i++ {
// Under the hood QueryContext generates an unnamed parameterized prepared statement
rows, err := db.QueryContext(context.Background(), "SELECT $1", 1)
wg.Wait()
close(errCh)

for err := range errCh {
if err != nil {
t.Fatalf("could not query: %+v", err)
t.Fatalf("received error from goroutine: %v", err)
}
_ = rows.Close()
}
}