Browse Source

Merge pull request 'dns: introduce `jpictl dns sync` to update public DNS records' (#25)

A and AAAA only

Reviewed-on: #25
pull/26/head v0.6.7
Alejandro Mery 1 year ago
parent
commit
76b40e63c7
  1. 66
      cmd/jpictl/dns.go
  2. 2
      pkg/dns/provider.go
  3. 1
      pkg/dns/record.go
  4. 347
      pkg/dns/sync.go

66
cmd/jpictl/dns.go

@ -3,6 +3,7 @@ package main
import ( import (
"context" "context"
"os" "os"
"time"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -10,9 +11,13 @@ import (
"git.jpi.io/amery/jpictl/pkg/dns" "git.jpi.io/amery/jpictl/pkg/dns"
) )
func newDNSManager(m *cluster.Cluster) (*dns.Manager, error) { const (
ctx := context.TODO() // DNSSyncTimeout specifies how long are we willing to wait for a DNS
// synchronization
DNSSyncTimeout = 10 * time.Second
)
func newDNSManager(m *cluster.Cluster, provider dns.Provider) (*dns.Manager, error) {
domain := m.Domain domain := m.Domain
if m.Name != "" { if m.Name != "" {
domain = m.Name + "." + domain domain = m.Name + "." + domain
@ -23,6 +28,26 @@ func newDNSManager(m *cluster.Cluster) (*dns.Manager, error) {
return nil, err return nil, err
} }
if provider != nil {
// set provider only if specified
err = dns.WithProvider(provider)(mgr)
if err != nil {
return nil, err
}
}
if err := populateDNSManager(mgr, m); err != nil {
return nil, err
}
return mgr, nil
}
func populateDNSManager(mgr *dns.Manager, m *cluster.Cluster) error {
var err error
ctx := context.TODO()
m.ForEachZone(func(z *cluster.Zone) bool { m.ForEachZone(func(z *cluster.Zone) bool {
z.ForEachMachine(func(p *cluster.Machine) bool { z.ForEachMachine(func(p *cluster.Machine) bool {
err = mgr.AddHost(ctx, z.Name, p.ID, true, p.PublicAddresses...) err = mgr.AddHost(ctx, z.Name, p.ID, true, p.PublicAddresses...)
@ -32,7 +57,7 @@ func newDNSManager(m *cluster.Cluster) (*dns.Manager, error) {
return err != nil return err != nil
}) })
if err != nil { if err != nil {
return nil, err return err
} }
m.ForEachRegion(func(r *cluster.Region) bool { m.ForEachRegion(func(r *cluster.Region) bool {
@ -43,11 +68,8 @@ func newDNSManager(m *cluster.Cluster) (*dns.Manager, error) {
return err != nil return err != nil
}) })
if err != nil {
return nil, err
}
return mgr, nil return err
} }
// Command // Command
@ -65,7 +87,7 @@ var dnsWriteCmd = &cobra.Command{
return err return err
} }
mgr, err := newDNSManager(m) mgr, err := newDNSManager(m, nil)
if err != nil { if err != nil {
return err return err
} }
@ -75,8 +97,36 @@ var dnsWriteCmd = &cobra.Command{
}, },
} }
var dnsSyncCmd = &cobra.Command{
Use: "sync",
Short: "dns sync updates public DNS records",
PreRun: setVerbosity,
RunE: func(_ *cobra.Command, _ []string) error {
cred, err := dns.DefaultDNSProvider()
if err != nil {
return err
}
m, err := cfg.LoadZones(true)
if err != nil {
return err
}
mgr, err := newDNSManager(m, cred)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), DNSSyncTimeout)
defer cancel()
return mgr.Sync(ctx)
},
}
func init() { func init() {
rootCmd.AddCommand(dnsCmd) rootCmd.AddCommand(dnsCmd)
dnsCmd.AddCommand(dnsWriteCmd) dnsCmd.AddCommand(dnsWriteCmd)
dnsCmd.AddCommand(dnsSyncCmd)
} }

2
pkg/dns/provider.go

@ -18,6 +18,8 @@ const (
type Provider interface { type Provider interface {
libdns.RecordGetter libdns.RecordGetter
libdns.RecordDeleter libdns.RecordDeleter
libdns.RecordSetter
libdns.RecordAppender
} }
// DefaultDNSProvider returns a cloudflare DNS provider // DefaultDNSProvider returns a cloudflare DNS provider

1
pkg/dns/record.go

@ -167,6 +167,7 @@ func (mgr *Manager) genAllAddrRecords() []AddrRecord {
out = append(out, rec) out = append(out, rec)
} }
SortAddrRecords(out)
return out return out
} }

347
pkg/dns/sync.go

@ -0,0 +1,347 @@
package dns
import (
"context"
"errors"
"net/netip"
"sort"
"strings"
"time"
"darvaza.org/core"
"darvaza.org/slog"
"github.com/libdns/libdns"
)
// SyncAddrRecord is similar to AddrRecord but include libdns.Record details
// fetched from the Provider
type SyncAddrRecord struct {
Name string
Addrs []SyncAddr
}
// SyncAddr extends netip.Addr with ID and TTL fetched from the Provider
type SyncAddr struct {
ID string
Addr netip.Addr
TTL time.Duration
}
// Export assembles a libdns.Record
func (rec *SyncAddr) Export(name string) libdns.Record {
return libdns.Record{
ID: rec.ID,
Name: name,
Type: core.IIf(rec.Addr.Is6(), "AAAA", "A"),
TTL: time.Second,
Value: rec.Addr.String(),
}
}
// SortSyncAddrSlice sorts a slice of [SyncAddr] by its address
func SortSyncAddrSlice(s []SyncAddr) []SyncAddr {
sort.Slice(s, func(i, j int) bool {
a1 := s[i].Addr
a2 := s[j].Addr
return a1.Less(a2)
})
return s
}
// GetRecords pulls all the address records on DNS for our domain
func (mgr *Manager) GetRecords(ctx context.Context) ([]SyncAddrRecord, error) {
if mgr.p == nil {
return nil, errors.New("dns provider not specified")
}
recs, err := mgr.p.GetRecords(ctx, mgr.domain)
if err != nil {
return nil, err
}
return mgr.filteredRecords(recs)
}
// AsSyncAddr converts a A or AAAA [libdns.Record] into a [SyncAddr]
func (mgr *Manager) AsSyncAddr(rr libdns.Record) (SyncAddr, bool, error) {
var out SyncAddr
var addr netip.Addr
// skip non-address types
if rr.Type != "A" && rr.Type != "AAAA" {
return out, false, nil
}
// skip entries not containing our suffix
if mgr.suffix != "" {
if !strings.HasSuffix(rr.Name, mgr.suffix) {
return out, false, nil
}
}
err := addr.UnmarshalText([]byte(rr.Value))
if err != nil {
// invalid address on A or AAAA record
return out, false, err
}
out = SyncAddr{
ID: rr.ID,
TTL: rr.TTL,
Addr: addr,
}
return out, true, nil
}
func (mgr *Manager) filteredRecords(recs []libdns.Record) ([]SyncAddrRecord, error) {
// filter and convert
cache := make(map[string][]SyncAddr)
for _, rr := range recs {
addr, ok, err := mgr.AsSyncAddr(rr)
switch {
case err != nil:
// skip invalid addresses
mgr.l.Error().
WithField("subsystem", "dns").
WithField(slog.ErrorFieldName, err).
WithField("name", rr.Name).
WithField("type", rr.Type).
WithField("addr", rr.Value).
Print()
case ok:
// store
cache[rr.Name] = append(cache[rr.Name], addr)
}
}
// prepare records
out := make([]SyncAddrRecord, len(cache))
names := make([]string, 0, len(cache))
for name := range cache {
names = append(names, name)
}
sort.Strings(names)
for i, name := range names {
addrs := cache[name]
out[i] = SyncAddrRecord{
Name: name,
Addrs: SortSyncAddrSlice(addrs),
}
}
return out, nil
}
// Sync updates all the address records on DNS for our domain
func (mgr *Manager) Sync(ctx context.Context) error {
current, err := mgr.GetRecords(ctx)
if err != nil {
return core.Wrap(err, "GetRecords")
}
goal := mgr.genAllAddrRecords()
for _, p := range makeSyncMap(current, goal) {
err := mgr.doSync(ctx, p.Name, p.Before, p.After)
if err != nil {
return err
}
}
return nil
}
func (mgr *Manager) doSync(ctx context.Context, name string,
before []SyncAddr, after []netip.Addr) error {
//
var err error
for _, a := range after {
before, err = mgr.doSyncUpdateOrInsert(ctx, name, a, before)
if err != nil {
return err
}
}
for _, b := range before {
err = mgr.doSyncRemove(ctx, name, b)
if err != nil {
return err
}
}
return nil
}
func (mgr *Manager) doSyncUpdateOrInsert(ctx context.Context, name string,
addr netip.Addr, addrs []SyncAddr) ([]SyncAddr, error) {
//
var err error
i, ok := findSyncAddrSorted(addr, addrs)
if ok {
rec := addrs[i]
addrs = append(addrs[:i], addrs[i+1:]...)
err = mgr.doSyncUpdate(ctx, name, addr, rec)
} else {
err = mgr.doSyncInsert(ctx, name, addr)
}
return addrs, err
}
func (mgr *Manager) doSyncUpdate(ctx context.Context, name string,
addr netip.Addr, rec SyncAddr) error {
//
var log slog.Logger
var msg string
var err error
if rec.TTL != time.Second {
// amend TTL
// TODO: batch updates
_, err = mgr.p.SetRecords(ctx, mgr.domain, []libdns.Record{
rec.Export(name),
})
if err == nil {
log = mgr.l.Info()
msg = "Updated"
} else {
log = mgr.l.Error().
WithField(slog.ErrorFieldName, err)
msg = "Failed"
}
} else {
log = mgr.l.Info()
msg = "OK"
}
log.
WithField("subsystem", "dns").
WithField("name", name).
WithField("addr", addr).
Print(msg)
return err
}
func (mgr *Manager) doSyncInsert(ctx context.Context, name string,
addr netip.Addr) error {
//
var log slog.Logger
var msg string
rec := libdns.Record{
Name: name,
Type: core.IIf(addr.Is6(), "AAAA", "A"),
TTL: time.Second,
Value: addr.String(),
}
_, err := mgr.p.AppendRecords(ctx, mgr.domain, []libdns.Record{
rec,
})
if err != nil {
log = mgr.l.Error().
WithField(slog.ErrorFieldName, err)
msg = "Failed to Add"
} else {
log = mgr.l.Info()
msg = "Added"
}
log.
WithField("subsystem", "dns").
WithField("name", name).
WithField("addr", addr).
Print(msg)
return err
}
func (mgr *Manager) doSyncRemove(ctx context.Context, name string,
rec SyncAddr) error {
//
var log slog.Logger
var msg string
// TODO: batch deletes
_, err := mgr.p.DeleteRecords(ctx, mgr.domain, []libdns.Record{
rec.Export(name),
})
if err != nil {
log = mgr.l.Error().
WithField(slog.ErrorFieldName, err)
msg = "Failed to Delete"
} else {
log = mgr.l.Warn()
msg = "Deleted"
}
log.
WithField("subsystem", "dns").
WithField("name", name).
WithField("addr", rec.Addr).
Print(msg)
return err
}
func findSyncAddrSorted(target netip.Addr, addrs []SyncAddr) (int, bool) {
for i, a := range addrs {
switch target.Compare(a.Addr) {
case 0:
// match
return i, true
case -1:
// miss
return -1, false
default:
// next
}
}
return -1, false
}
type syncMapEntry struct {
Name string
Before []SyncAddr
After []netip.Addr
}
func makeSyncMap(current []SyncAddrRecord,
goal []AddrRecord) map[string]syncMapEntry {
//
data := make(map[string]syncMapEntry)
for _, cur := range current {
me, ok := data[cur.Name]
if !ok {
me = syncMapEntry{
Name: cur.Name,
}
}
me.Before = append(me.Before, cur.Addrs...)
data[cur.Name] = me
}
for _, rr := range goal {
me, ok := data[rr.Name]
if !ok {
me = syncMapEntry{
Name: rr.Name,
}
}
me.After = append(me.After, rr.Addr...)
data[rr.Name] = me
}
return data
}
Loading…
Cancel
Save