You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
347 lines
6.5 KiB
347 lines
6.5 KiB
package dns |
|
|
|
import ( |
|
"context" |
|
"errors" |
|
"net/netip" |
|
"sort" |
|
"strings" |
|
"time" |
|
|
|
"darvaza.org/core" |
|
"darvaza.org/slog" |
|
"github.com/libdns/libdns" |
|
) |
|
|
|
// SyncAddrRecord is similar to AddrRecord but include libdns.Record details |
|
// fetched from the Provider |
|
type SyncAddrRecord struct { |
|
Name string |
|
Addrs []SyncAddr |
|
} |
|
|
|
// SyncAddr extends netip.Addr with ID and TTL fetched from the Provider |
|
type SyncAddr struct { |
|
ID string |
|
Addr netip.Addr |
|
TTL time.Duration |
|
} |
|
|
|
// Export assembles a libdns.Record |
|
func (rec *SyncAddr) Export(name string) libdns.Record { |
|
return libdns.Record{ |
|
ID: rec.ID, |
|
Name: name, |
|
Type: core.IIf(rec.Addr.Is6(), "AAAA", "A"), |
|
TTL: time.Second, |
|
Value: rec.Addr.String(), |
|
} |
|
} |
|
|
|
// SortSyncAddrSlice sorts a slice of [SyncAddr] by its address |
|
func SortSyncAddrSlice(s []SyncAddr) []SyncAddr { |
|
sort.Slice(s, func(i, j int) bool { |
|
a1 := s[i].Addr |
|
a2 := s[j].Addr |
|
return a1.Less(a2) |
|
}) |
|
return s |
|
} |
|
|
|
// GetRecords pulls all the address records on DNS for our domain |
|
func (mgr *Manager) GetRecords(ctx context.Context) ([]SyncAddrRecord, error) { |
|
if mgr.p == nil { |
|
return nil, errors.New("dns provider not specified") |
|
} |
|
|
|
recs, err := mgr.p.GetRecords(ctx, mgr.domain) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
return mgr.filteredRecords(recs) |
|
} |
|
|
|
// AsSyncAddr converts a A or AAAA [libdns.Record] into a [SyncAddr] |
|
func (mgr *Manager) AsSyncAddr(rr libdns.Record) (SyncAddr, bool, error) { |
|
var out SyncAddr |
|
var addr netip.Addr |
|
|
|
// skip non-address types |
|
if rr.Type != "A" && rr.Type != "AAAA" { |
|
return out, false, nil |
|
} |
|
|
|
// skip entries not containing our suffix |
|
if mgr.suffix != "" { |
|
if !strings.HasSuffix(rr.Name, mgr.suffix) { |
|
return out, false, nil |
|
} |
|
} |
|
|
|
err := addr.UnmarshalText([]byte(rr.Value)) |
|
if err != nil { |
|
// invalid address on A or AAAA record |
|
return out, false, err |
|
} |
|
|
|
out = SyncAddr{ |
|
ID: rr.ID, |
|
TTL: rr.TTL, |
|
Addr: addr, |
|
} |
|
|
|
return out, true, nil |
|
} |
|
|
|
func (mgr *Manager) filteredRecords(recs []libdns.Record) ([]SyncAddrRecord, error) { |
|
// filter and convert |
|
cache := make(map[string][]SyncAddr) |
|
for _, rr := range recs { |
|
addr, ok, err := mgr.AsSyncAddr(rr) |
|
switch { |
|
case err != nil: |
|
// skip invalid addresses |
|
mgr.l.Error(). |
|
WithField("subsystem", "dns"). |
|
WithField(slog.ErrorFieldName, err). |
|
WithField("name", rr.Name). |
|
WithField("type", rr.Type). |
|
WithField("addr", rr.Value). |
|
Print() |
|
case ok: |
|
// store |
|
cache[rr.Name] = append(cache[rr.Name], addr) |
|
} |
|
} |
|
|
|
// prepare records |
|
out := make([]SyncAddrRecord, len(cache)) |
|
names := make([]string, 0, len(cache)) |
|
for name := range cache { |
|
names = append(names, name) |
|
} |
|
sort.Strings(names) |
|
|
|
for i, name := range names { |
|
addrs := cache[name] |
|
|
|
out[i] = SyncAddrRecord{ |
|
Name: name, |
|
Addrs: SortSyncAddrSlice(addrs), |
|
} |
|
} |
|
|
|
return out, nil |
|
} |
|
|
|
// Sync updates all the address records on DNS for our domain |
|
func (mgr *Manager) Sync(ctx context.Context) error { |
|
current, err := mgr.GetRecords(ctx) |
|
if err != nil { |
|
return core.Wrap(err, "GetRecords") |
|
} |
|
|
|
goal := mgr.genAllAddrRecords() |
|
for _, p := range makeSyncMap(current, goal) { |
|
err := mgr.doSync(ctx, p.Name, p.Before, p.After) |
|
if err != nil { |
|
return err |
|
} |
|
} |
|
|
|
return nil |
|
} |
|
|
|
func (mgr *Manager) doSync(ctx context.Context, name string, |
|
before []SyncAddr, after []netip.Addr) error { |
|
// |
|
var err error |
|
|
|
for _, a := range after { |
|
before, err = mgr.doSyncUpdateOrInsert(ctx, name, a, before) |
|
if err != nil { |
|
return err |
|
} |
|
} |
|
|
|
for _, b := range before { |
|
err = mgr.doSyncRemove(ctx, name, b) |
|
if err != nil { |
|
return err |
|
} |
|
} |
|
|
|
return nil |
|
} |
|
|
|
func (mgr *Manager) doSyncUpdateOrInsert(ctx context.Context, name string, |
|
addr netip.Addr, addrs []SyncAddr) ([]SyncAddr, error) { |
|
// |
|
var err error |
|
|
|
i, ok := findSyncAddrSorted(addr, addrs) |
|
if ok { |
|
rec := addrs[i] |
|
|
|
addrs = append(addrs[:i], addrs[i+1:]...) |
|
err = mgr.doSyncUpdate(ctx, name, addr, rec) |
|
} else { |
|
err = mgr.doSyncInsert(ctx, name, addr) |
|
} |
|
|
|
return addrs, err |
|
} |
|
|
|
func (mgr *Manager) doSyncUpdate(ctx context.Context, name string, |
|
addr netip.Addr, rec SyncAddr) error { |
|
// |
|
var log slog.Logger |
|
var msg string |
|
var err error |
|
|
|
if rec.TTL != time.Second { |
|
// amend TTL |
|
|
|
// TODO: batch updates |
|
_, err = mgr.p.SetRecords(ctx, mgr.domain, []libdns.Record{ |
|
rec.Export(name), |
|
}) |
|
|
|
if err == nil { |
|
log = mgr.l.Info() |
|
msg = "Updated" |
|
} else { |
|
log = mgr.l.Error(). |
|
WithField(slog.ErrorFieldName, err) |
|
msg = "Failed" |
|
} |
|
} else { |
|
log = mgr.l.Info() |
|
msg = "OK" |
|
} |
|
|
|
log. |
|
WithField("subsystem", "dns"). |
|
WithField("name", name). |
|
WithField("addr", addr). |
|
Print(msg) |
|
|
|
return err |
|
} |
|
|
|
func (mgr *Manager) doSyncInsert(ctx context.Context, name string, |
|
addr netip.Addr) error { |
|
// |
|
var log slog.Logger |
|
var msg string |
|
|
|
rec := libdns.Record{ |
|
Name: name, |
|
Type: core.IIf(addr.Is6(), "AAAA", "A"), |
|
TTL: time.Second, |
|
Value: addr.String(), |
|
} |
|
|
|
_, err := mgr.p.AppendRecords(ctx, mgr.domain, []libdns.Record{ |
|
rec, |
|
}) |
|
|
|
if err != nil { |
|
log = mgr.l.Error(). |
|
WithField(slog.ErrorFieldName, err) |
|
msg = "Failed to Add" |
|
} else { |
|
log = mgr.l.Info() |
|
msg = "Added" |
|
} |
|
|
|
log. |
|
WithField("subsystem", "dns"). |
|
WithField("name", name). |
|
WithField("addr", addr). |
|
Print(msg) |
|
return err |
|
} |
|
|
|
func (mgr *Manager) doSyncRemove(ctx context.Context, name string, |
|
rec SyncAddr) error { |
|
// |
|
var log slog.Logger |
|
var msg string |
|
|
|
// TODO: batch deletes |
|
_, err := mgr.p.DeleteRecords(ctx, mgr.domain, []libdns.Record{ |
|
rec.Export(name), |
|
}) |
|
|
|
if err != nil { |
|
log = mgr.l.Error(). |
|
WithField(slog.ErrorFieldName, err) |
|
msg = "Failed to Delete" |
|
} else { |
|
log = mgr.l.Warn() |
|
msg = "Deleted" |
|
} |
|
|
|
log. |
|
WithField("subsystem", "dns"). |
|
WithField("name", name). |
|
WithField("addr", rec.Addr). |
|
Print(msg) |
|
return err |
|
} |
|
|
|
func findSyncAddrSorted(target netip.Addr, addrs []SyncAddr) (int, bool) { |
|
for i, a := range addrs { |
|
switch target.Compare(a.Addr) { |
|
case 0: |
|
// match |
|
return i, true |
|
case -1: |
|
// miss |
|
return -1, false |
|
default: |
|
// next |
|
} |
|
} |
|
|
|
return -1, false |
|
} |
|
|
|
type syncMapEntry struct { |
|
Name string |
|
Before []SyncAddr |
|
After []netip.Addr |
|
} |
|
|
|
func makeSyncMap(current []SyncAddrRecord, |
|
goal []AddrRecord) map[string]syncMapEntry { |
|
// |
|
data := make(map[string]syncMapEntry) |
|
|
|
for _, cur := range current { |
|
me, ok := data[cur.Name] |
|
if !ok { |
|
me = syncMapEntry{ |
|
Name: cur.Name, |
|
} |
|
} |
|
|
|
me.Before = append(me.Before, cur.Addrs...) |
|
data[cur.Name] = me |
|
} |
|
|
|
for _, rr := range goal { |
|
me, ok := data[rr.Name] |
|
if !ok { |
|
me = syncMapEntry{ |
|
Name: rr.Name, |
|
} |
|
} |
|
me.After = append(me.After, rr.Addr...) |
|
data[rr.Name] = me |
|
} |
|
|
|
return data |
|
}
|
|
|