diff --git a/cmd/jpictl/dns.go b/cmd/jpictl/dns.go index 6520bbf..4431801 100644 --- a/cmd/jpictl/dns.go +++ b/cmd/jpictl/dns.go @@ -3,6 +3,7 @@ package main import ( "context" "os" + "time" "github.com/spf13/cobra" @@ -10,9 +11,13 @@ import ( "git.jpi.io/amery/jpictl/pkg/dns" ) -func newDNSManager(m *cluster.Cluster) (*dns.Manager, error) { - ctx := context.TODO() +const ( + // DNSSyncTimeout specifies how long are we willing to wait for a DNS + // synchronization + DNSSyncTimeout = 10 * time.Second +) +func newDNSManager(m *cluster.Cluster, provider dns.Provider) (*dns.Manager, error) { domain := m.Domain if m.Name != "" { domain = m.Name + "." + domain @@ -23,6 +28,26 @@ func newDNSManager(m *cluster.Cluster) (*dns.Manager, error) { return nil, err } + if provider != nil { + // set provider only if specified + err = dns.WithProvider(provider)(mgr) + if err != nil { + return nil, err + } + } + + if err := populateDNSManager(mgr, m); err != nil { + return nil, err + } + + return mgr, nil +} + +func populateDNSManager(mgr *dns.Manager, m *cluster.Cluster) error { + var err error + + ctx := context.TODO() + m.ForEachZone(func(z *cluster.Zone) bool { z.ForEachMachine(func(p *cluster.Machine) bool { err = mgr.AddHost(ctx, z.Name, p.ID, true, p.PublicAddresses...) @@ -32,7 +57,7 @@ func newDNSManager(m *cluster.Cluster) (*dns.Manager, error) { return err != nil }) if err != nil { - return nil, err + return err } m.ForEachRegion(func(r *cluster.Region) bool { @@ -43,11 +68,8 @@ func newDNSManager(m *cluster.Cluster) (*dns.Manager, error) { return err != nil }) - if err != nil { - return nil, err - } - return mgr, nil + return err } // Command @@ -65,7 +87,7 @@ var dnsWriteCmd = &cobra.Command{ return err } - mgr, err := newDNSManager(m) + mgr, err := newDNSManager(m, nil) if err != nil { return err } @@ -75,8 +97,36 @@ var dnsWriteCmd = &cobra.Command{ }, } +var dnsSyncCmd = &cobra.Command{ + Use: "sync", + Short: "dns sync updates public DNS records", + PreRun: setVerbosity, + RunE: func(_ *cobra.Command, _ []string) error { + cred, err := dns.DefaultDNSProvider() + if err != nil { + return err + } + + m, err := cfg.LoadZones(true) + if err != nil { + return err + } + + mgr, err := newDNSManager(m, cred) + if err != nil { + return err + } + + ctx, cancel := context.WithTimeout(context.Background(), DNSSyncTimeout) + defer cancel() + + return mgr.Sync(ctx) + }, +} + func init() { rootCmd.AddCommand(dnsCmd) dnsCmd.AddCommand(dnsWriteCmd) + dnsCmd.AddCommand(dnsSyncCmd) } diff --git a/pkg/dns/provider.go b/pkg/dns/provider.go index 6a53a24..ff87ddd 100644 --- a/pkg/dns/provider.go +++ b/pkg/dns/provider.go @@ -18,6 +18,8 @@ const ( type Provider interface { libdns.RecordGetter libdns.RecordDeleter + libdns.RecordSetter + libdns.RecordAppender } // DefaultDNSProvider returns a cloudflare DNS provider diff --git a/pkg/dns/record.go b/pkg/dns/record.go index be97701..2b98905 100644 --- a/pkg/dns/record.go +++ b/pkg/dns/record.go @@ -167,6 +167,7 @@ func (mgr *Manager) genAllAddrRecords() []AddrRecord { out = append(out, rec) } + SortAddrRecords(out) return out } diff --git a/pkg/dns/sync.go b/pkg/dns/sync.go new file mode 100644 index 0000000..05b73cb --- /dev/null +++ b/pkg/dns/sync.go @@ -0,0 +1,347 @@ +package dns + +import ( + "context" + "errors" + "net/netip" + "sort" + "strings" + "time" + + "darvaza.org/core" + "darvaza.org/slog" + "github.com/libdns/libdns" +) + +// SyncAddrRecord is similar to AddrRecord but include libdns.Record details +// fetched from the Provider +type SyncAddrRecord struct { + Name string + Addrs []SyncAddr +} + +// SyncAddr extends netip.Addr with ID and TTL fetched from the Provider +type SyncAddr struct { + ID string + Addr netip.Addr + TTL time.Duration +} + +// Export assembles a libdns.Record +func (rec *SyncAddr) Export(name string) libdns.Record { + return libdns.Record{ + ID: rec.ID, + Name: name, + Type: core.IIf(rec.Addr.Is6(), "AAAA", "A"), + TTL: time.Second, + Value: rec.Addr.String(), + } +} + +// SortSyncAddrSlice sorts a slice of [SyncAddr] by its address +func SortSyncAddrSlice(s []SyncAddr) []SyncAddr { + sort.Slice(s, func(i, j int) bool { + a1 := s[i].Addr + a2 := s[j].Addr + return a1.Less(a2) + }) + return s +} + +// GetRecords pulls all the address records on DNS for our domain +func (mgr *Manager) GetRecords(ctx context.Context) ([]SyncAddrRecord, error) { + if mgr.p == nil { + return nil, errors.New("dns provider not specified") + } + + recs, err := mgr.p.GetRecords(ctx, mgr.domain) + if err != nil { + return nil, err + } + + return mgr.filteredRecords(recs) +} + +// AsSyncAddr converts a A or AAAA [libdns.Record] into a [SyncAddr] +func (mgr *Manager) AsSyncAddr(rr libdns.Record) (SyncAddr, bool, error) { + var out SyncAddr + var addr netip.Addr + + // skip non-address types + if rr.Type != "A" && rr.Type != "AAAA" { + return out, false, nil + } + + // skip entries not containing our suffix + if mgr.suffix != "" { + if !strings.HasSuffix(rr.Name, mgr.suffix) { + return out, false, nil + } + } + + err := addr.UnmarshalText([]byte(rr.Value)) + if err != nil { + // invalid address on A or AAAA record + return out, false, err + } + + out = SyncAddr{ + ID: rr.ID, + TTL: rr.TTL, + Addr: addr, + } + + return out, true, nil +} + +func (mgr *Manager) filteredRecords(recs []libdns.Record) ([]SyncAddrRecord, error) { + // filter and convert + cache := make(map[string][]SyncAddr) + for _, rr := range recs { + addr, ok, err := mgr.AsSyncAddr(rr) + switch { + case err != nil: + // skip invalid addresses + mgr.l.Error(). + WithField("subsystem", "dns"). + WithField(slog.ErrorFieldName, err). + WithField("name", rr.Name). + WithField("type", rr.Type). + WithField("addr", rr.Value). + Print() + case ok: + // store + cache[rr.Name] = append(cache[rr.Name], addr) + } + } + + // prepare records + out := make([]SyncAddrRecord, len(cache)) + names := make([]string, 0, len(cache)) + for name := range cache { + names = append(names, name) + } + sort.Strings(names) + + for i, name := range names { + addrs := cache[name] + + out[i] = SyncAddrRecord{ + Name: name, + Addrs: SortSyncAddrSlice(addrs), + } + } + + return out, nil +} + +// Sync updates all the address records on DNS for our domain +func (mgr *Manager) Sync(ctx context.Context) error { + current, err := mgr.GetRecords(ctx) + if err != nil { + return core.Wrap(err, "GetRecords") + } + + goal := mgr.genAllAddrRecords() + for _, p := range makeSyncMap(current, goal) { + err := mgr.doSync(ctx, p.Name, p.Before, p.After) + if err != nil { + return err + } + } + + return nil +} + +func (mgr *Manager) doSync(ctx context.Context, name string, + before []SyncAddr, after []netip.Addr) error { + // + var err error + + for _, a := range after { + before, err = mgr.doSyncUpdateOrInsert(ctx, name, a, before) + if err != nil { + return err + } + } + + for _, b := range before { + err = mgr.doSyncRemove(ctx, name, b) + if err != nil { + return err + } + } + + return nil +} + +func (mgr *Manager) doSyncUpdateOrInsert(ctx context.Context, name string, + addr netip.Addr, addrs []SyncAddr) ([]SyncAddr, error) { + // + var err error + + i, ok := findSyncAddrSorted(addr, addrs) + if ok { + rec := addrs[i] + + addrs = append(addrs[:i], addrs[i+1:]...) + err = mgr.doSyncUpdate(ctx, name, addr, rec) + } else { + err = mgr.doSyncInsert(ctx, name, addr) + } + + return addrs, err +} + +func (mgr *Manager) doSyncUpdate(ctx context.Context, name string, + addr netip.Addr, rec SyncAddr) error { + // + var log slog.Logger + var msg string + var err error + + if rec.TTL != time.Second { + // amend TTL + + // TODO: batch updates + _, err = mgr.p.SetRecords(ctx, mgr.domain, []libdns.Record{ + rec.Export(name), + }) + + if err == nil { + log = mgr.l.Info() + msg = "Updated" + } else { + log = mgr.l.Error(). + WithField(slog.ErrorFieldName, err) + msg = "Failed" + } + } else { + log = mgr.l.Info() + msg = "OK" + } + + log. + WithField("subsystem", "dns"). + WithField("name", name). + WithField("addr", addr). + Print(msg) + + return err +} + +func (mgr *Manager) doSyncInsert(ctx context.Context, name string, + addr netip.Addr) error { + // + var log slog.Logger + var msg string + + rec := libdns.Record{ + Name: name, + Type: core.IIf(addr.Is6(), "AAAA", "A"), + TTL: time.Second, + Value: addr.String(), + } + + _, err := mgr.p.AppendRecords(ctx, mgr.domain, []libdns.Record{ + rec, + }) + + if err != nil { + log = mgr.l.Error(). + WithField(slog.ErrorFieldName, err) + msg = "Failed to Add" + } else { + log = mgr.l.Info() + msg = "Added" + } + + log. + WithField("subsystem", "dns"). + WithField("name", name). + WithField("addr", addr). + Print(msg) + return err +} + +func (mgr *Manager) doSyncRemove(ctx context.Context, name string, + rec SyncAddr) error { + // + var log slog.Logger + var msg string + + // TODO: batch deletes + _, err := mgr.p.DeleteRecords(ctx, mgr.domain, []libdns.Record{ + rec.Export(name), + }) + + if err != nil { + log = mgr.l.Error(). + WithField(slog.ErrorFieldName, err) + msg = "Failed to Delete" + } else { + log = mgr.l.Warn() + msg = "Deleted" + } + + log. + WithField("subsystem", "dns"). + WithField("name", name). + WithField("addr", rec.Addr). + Print(msg) + return err +} + +func findSyncAddrSorted(target netip.Addr, addrs []SyncAddr) (int, bool) { + for i, a := range addrs { + switch target.Compare(a.Addr) { + case 0: + // match + return i, true + case -1: + // miss + return -1, false + default: + // next + } + } + + return -1, false +} + +type syncMapEntry struct { + Name string + Before []SyncAddr + After []netip.Addr +} + +func makeSyncMap(current []SyncAddrRecord, + goal []AddrRecord) map[string]syncMapEntry { + // + data := make(map[string]syncMapEntry) + + for _, cur := range current { + me, ok := data[cur.Name] + if !ok { + me = syncMapEntry{ + Name: cur.Name, + } + } + + me.Before = append(me.Before, cur.Addrs...) + data[cur.Name] = me + } + + for _, rr := range goal { + me, ok := data[rr.Name] + if !ok { + me = syncMapEntry{ + Name: rr.Name, + } + } + me.After = append(me.After, rr.Addr...) + data[rr.Name] = me + } + + return data +}