diff --git a/client.go b/client.go index f58dc17..d2738c7 100644 --- a/client.go +++ b/client.go @@ -37,9 +37,9 @@ type Client struct { } type RDRCStore interface { - LoadRDRC(transportName string, qName string) (rejected bool) - SaveRDRC(transportName string, qName string) error - SaveRDRCAsync(transportName string, qName string, logger logger.Logger) + LoadRDRC(transportName string, qName string, qType uint16) (rejected bool) + SaveRDRC(transportName string, qName string, qType uint16) error + SaveRDRCAsync(transportName string, qName string, qType uint16, logger logger.Logger) } type transportCacheKey struct { @@ -143,7 +143,7 @@ func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transp } ctx = contextWithTransportName(ctx, transport.Name()) if responseChecker != nil && c.rdrc != nil { - rejected := c.rdrc.LoadRDRC(transport.Name(), question.Name) + rejected := c.rdrc.LoadRDRC(transport.Name(), question.Name, question.Qtype) if rejected { return nil, ErrResponseRejectedCached } @@ -154,7 +154,7 @@ func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transp } if responseChecker != nil && !responseChecker(response) { if c.rdrc != nil { - c.rdrc.SaveRDRCAsync(transport.Name(), question.Name, c.logger) + c.rdrc.SaveRDRCAsync(transport.Name(), question.Name, question.Qtype, c.logger) } return response, ErrResponseRejected } @@ -259,7 +259,13 @@ func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transpor } } if responseChecker != nil && c.rdrc != nil { - rejected := c.rdrc.LoadRDRC(transport.Name(), dnsName) + var rejected bool + if strategy != DomainStrategyUseIPv6 { + rejected = c.rdrc.LoadRDRC(transport.Name(), dnsName, dns.TypeA) + } + if !rejected && strategy != DomainStrategyUseIPv4 { + rejected = c.rdrc.LoadRDRC(transport.Name(), dnsName, dns.TypeAAAA) + } if rejected { return nil, ErrResponseRejectedCached } @@ -271,7 +277,16 @@ func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transpor } if responseChecker != nil && !responseChecker(response) { if c.rdrc != nil { - c.rdrc.SaveRDRCAsync(transport.Name(), dnsName, c.logger) + if common.Any(response, func(addr netip.Addr) bool { + return addr.Is4() + }) { + c.rdrc.SaveRDRCAsync(transport.Name(), dnsName, dns.TypeA, c.logger) + } + if common.Any(response, func(addr netip.Addr) bool { + return addr.Is6() + }) { + c.rdrc.SaveRDRCAsync(transport.Name(), dnsName, dns.TypeAAAA, c.logger) + } } return response, ErrResponseRejected }