Skip to content

Commit 7778a1c

Browse files
authored
Add search asynchronously with context (#440)
* feat: add search with channels inspired by #319 * refactor: fix to check proper test results #319 * refactor: fix to use unpackAttributes() for Attributes #319 * refactor: returns receive-only channel to prevent closing it from the caller #319 * refactor: pass channelSize to be able to controll buffered channel by the caller #319 * fix: recover an asynchronouse closing timing issue #319 * fix: consume all entries from the channel to prevent blocking by the connection #319 * feat: add initial search async function with channel #341 * feat: provide search async function and drop search with channels #319 #341 * refactor: lock when to call GetLastError since it might be in communication
1 parent cdb0754 commit 7778a1c

12 files changed

+632
-0
lines changed

client.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package ldap
22

33
import (
4+
"context"
45
"crypto/tls"
56
"time"
67
)
@@ -32,6 +33,7 @@ type Client interface {
3233
PasswordModify(*PasswordModifyRequest) (*PasswordModifyResult, error)
3334

3435
Search(*SearchRequest) (*SearchResult, error)
36+
SearchAsync(ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response
3537
SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error)
3638
DirSync(searchRequest *SearchRequest, flags, maxAttrCount int64, cookie []byte) (*SearchResult, error)
3739
}

conn.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,8 @@ func (l *Conn) nextMessageID() int64 {
327327
// GetLastError returns the last recorded error from goroutines like processMessages and reader.
328328
// Only the last recorded error will be returned.
329329
func (l *Conn) GetLastError() error {
330+
l.messageMutex.Lock()
331+
defer l.messageMutex.Unlock()
330332
return l.err
331333
}
332334

examples_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package ldap
22

33
import (
4+
"context"
45
"crypto/tls"
56
"crypto/x509"
67
"fmt"
@@ -50,6 +51,35 @@ func ExampleConn_Search() {
5051
}
5152
}
5253

54+
// This example demonstrates how to search asynchronously
55+
func ExampleConn_SearchAsync() {
56+
l, err := DialURL(fmt.Sprintf("%s:%d", "ldap.example.com", 389))
57+
if err != nil {
58+
log.Fatal(err)
59+
}
60+
defer l.Close()
61+
62+
searchRequest := NewSearchRequest(
63+
"dc=example,dc=com", // The base dn to search
64+
ScopeWholeSubtree, NeverDerefAliases, 0, 0, false,
65+
"(&(objectClass=organizationalPerson))", // The filter to apply
66+
[]string{"dn", "cn"}, // A list attributes to retrieve
67+
nil,
68+
)
69+
70+
ctx, cancel := context.WithCancel(context.Background())
71+
defer cancel()
72+
73+
r := l.SearchAsync(ctx, searchRequest, 64)
74+
for r.Next() {
75+
entry := r.Entry()
76+
fmt.Printf("%s has DN %s\n", entry.GetAttributeValue("cn"), entry.DN)
77+
}
78+
if err := r.Err(); err != nil {
79+
log.Fatal(err)
80+
}
81+
}
82+
5383
// This example demonstrates how to start a TLS connection
5484
func ExampleConn_StartTLS() {
5585
l, err := DialURL("ldap://ldap.example.com:389")

ldap_test.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package ldap
22

33
import (
4+
"context"
45
"crypto/tls"
6+
"log"
57
"testing"
68

79
ber "github.com/go-asn1-ber/asn1-ber"
@@ -344,3 +346,67 @@ func TestEscapeDN(t *testing.T) {
344346
})
345347
}
346348
}
349+
350+
func TestSearchAsync(t *testing.T) {
351+
l, err := DialURL(ldapServer)
352+
if err != nil {
353+
t.Fatal(err)
354+
}
355+
defer l.Close()
356+
357+
searchRequest := NewSearchRequest(
358+
baseDN,
359+
ScopeWholeSubtree, DerefAlways, 0, 0, false,
360+
filter[2],
361+
attributes,
362+
nil)
363+
364+
srs := make([]*Entry, 0)
365+
ctx := context.Background()
366+
r := l.SearchAsync(ctx, searchRequest, 64)
367+
for r.Next() {
368+
srs = append(srs, r.Entry())
369+
}
370+
if err := r.Err(); err != nil {
371+
log.Fatal(err)
372+
}
373+
374+
t.Logf("TestSearcAsync: %s -> num of entries = %d", searchRequest.Filter, len(srs))
375+
}
376+
377+
func TestSearchAsyncAndCancel(t *testing.T) {
378+
l, err := DialURL(ldapServer)
379+
if err != nil {
380+
t.Fatal(err)
381+
}
382+
defer l.Close()
383+
384+
searchRequest := NewSearchRequest(
385+
baseDN,
386+
ScopeWholeSubtree, DerefAlways, 0, 0, false,
387+
filter[2],
388+
attributes,
389+
nil)
390+
391+
cancelNum := 10
392+
srs := make([]*Entry, 0)
393+
ctx, cancel := context.WithCancel(context.Background())
394+
defer cancel()
395+
r := l.SearchAsync(ctx, searchRequest, 0)
396+
for r.Next() {
397+
srs = append(srs, r.Entry())
398+
if len(srs) == cancelNum {
399+
cancel()
400+
}
401+
}
402+
if err := r.Err(); err != nil {
403+
log.Fatal(err)
404+
}
405+
406+
if len(srs) > cancelNum+3 {
407+
// the cancellation process is asynchronous,
408+
// so it might get some entries after calling cancel()
409+
t.Errorf("Got entries %d, expected < %d", len(srs), cancelNum+3)
410+
}
411+
t.Logf("TestSearchAsyncAndCancel: %s -> num of entries = %d", searchRequest.Filter, len(srs))
412+
}

response.go

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
package ldap
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
8+
ber "github.com/go-asn1-ber/asn1-ber"
9+
)
10+
11+
// Response defines an interface to get data from an LDAP server
12+
type Response interface {
13+
Entry() *Entry
14+
Referral() string
15+
Controls() []Control
16+
Err() error
17+
Next() bool
18+
}
19+
20+
type searchResponse struct {
21+
conn *Conn
22+
ch chan *SearchSingleResult
23+
24+
entry *Entry
25+
referral string
26+
controls []Control
27+
err error
28+
}
29+
30+
// Entry returns an entry from the given search request
31+
func (r *searchResponse) Entry() *Entry {
32+
return r.entry
33+
}
34+
35+
// Referral returns a referral from the given search request
36+
func (r *searchResponse) Referral() string {
37+
return r.referral
38+
}
39+
40+
// Controls returns controls from the given search request
41+
func (r *searchResponse) Controls() []Control {
42+
return r.controls
43+
}
44+
45+
// Err returns an error when the given search request was failed
46+
func (r *searchResponse) Err() error {
47+
return r.err
48+
}
49+
50+
// Next returns whether next data exist or not
51+
func (r *searchResponse) Next() bool {
52+
res, ok := <-r.ch
53+
if !ok {
54+
return false
55+
}
56+
if res == nil {
57+
return false
58+
}
59+
r.err = res.Error
60+
if r.err != nil {
61+
return false
62+
}
63+
r.err = r.conn.GetLastError()
64+
if r.err != nil {
65+
return false
66+
}
67+
r.entry = res.Entry
68+
r.referral = res.Referral
69+
r.controls = res.Controls
70+
return true
71+
}
72+
73+
func (r *searchResponse) start(ctx context.Context, searchRequest *SearchRequest) {
74+
go func() {
75+
defer func() {
76+
close(r.ch)
77+
if err := recover(); err != nil {
78+
r.conn.err = fmt.Errorf("ldap: recovered panic in searchResponse: %v", err)
79+
}
80+
}()
81+
82+
if r.conn.IsClosing() {
83+
return
84+
}
85+
86+
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
87+
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, r.conn.nextMessageID(), "MessageID"))
88+
// encode search request
89+
err := searchRequest.appendTo(packet)
90+
if err != nil {
91+
r.ch <- &SearchSingleResult{Error: err}
92+
return
93+
}
94+
r.conn.Debug.PrintPacket(packet)
95+
96+
msgCtx, err := r.conn.sendMessage(packet)
97+
if err != nil {
98+
r.ch <- &SearchSingleResult{Error: err}
99+
return
100+
}
101+
defer r.conn.finishMessage(msgCtx)
102+
103+
foundSearchSingleResultDone := false
104+
for !foundSearchSingleResultDone {
105+
select {
106+
case <-ctx.Done():
107+
r.conn.Debug.Printf("%d: %s", msgCtx.id, ctx.Err().Error())
108+
return
109+
default:
110+
r.conn.Debug.Printf("%d: waiting for response", msgCtx.id)
111+
packetResponse, ok := <-msgCtx.responses
112+
if !ok {
113+
err := NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
114+
r.ch <- &SearchSingleResult{Error: err}
115+
return
116+
}
117+
packet, err = packetResponse.ReadPacket()
118+
r.conn.Debug.Printf("%d: got response %p", msgCtx.id, packet)
119+
if err != nil {
120+
r.ch <- &SearchSingleResult{Error: err}
121+
return
122+
}
123+
124+
if r.conn.Debug {
125+
if err := addLDAPDescriptions(packet); err != nil {
126+
r.ch <- &SearchSingleResult{Error: err}
127+
return
128+
}
129+
ber.PrintPacket(packet)
130+
}
131+
132+
switch packet.Children[1].Tag {
133+
case ApplicationSearchResultEntry:
134+
r.ch <- &SearchSingleResult{
135+
Entry: &Entry{
136+
DN: packet.Children[1].Children[0].Value.(string),
137+
Attributes: unpackAttributes(packet.Children[1].Children[1].Children),
138+
},
139+
}
140+
141+
case ApplicationSearchResultDone:
142+
if err := GetLDAPError(packet); err != nil {
143+
r.ch <- &SearchSingleResult{Error: err}
144+
return
145+
}
146+
if len(packet.Children) == 3 {
147+
result := &SearchSingleResult{}
148+
for _, child := range packet.Children[2].Children {
149+
decodedChild, err := DecodeControl(child)
150+
if err != nil {
151+
werr := fmt.Errorf("failed to decode child control: %w", err)
152+
r.ch <- &SearchSingleResult{Error: werr}
153+
return
154+
}
155+
result.Controls = append(result.Controls, decodedChild)
156+
}
157+
r.ch <- result
158+
}
159+
foundSearchSingleResultDone = true
160+
161+
case ApplicationSearchResultReference:
162+
ref := packet.Children[1].Children[0].Value.(string)
163+
r.ch <- &SearchSingleResult{Referral: ref}
164+
}
165+
}
166+
}
167+
r.conn.Debug.Printf("%d: returning", msgCtx.id)
168+
}()
169+
}
170+
171+
func newSearchResponse(conn *Conn, bufferSize int) *searchResponse {
172+
var ch chan *SearchSingleResult
173+
if bufferSize > 0 {
174+
ch = make(chan *SearchSingleResult, bufferSize)
175+
} else {
176+
ch = make(chan *SearchSingleResult)
177+
}
178+
return &searchResponse{
179+
conn: conn,
180+
ch: ch,
181+
}
182+
}

search.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package ldap
22

33
import (
4+
"context"
45
"errors"
56
"fmt"
67
"reflect"
@@ -375,6 +376,28 @@ func (s *SearchResult) appendTo(r *SearchResult) {
375376
r.Controls = append(r.Controls, s.Controls...)
376377
}
377378

379+
// SearchSingleResult holds the server's single response to a search request
380+
type SearchSingleResult struct {
381+
// Entry is the returned entry
382+
Entry *Entry
383+
// Referral is the returned referral
384+
Referral string
385+
// Controls are the returned controls
386+
Controls []Control
387+
// Error is set when the search request was failed
388+
Error error
389+
}
390+
391+
// Print outputs a human-readable description
392+
func (s *SearchSingleResult) Print() {
393+
s.Entry.Print()
394+
}
395+
396+
// PrettyPrint outputs a human-readable description with indenting
397+
func (s *SearchSingleResult) PrettyPrint(indent int) {
398+
s.Entry.PrettyPrint(indent)
399+
}
400+
378401
// SearchRequest represents a search request to send to the server
379402
type SearchRequest struct {
380403
BaseDN string
@@ -559,6 +582,17 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) {
559582
}
560583
}
561584

585+
// SearchAsync performs a search request and returns all search results asynchronously.
586+
// This means you get all results until an error happens (or the search successfully finished),
587+
// e.g. for size / time limited requests all are recieved until the limit is reached.
588+
// To stop the search, call cancel function returned context.
589+
func (l *Conn) SearchAsync(
590+
ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response {
591+
r := newSearchResponse(l, bufferSize)
592+
r.start(ctx, searchRequest)
593+
return r
594+
}
595+
562596
// unpackAttributes will extract all given LDAP attributes and it's values
563597
// from the ber.Packet
564598
func unpackAttributes(children []*ber.Packet) []*EntryAttribute {

0 commit comments

Comments
 (0)