From b0f4be70477ca0399ee6c869ae960d70f7098a85 Mon Sep 17 00:00:00 2001 From: Alejandro Mery Date: Mon, 23 Oct 2023 22:41:17 +0000 Subject: [PATCH] dns: refactor GetRecords() to allow commands other than sync Signed-off-by: Alejandro Mery --- pkg/dns/error.go | 12 +++++++++ pkg/dns/manager.go | 63 +++++++++++++++++++++++++++++++++++++++++++--- pkg/dns/sync.go | 17 +++++-------- 3 files changed, 78 insertions(+), 14 deletions(-) create mode 100644 pkg/dns/error.go diff --git a/pkg/dns/error.go b/pkg/dns/error.go new file mode 100644 index 0000000..abe9126 --- /dev/null +++ b/pkg/dns/error.go @@ -0,0 +1,12 @@ +package dns + +import "errors" + +var ( + // ErrNoDNSProvider indicates a [libdns.Provider] wasn't assigned + // to the [Manager] + ErrNoDNSProvider = errors.New("dns provider not specified") + + // ErrNoDomain indicates a domain wasn't specified + ErrNoDomain = errors.New("domain not specified") +) diff --git a/pkg/dns/manager.go b/pkg/dns/manager.go index 8e29252..24dd62f 100644 --- a/pkg/dns/manager.go +++ b/pkg/dns/manager.go @@ -2,15 +2,16 @@ package dns import ( "context" - "errors" "io/fs" "net/netip" "strings" "darvaza.org/core" "darvaza.org/slog" - "git.jpi.io/amery/jpictl/pkg/cluster" + "github.com/libdns/libdns" "golang.org/x/net/publicsuffix" + + "git.jpi.io/amery/jpictl/pkg/cluster" ) // Manager is a DNS Manager instance @@ -71,7 +72,7 @@ func (mgr *Manager) setDefaults() error { } if mgr.domain == "" || mgr.suffix == "" { - return errors.New("domain not specified") + return ErrNoDomain } for _, opt := range opts { @@ -120,6 +121,62 @@ func NewManager(opts ...ManagerOption) (*Manager, error) { return mgr, nil } +// GetRecords pulls all the address records on DNS for our domain, +// optionally only those matching the given names. +func (mgr *Manager) GetRecords(ctx context.Context, names ...string) ([]libdns.Record, error) { + if mgr.p == nil { + return nil, ErrNoDNSProvider + } + + recs, err := mgr.p.GetRecords(ctx, mgr.domain) + switch { + case err != nil: + // failed + return nil, err + case len(recs) == 0: + // empty + return []libdns.Record{}, nil + case mgr.suffix == "" && len(names) == 0: + // unfiltered + return recs, nil + default: + // filtered + recs = mgr.filterRecords(recs, names...) + return recs, nil + } +} + +func (mgr *Manager) filterRecords(recs []libdns.Record, names ...string) []libdns.Record { + out := make([]libdns.Record, 0, len(recs)) + for _, rr := range recs { + name, ok := mgr.matchSuffix(rr) + switch { + case !ok: + // skip, wrong subdomain + continue + case len(names) == 0: + // unfiltered, take it + case !core.SliceContains(names, name): + // skip, not one of the requested names + continue + } + + out = append(out, rr) + } + + return out +} + +func (mgr *Manager) matchSuffix(rr libdns.Record) (string, bool) { + if mgr.suffix == "" { + // no suffix + return rr.Name, true + } + + // remove suffix + return strings.CutSuffix(rr.Name, mgr.suffix) +} + // AddHost registers a host func (mgr *Manager) AddHost(_ context.Context, zone string, id int, active bool, addrs ...netip.Addr) error { diff --git a/pkg/dns/sync.go b/pkg/dns/sync.go index 05b73cb..712022c 100644 --- a/pkg/dns/sync.go +++ b/pkg/dns/sync.go @@ -2,7 +2,6 @@ package dns import ( "context" - "errors" "net/netip" "sort" "strings" @@ -48,18 +47,14 @@ func SortSyncAddrSlice(s []SyncAddr) []SyncAddr { return s } -// GetRecords pulls all the address records on DNS for our domain -func (mgr *Manager) GetRecords(ctx context.Context) ([]SyncAddrRecord, error) { - if mgr.p == nil { - return nil, errors.New("dns provider not specified") - } - - recs, err := mgr.p.GetRecords(ctx, mgr.domain) +// GetSyncRecords pulls all the address records on DNS for our domain +func (mgr *Manager) GetSyncRecords(ctx context.Context) ([]SyncAddrRecord, error) { + recs, err := mgr.GetRecords(ctx) if err != nil { return nil, err } - return mgr.filteredRecords(recs) + return mgr.filteredSyncRecords(recs) } // AsSyncAddr converts a A or AAAA [libdns.Record] into a [SyncAddr] @@ -94,7 +89,7 @@ func (mgr *Manager) AsSyncAddr(rr libdns.Record) (SyncAddr, bool, error) { return out, true, nil } -func (mgr *Manager) filteredRecords(recs []libdns.Record) ([]SyncAddrRecord, error) { +func (mgr *Manager) filteredSyncRecords(recs []libdns.Record) ([]SyncAddrRecord, error) { // filter and convert cache := make(map[string][]SyncAddr) for _, rr := range recs { @@ -137,7 +132,7 @@ func (mgr *Manager) filteredRecords(recs []libdns.Record) ([]SyncAddrRecord, err // Sync updates all the address records on DNS for our domain func (mgr *Manager) Sync(ctx context.Context) error { - current, err := mgr.GetRecords(ctx) + current, err := mgr.GetSyncRecords(ctx) if err != nil { return core.Wrap(err, "GetRecords") }