diff --git a/cmd/jpictl/config.go b/cmd/jpictl/config.go index a649b53..ec20b7e 100644 --- a/cmd/jpictl/config.go +++ b/cmd/jpictl/config.go @@ -14,6 +14,8 @@ var cfg = &Config{ } // LoadZones loads all zones and machines in the config directory -func (cfg *Config) LoadZones() (*zones.Zones, error) { - return zones.New(cfg.Base, cfg.Domain) +func (cfg *Config) LoadZones(resolve bool) (*zones.Zones, error) { + return zones.New(cfg.Base, cfg.Domain, + zones.ResolvePublicAddresses(resolve), + ) } diff --git a/cmd/jpictl/dump.go b/cmd/jpictl/dump.go index eae0801..fb892ba 100644 --- a/cmd/jpictl/dump.go +++ b/cmd/jpictl/dump.go @@ -58,7 +58,7 @@ var dumpCmd = &cobra.Command{ var buf bytes.Buffer var enc Encoder - m, err := cfg.LoadZones() + m, err := cfg.LoadZones(true) if err != nil { return err } diff --git a/cmd/jpictl/env.go b/cmd/jpictl/env.go index 659862a..a968725 100644 --- a/cmd/jpictl/env.go +++ b/cmd/jpictl/env.go @@ -11,7 +11,7 @@ var envCmd = &cobra.Command{ Use: "env", Short: "generates environment variables for shell scripts", RunE: func(_ *cobra.Command, _ []string) error { - m, err := cfg.LoadZones() + m, err := cfg.LoadZones(false) if err != nil { return err } diff --git a/cmd/jpictl/write.go b/cmd/jpictl/write.go index 93fffb7..acb79c2 100644 --- a/cmd/jpictl/write.go +++ b/cmd/jpictl/write.go @@ -9,7 +9,7 @@ var writeCmd = &cobra.Command{ Use: "write", Short: "rewrites all config files", RunE: func(_ *cobra.Command, _ []string) error { - m, err := cfg.LoadZones() + m, err := cfg.LoadZones(false) if err != nil { return err } diff --git a/pkg/zones/machine_scan.go b/pkg/zones/machine_scan.go index ee46203..52efe6e 100644 --- a/pkg/zones/machine_scan.go +++ b/pkg/zones/machine_scan.go @@ -7,8 +7,9 @@ import ( "time" ) -func (m *Machine) lookupNetIP() ([]netip.Addr, error) { - timeout := 2 * time.Second +// LookupNetIP uses the DNS Resolver to get the public addresses associated +// to a Machine +func (m *Machine) LookupNetIP(timeout time.Duration) ([]netip.Addr, error) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() @@ -16,8 +17,9 @@ func (m *Machine) lookupNetIP() ([]netip.Addr, error) { return m.zone.zones.resolver.LookupNetIP(ctx, "ip", m.FullName()) } -func (m *Machine) updatePublicAddresses() error { - addrs, err := m.lookupNetIP() +// UpdatePublicAddresses uses the DNS Resolver to set Machine.PublicAddresses +func (m *Machine) UpdatePublicAddresses() error { + addrs, err := m.LookupNetIP(2 * time.Second) if err != nil { return err } @@ -52,12 +54,16 @@ func (m *Machine) setID() error { return nil } -func (m *Machine) scan() error { +func (m *Machine) scan(opts *ScanOptions) error { for i := 0; i < RingsCount; i++ { if err := m.tryApplyWireguardConfig(i); err != nil { return err } } - return m.updatePublicAddresses() + if !opts.DontResolvePublicAddresses { + return m.UpdatePublicAddresses() + } + + return nil } diff --git a/pkg/zones/options.go b/pkg/zones/options.go new file mode 100644 index 0000000..4c1e806 --- /dev/null +++ b/pkg/zones/options.go @@ -0,0 +1,109 @@ +package zones + +import ( + "io/fs" + "path/filepath" + + "darvaza.org/resolver" + "github.com/hack-pad/hackpadfs/os" +) + +// A ScanOption preconfigures the Zones before scanning +type ScanOption func(*Zones, *ScanOptions) error + +// ScanOptions contains flags used by the initial scan +type ScanOptions struct { + // DontResolvePublicAddresses indicates we shouldn't + // pre-populate Machine.PublicAddresses during the + // initial scan + DontResolvePublicAddresses bool +} + +// ResolvePublicAddresses instructs the scanner to use +// the DNS resolver to get PublicAddresses of nodes. +// Default is true +func ResolvePublicAddresses(resolve bool) ScanOption { + return func(m *Zones, opt *ScanOptions) error { + opt.DontResolvePublicAddresses = !resolve + return nil + } +} + +// WithLookuper specifies what resolver.Lookuper to use to +// find public addresses +func WithLookuper(h resolver.Lookuper) ScanOption { + return func(m *Zones, opt *ScanOptions) error { + if h == nil { + return fs.ErrInvalid + } + m.resolver = resolver.NewResolver(h) + return nil + } +} + +// WithResolver specifies what resolver to use to find +// public addresses. if nil is passed, the [net.Resolver] will be used. +// The default is using Cloudflare's 1.1.1.1. +func WithResolver(h resolver.Resolver) ScanOption { + return func(m *Zones, opt *ScanOptions) error { + if h == nil { + h = resolver.SystemResolver(true) + } + + m.resolver = h + return nil + } +} + +func (m *Zones) setDefaults(opt *ScanOptions) error { + if m.resolver == nil { + h := resolver.NewCloudflareLookuper() + + if err := WithLookuper(h)(m, opt); err != nil { + return err + } + } + + return nil +} + +// NewFS builds a [Zones] tree using the given directory +func NewFS(dir fs.FS, domain string, opts ...ScanOption) (*Zones, error) { + var scanOptions ScanOptions + + z := &Zones{ + dir: dir, + domain: domain, + } + + for _, opt := range opts { + if err := opt(z, &scanOptions); err != nil { + return nil, err + } + } + + if err := z.setDefaults(&scanOptions); err != nil { + return nil, err + } + + if err := z.scan(&scanOptions); err != nil { + return nil, err + } + + return z, nil +} + +// New builds a [Zones] tree using the given directory +func New(dir, domain string, opts ...ScanOption) (*Zones, error) { + dir, err := filepath.Abs(dir) + if err != nil { + return nil, err + } + + base, err := os.NewFS().Sub(dir[1:]) + if err != nil { + return nil, err + } + + return NewFS(base, domain, opts...) +} diff --git a/pkg/zones/scan.go b/pkg/zones/scan.go index 8faf8a8..08734a3 100644 --- a/pkg/zones/scan.go +++ b/pkg/zones/scan.go @@ -5,15 +5,15 @@ import ( "sort" ) -func (m *Zones) scan() error { - for _, fn := range []func() error{ +func (m *Zones) scan(opts *ScanOptions) error { + for _, fn := range []func(*ScanOptions) error{ m.scanDirectory, m.scanMachines, m.scanZoneIDs, m.scanSort, m.scanGateways, } { - if err := fn(); err != nil { + if err := fn(opts); err != nil { return err } } @@ -21,7 +21,7 @@ func (m *Zones) scan() error { return nil } -func (m *Zones) scanDirectory() error { +func (m *Zones) scanDirectory(_ *ScanOptions) error { // each directory is a zone entries, err := fs.ReadDir(m.dir, ".") if err != nil { @@ -46,16 +46,16 @@ func (m *Zones) scanDirectory() error { return nil } -func (m *Zones) scanMachines() error { +func (m *Zones) scanMachines(opts *ScanOptions) error { var err error m.ForEachMachine(func(p *Machine) bool { - err = p.scan() + err = p.scan(opts) return err != nil }) return err } -func (m *Zones) scanZoneIDs() error { +func (m *Zones) scanZoneIDs(_ *ScanOptions) error { var hasMissing bool var lastZoneID int @@ -85,7 +85,7 @@ func (m *Zones) scanZoneIDs() error { return nil } -func (m *Zones) scanSort() error { +func (m *Zones) scanSort(_ *ScanOptions) error { sort.SliceStable(m.Zones, func(i, j int) bool { id1 := m.Zones[i].ID id2 := m.Zones[j].ID @@ -111,7 +111,7 @@ func (m *Zones) scanSort() error { return nil } -func (m *Zones) scanGateways() error { +func (m *Zones) scanGateways(_ *ScanOptions) error { var err error m.ForEachZone(func(z *Zone) bool { diff --git a/pkg/zones/zones.go b/pkg/zones/zones.go index ae9eb33..f25f264 100644 --- a/pkg/zones/zones.go +++ b/pkg/zones/zones.go @@ -3,11 +3,8 @@ package zones import ( "io/fs" - "path/filepath" "sort" - "github.com/hack-pad/hackpadfs/os" - "darvaza.org/resolver" ) @@ -190,35 +187,3 @@ func (m *Zones) GetMachineByName(name string) (*Machine, bool) { return out, out != nil } - -// NewFS builds a [Zones] tree using the given directory -func NewFS(dir fs.FS, domain string) (*Zones, error) { - lockuper := resolver.NewCloudflareLookuper() - - z := &Zones{ - dir: dir, - resolver: resolver.NewResolver(lockuper), - domain: domain, - } - - if err := z.scan(); err != nil { - return nil, err - } - - return z, nil -} - -// New builds a [Zones] tree using the given directory -func New(dir, domain string) (*Zones, error) { - dir, err := filepath.Abs(dir) - if err != nil { - return nil, err - } - - base, err := os.NewFS().Sub(dir[1:]) - if err != nil { - return nil, err - } - - return NewFS(base, domain) -}