From dd585b0fa289781ecc834bcdff97a3d0243df9d9 Mon Sep 17 00:00:00 2001 From: Alejandro Mery Date: Tue, 12 Sep 2023 20:28:50 +0000 Subject: [PATCH] dns: add Sync() mechanism to update A/AAAA records on the DNS provider Signed-off-by: Alejandro Mery --- pkg/dns/record.go | 1 + pkg/dns/sync.go | 347 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 348 insertions(+) create mode 100644 pkg/dns/sync.go diff --git a/pkg/dns/record.go b/pkg/dns/record.go index be97701..2b98905 100644 --- a/pkg/dns/record.go +++ b/pkg/dns/record.go @@ -167,6 +167,7 @@ func (mgr *Manager) genAllAddrRecords() []AddrRecord { out = append(out, rec) } + SortAddrRecords(out) return out } diff --git a/pkg/dns/sync.go b/pkg/dns/sync.go new file mode 100644 index 0000000..05b73cb --- /dev/null +++ b/pkg/dns/sync.go @@ -0,0 +1,347 @@ +package dns + +import ( + "context" + "errors" + "net/netip" + "sort" + "strings" + "time" + + "darvaza.org/core" + "darvaza.org/slog" + "github.com/libdns/libdns" +) + +// SyncAddrRecord is similar to AddrRecord but include libdns.Record details +// fetched from the Provider +type SyncAddrRecord struct { + Name string + Addrs []SyncAddr +} + +// SyncAddr extends netip.Addr with ID and TTL fetched from the Provider +type SyncAddr struct { + ID string + Addr netip.Addr + TTL time.Duration +} + +// Export assembles a libdns.Record +func (rec *SyncAddr) Export(name string) libdns.Record { + return libdns.Record{ + ID: rec.ID, + Name: name, + Type: core.IIf(rec.Addr.Is6(), "AAAA", "A"), + TTL: time.Second, + Value: rec.Addr.String(), + } +} + +// SortSyncAddrSlice sorts a slice of [SyncAddr] by its address +func SortSyncAddrSlice(s []SyncAddr) []SyncAddr { + sort.Slice(s, func(i, j int) bool { + a1 := s[i].Addr + a2 := s[j].Addr + return a1.Less(a2) + }) + return s +} + +// GetRecords pulls all the address records on DNS for our domain +func (mgr *Manager) GetRecords(ctx context.Context) ([]SyncAddrRecord, error) { + if mgr.p == nil { + return nil, errors.New("dns provider not specified") + } + + recs, err := mgr.p.GetRecords(ctx, mgr.domain) + if err != nil { + return nil, err + } + + return mgr.filteredRecords(recs) +} + +// AsSyncAddr converts a A or AAAA [libdns.Record] into a [SyncAddr] +func (mgr *Manager) AsSyncAddr(rr libdns.Record) (SyncAddr, bool, error) { + var out SyncAddr + var addr netip.Addr + + // skip non-address types + if rr.Type != "A" && rr.Type != "AAAA" { + return out, false, nil + } + + // skip entries not containing our suffix + if mgr.suffix != "" { + if !strings.HasSuffix(rr.Name, mgr.suffix) { + return out, false, nil + } + } + + err := addr.UnmarshalText([]byte(rr.Value)) + if err != nil { + // invalid address on A or AAAA record + return out, false, err + } + + out = SyncAddr{ + ID: rr.ID, + TTL: rr.TTL, + Addr: addr, + } + + return out, true, nil +} + +func (mgr *Manager) filteredRecords(recs []libdns.Record) ([]SyncAddrRecord, error) { + // filter and convert + cache := make(map[string][]SyncAddr) + for _, rr := range recs { + addr, ok, err := mgr.AsSyncAddr(rr) + switch { + case err != nil: + // skip invalid addresses + mgr.l.Error(). + WithField("subsystem", "dns"). + WithField(slog.ErrorFieldName, err). + WithField("name", rr.Name). + WithField("type", rr.Type). + WithField("addr", rr.Value). + Print() + case ok: + // store + cache[rr.Name] = append(cache[rr.Name], addr) + } + } + + // prepare records + out := make([]SyncAddrRecord, len(cache)) + names := make([]string, 0, len(cache)) + for name := range cache { + names = append(names, name) + } + sort.Strings(names) + + for i, name := range names { + addrs := cache[name] + + out[i] = SyncAddrRecord{ + Name: name, + Addrs: SortSyncAddrSlice(addrs), + } + } + + return out, nil +} + +// Sync updates all the address records on DNS for our domain +func (mgr *Manager) Sync(ctx context.Context) error { + current, err := mgr.GetRecords(ctx) + if err != nil { + return core.Wrap(err, "GetRecords") + } + + goal := mgr.genAllAddrRecords() + for _, p := range makeSyncMap(current, goal) { + err := mgr.doSync(ctx, p.Name, p.Before, p.After) + if err != nil { + return err + } + } + + return nil +} + +func (mgr *Manager) doSync(ctx context.Context, name string, + before []SyncAddr, after []netip.Addr) error { + // + var err error + + for _, a := range after { + before, err = mgr.doSyncUpdateOrInsert(ctx, name, a, before) + if err != nil { + return err + } + } + + for _, b := range before { + err = mgr.doSyncRemove(ctx, name, b) + if err != nil { + return err + } + } + + return nil +} + +func (mgr *Manager) doSyncUpdateOrInsert(ctx context.Context, name string, + addr netip.Addr, addrs []SyncAddr) ([]SyncAddr, error) { + // + var err error + + i, ok := findSyncAddrSorted(addr, addrs) + if ok { + rec := addrs[i] + + addrs = append(addrs[:i], addrs[i+1:]...) + err = mgr.doSyncUpdate(ctx, name, addr, rec) + } else { + err = mgr.doSyncInsert(ctx, name, addr) + } + + return addrs, err +} + +func (mgr *Manager) doSyncUpdate(ctx context.Context, name string, + addr netip.Addr, rec SyncAddr) error { + // + var log slog.Logger + var msg string + var err error + + if rec.TTL != time.Second { + // amend TTL + + // TODO: batch updates + _, err = mgr.p.SetRecords(ctx, mgr.domain, []libdns.Record{ + rec.Export(name), + }) + + if err == nil { + log = mgr.l.Info() + msg = "Updated" + } else { + log = mgr.l.Error(). + WithField(slog.ErrorFieldName, err) + msg = "Failed" + } + } else { + log = mgr.l.Info() + msg = "OK" + } + + log. + WithField("subsystem", "dns"). + WithField("name", name). + WithField("addr", addr). + Print(msg) + + return err +} + +func (mgr *Manager) doSyncInsert(ctx context.Context, name string, + addr netip.Addr) error { + // + var log slog.Logger + var msg string + + rec := libdns.Record{ + Name: name, + Type: core.IIf(addr.Is6(), "AAAA", "A"), + TTL: time.Second, + Value: addr.String(), + } + + _, err := mgr.p.AppendRecords(ctx, mgr.domain, []libdns.Record{ + rec, + }) + + if err != nil { + log = mgr.l.Error(). + WithField(slog.ErrorFieldName, err) + msg = "Failed to Add" + } else { + log = mgr.l.Info() + msg = "Added" + } + + log. + WithField("subsystem", "dns"). + WithField("name", name). + WithField("addr", addr). + Print(msg) + return err +} + +func (mgr *Manager) doSyncRemove(ctx context.Context, name string, + rec SyncAddr) error { + // + var log slog.Logger + var msg string + + // TODO: batch deletes + _, err := mgr.p.DeleteRecords(ctx, mgr.domain, []libdns.Record{ + rec.Export(name), + }) + + if err != nil { + log = mgr.l.Error(). + WithField(slog.ErrorFieldName, err) + msg = "Failed to Delete" + } else { + log = mgr.l.Warn() + msg = "Deleted" + } + + log. + WithField("subsystem", "dns"). + WithField("name", name). + WithField("addr", rec.Addr). + Print(msg) + return err +} + +func findSyncAddrSorted(target netip.Addr, addrs []SyncAddr) (int, bool) { + for i, a := range addrs { + switch target.Compare(a.Addr) { + case 0: + // match + return i, true + case -1: + // miss + return -1, false + default: + // next + } + } + + return -1, false +} + +type syncMapEntry struct { + Name string + Before []SyncAddr + After []netip.Addr +} + +func makeSyncMap(current []SyncAddrRecord, + goal []AddrRecord) map[string]syncMapEntry { + // + data := make(map[string]syncMapEntry) + + for _, cur := range current { + me, ok := data[cur.Name] + if !ok { + me = syncMapEntry{ + Name: cur.Name, + } + } + + me.Before = append(me.Before, cur.Addrs...) + data[cur.Name] = me + } + + for _, rr := range goal { + me, ok := data[rr.Name] + if !ok { + me = syncMapEntry{ + Name: rr.Name, + } + } + me.After = append(me.After, rr.Addr...) + data[rr.Name] = me + } + + return data +}