diff --git a/pkg/dns/record.go b/pkg/dns/record.go index 2b98905..1b9fc51 100644 --- a/pkg/dns/record.go +++ b/pkg/dns/record.go @@ -6,6 +6,7 @@ import ( "io" "net/netip" "sort" + "strings" "time" "darvaza.org/core" @@ -38,6 +39,53 @@ func SortAddrRecords(s []AddrRecord) []AddrRecord { return s } +// SortRecords sorts a slice of [libdns.Record], by Name, Type and Value +func SortRecords(s []libdns.Record) []libdns.Record { + sort.Slice(s, func(i, j int) bool { + return lessRecord(s[i], s[j]) + }) + return s +} + +func lessRecord(a, b libdns.Record) bool { + aName := strings.ToLower(a.Name) + bName := strings.ToLower(b.Name) + + switch { + case aName < bName: + return true + case aName > bName: + return false + } + + aType := strings.ToUpper(a.Type) + bType := strings.ToUpper(b.Type) + + switch { + case aType < bType: + return true + case aType > bType: + return false + case aType == "A", aType == "AAAA": + // IP Addresses + var aa, ba netip.Addr + + switch { + case aa.UnmarshalText([]byte(a.Value)) != nil: + // bad address on a + return true + case ba.UnmarshalText([]byte(b.Value)) != nil: + // bad address on b + return false + default: + return aa.Less(ba) + } + default: + // text + return a.Value < b.Value + } +} + // SortRegions sorts regions. first by length those 3-character // or shorter, and then by length. It's mostly aimed at // supporting ISO-3166 order diff --git a/pkg/dns/show.go b/pkg/dns/show.go index 24cba56..e6a7fbb 100644 --- a/pkg/dns/show.go +++ b/pkg/dns/show.go @@ -1,11 +1,15 @@ package dns import ( + "bytes" "context" "fmt" + "io" + "os" "time" "darvaza.org/core" + "github.com/libdns/libdns" ) // Show shows current DNS entries @@ -15,15 +19,40 @@ func (mgr *Manager) Show(ctx context.Context, names ...string) error { return core.Wrap(err, "GetRecords") } + SortRecords(recs) + return writeRecords(recs, os.Stdout) +} + +func writeRecords(recs []libdns.Record, w io.Writer) error { + var buf bytes.Buffer + for _, rr := range recs { - _, _ = fmt.Printf("%s\t%v\tIN\t%s\t%s\t; %s\n", - rr.Name, - int(rr.TTL/time.Second), - rr.Type, - rr.Value, - rr.ID) + _ = fmtRecord(&buf, rr) + _, _ = buf.WriteRune('\n') + } + _, _ = fmt.Fprintf(&buf, "; %v records\n", len(recs)) + + _, err := buf.WriteTo(w) + return err +} + +func fmtRecord(w io.Writer, rr libdns.Record) error { + ttl := int(rr.TTL / time.Second) + if ttl < 1 { + ttl = 1 + } + + _, err := fmt.Fprintf(w, "%s\t%v\tIN\t%s\t%s", + rr.Name, + ttl, + rr.Type, + rr.Value) + + if err == nil { + if rr.ID != "" { + _, err = fmt.Fprintf(w, "\t; %s", rr.ID) + } } - _, _ = fmt.Printf("; %v records\n", len(recs)) - return nil + return err }