Browse Source

dns: refactor GetRecords() to allow commands other than sync

Signed-off-by: Alejandro Mery <amery@jpi.io>
pull/27/head
Alejandro Mery 1 year ago
parent
commit
b0f4be7047
  1. 12
      pkg/dns/error.go
  2. 63
      pkg/dns/manager.go
  3. 17
      pkg/dns/sync.go

12
pkg/dns/error.go

@ -0,0 +1,12 @@
package dns
import "errors"
var (
// ErrNoDNSProvider indicates a [libdns.Provider] wasn't assigned
// to the [Manager]
ErrNoDNSProvider = errors.New("dns provider not specified")
// ErrNoDomain indicates a domain wasn't specified
ErrNoDomain = errors.New("domain not specified")
)

63
pkg/dns/manager.go

@ -2,15 +2,16 @@ package dns
import ( import (
"context" "context"
"errors"
"io/fs" "io/fs"
"net/netip" "net/netip"
"strings" "strings"
"darvaza.org/core" "darvaza.org/core"
"darvaza.org/slog" "darvaza.org/slog"
"git.jpi.io/amery/jpictl/pkg/cluster" "github.com/libdns/libdns"
"golang.org/x/net/publicsuffix" "golang.org/x/net/publicsuffix"
"git.jpi.io/amery/jpictl/pkg/cluster"
) )
// Manager is a DNS Manager instance // Manager is a DNS Manager instance
@ -71,7 +72,7 @@ func (mgr *Manager) setDefaults() error {
} }
if mgr.domain == "" || mgr.suffix == "" { if mgr.domain == "" || mgr.suffix == "" {
return errors.New("domain not specified") return ErrNoDomain
} }
for _, opt := range opts { for _, opt := range opts {
@ -120,6 +121,62 @@ func NewManager(opts ...ManagerOption) (*Manager, error) {
return mgr, nil return mgr, nil
} }
// GetRecords pulls all the address records on DNS for our domain,
// optionally only those matching the given names.
func (mgr *Manager) GetRecords(ctx context.Context, names ...string) ([]libdns.Record, error) {
if mgr.p == nil {
return nil, ErrNoDNSProvider
}
recs, err := mgr.p.GetRecords(ctx, mgr.domain)
switch {
case err != nil:
// failed
return nil, err
case len(recs) == 0:
// empty
return []libdns.Record{}, nil
case mgr.suffix == "" && len(names) == 0:
// unfiltered
return recs, nil
default:
// filtered
recs = mgr.filterRecords(recs, names...)
return recs, nil
}
}
func (mgr *Manager) filterRecords(recs []libdns.Record, names ...string) []libdns.Record {
out := make([]libdns.Record, 0, len(recs))
for _, rr := range recs {
name, ok := mgr.matchSuffix(rr)
switch {
case !ok:
// skip, wrong subdomain
continue
case len(names) == 0:
// unfiltered, take it
case !core.SliceContains(names, name):
// skip, not one of the requested names
continue
}
out = append(out, rr)
}
return out
}
func (mgr *Manager) matchSuffix(rr libdns.Record) (string, bool) {
if mgr.suffix == "" {
// no suffix
return rr.Name, true
}
// remove suffix
return strings.CutSuffix(rr.Name, mgr.suffix)
}
// AddHost registers a host // AddHost registers a host
func (mgr *Manager) AddHost(_ context.Context, zone string, id int, func (mgr *Manager) AddHost(_ context.Context, zone string, id int,
active bool, addrs ...netip.Addr) error { active bool, addrs ...netip.Addr) error {

17
pkg/dns/sync.go

@ -2,7 +2,6 @@ package dns
import ( import (
"context" "context"
"errors"
"net/netip" "net/netip"
"sort" "sort"
"strings" "strings"
@ -48,18 +47,14 @@ func SortSyncAddrSlice(s []SyncAddr) []SyncAddr {
return s return s
} }
// GetRecords pulls all the address records on DNS for our domain // GetSyncRecords pulls all the address records on DNS for our domain
func (mgr *Manager) GetRecords(ctx context.Context) ([]SyncAddrRecord, error) { func (mgr *Manager) GetSyncRecords(ctx context.Context) ([]SyncAddrRecord, error) {
if mgr.p == nil { recs, err := mgr.GetRecords(ctx)
return nil, errors.New("dns provider not specified")
}
recs, err := mgr.p.GetRecords(ctx, mgr.domain)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return mgr.filteredRecords(recs) return mgr.filteredSyncRecords(recs)
} }
// AsSyncAddr converts a A or AAAA [libdns.Record] into a [SyncAddr] // AsSyncAddr converts a A or AAAA [libdns.Record] into a [SyncAddr]
@ -94,7 +89,7 @@ func (mgr *Manager) AsSyncAddr(rr libdns.Record) (SyncAddr, bool, error) {
return out, true, nil return out, true, nil
} }
func (mgr *Manager) filteredRecords(recs []libdns.Record) ([]SyncAddrRecord, error) { func (mgr *Manager) filteredSyncRecords(recs []libdns.Record) ([]SyncAddrRecord, error) {
// filter and convert // filter and convert
cache := make(map[string][]SyncAddr) cache := make(map[string][]SyncAddr)
for _, rr := range recs { for _, rr := range recs {
@ -137,7 +132,7 @@ func (mgr *Manager) filteredRecords(recs []libdns.Record) ([]SyncAddrRecord, err
// Sync updates all the address records on DNS for our domain // Sync updates all the address records on DNS for our domain
func (mgr *Manager) Sync(ctx context.Context) error { func (mgr *Manager) Sync(ctx context.Context) error {
current, err := mgr.GetRecords(ctx) current, err := mgr.GetSyncRecords(ctx)
if err != nil { if err != nil {
return core.Wrap(err, "GetRecords") return core.Wrap(err, "GetRecords")
} }

Loading…
Cancel
Save