package cluster import ( "io/fs" "os" "git.jpi.io/amery/jpictl/pkg/rings" ) var ( _ WireguardConfigPruner = (*Cluster)(nil) _ WireguardConfigPruner = (*Zone)(nil) _ WireguardConfigPruner = (*Machine)(nil) _ WireguardConfigWriter = (*Cluster)(nil) _ WireguardConfigWriter = (*Zone)(nil) _ WireguardConfigWriter = (*Machine)(nil) _ WireguardConfigSyncer = (*Cluster)(nil) _ WireguardConfigSyncer = (*Zone)(nil) _ WireguardConfigSyncer = (*Machine)(nil) _ WireguardKeysWriter = (*Cluster)(nil) _ WireguardKeysWriter = (*Zone)(nil) _ WireguardKeysWriter = (*Machine)(nil) ) // A WireguardConfigPruner deletes wgN.conf on all machines under // its scope with the specified ring disabled type WireguardConfigPruner interface { PruneWireguardConfig(ring rings.RingID) error } // PruneWireguardConfig removes wgN.conf files of machines with // the corresponding ring disabled on all zones func (m *Cluster) PruneWireguardConfig(ring rings.RingID) error { return pruneWireguardConfig(m, ring) } // PruneWireguardConfig removes wgN.conf files of machines with // the corresponding ring disabled. func (z *Zone) PruneWireguardConfig(ring rings.RingID) error { return pruneWireguardConfig(z, ring) } func pruneWireguardConfig(m MachineIterator, ring rings.RingID) error { var err error m.ForEachMachine(func(p *Machine) bool { err = p.PruneWireguardConfig(ring) if os.IsNotExist(err) { // ignore err = nil } return err != nil }) return err } // PruneWireguardConfig deletes the wgN.conf file if its // presence on the ring is disabled func (m *Machine) PruneWireguardConfig(ring rings.RingID) error { _, ok := m.getRingInfo(ring) if !ok { return m.RemoveWireguardConfig(ring) } return nil } // A WireguardConfigWriter rewrites all wgN.conf on all machines under // its scope attached to that ring type WireguardConfigWriter interface { WriteWireguardConfig(ring rings.RingID) error } // WriteWireguardConfig rewrites all wgN.conf on all machines // attached to that ring func (m *Cluster) WriteWireguardConfig(ring rings.RingID) error { switch ring { case rings.RingZeroID: return writeWireguardConfig(m, m, ring) case rings.RingOneID: var err error m.ForEachZone(func(z *Zone) bool { err = writeWireguardConfig(m, z, ring) return err != nil }) return err default: return ErrInvalidRing(ring) } } // WriteWireguardConfig rewrites all wgN.conf on all machines // on the Zone attached to that ring func (z *Zone) WriteWireguardConfig(ring rings.RingID) error { if ring == rings.RingZeroID || ring == rings.RingOneID { return writeWireguardConfig(z.zones, z.zones, ring) } return ErrInvalidRing(ring) } func writeWireguardConfig(z ZoneIterator, m MachineIterator, ring rings.RingID) error { r, err := NewRing(z, m, ring) if err != nil { return err } r.ForEachMachine(func(p *Machine) bool { err = p.writeWireguardRingConfig(r) return err != nil }) return err } // WriteWireguardConfig rewrites the wgN.conf file of this Machine // if enabled func (m *Machine) WriteWireguardConfig(ring rings.RingID) error { r, err := NewRing(m.zone.zones, m.zone, ring) if err != nil { return err } return m.writeWireguardRingConfig(r) } func (m *Machine) writeWireguardRingConfig(r *Ring) error { ring, err := AsWireguardInterfaceID(r.ID) if err != nil { return err } wg, err := r.ExportConfig(m) if err != nil { return nil } f, err := m.CreateTruncFile(ring.ConfFile()) if err != nil { return err } defer f.Close() _, err = wg.WriteTo(f) return err } // A WireguardConfigSyncer updates all wgN.conf on all machines under // its scope reflecting the state of the ring type WireguardConfigSyncer interface { SyncWireguardConfig(ring rings.RingID) error } // SyncWireguardConfig updates all wgN.conf files for the specified // ring func (m *Cluster) SyncWireguardConfig(ring rings.RingID) error { switch ring { case rings.RingZeroID: return syncWireguardConfig(m, m, ring) case rings.RingOneID: var err error m.ForEachZone(func(z *Zone) bool { err = syncWireguardConfig(m, z, ring) return err != nil }) return err default: return ErrInvalidRing(ring) } } // SyncWireguardConfig updates all wgN.conf files for the specified // ring func (z *Zone) SyncWireguardConfig(ring rings.RingID) error { switch ring { case rings.RingZeroID: return syncWireguardConfig(z.zones, z.zones, ring) case rings.RingOneID: return syncWireguardConfig(z.zones, z, ring) default: return ErrInvalidRing(ring) } } func syncWireguardConfig(z ZoneIterator, m MachineIterator, ring rings.RingID) error { r, err := NewRing(z, m, ring) if err != nil { return err } r.ForEachMachine(func(p *Machine) bool { if _, ok := p.getRingInfo(ring); ok { err = p.writeWireguardRingConfig(r) } else { err = p.RemoveWireguardConfig(ring) } return err != nil }) return err } // SyncWireguardConfig updates all wgN.conf files for the specified // ring func (m *Machine) SyncWireguardConfig(ring rings.RingID) error { return m.zone.SyncWireguardConfig(ring) } // A WireguardKeysWriter writes the Wireguard Keys for all machines // under its scope for the specified ring type WireguardKeysWriter interface { WriteWireguardKeys(ring rings.RingID) error } // WriteWireguardKeys rewrites all wgN.{key,pub} files func (m *Cluster) WriteWireguardKeys(ring rings.RingID) error { return writeWireguardKeys(m, ring) } // WriteWireguardKeys rewrites all wgN.{key,pub} files on this zone func (z *Zone) WriteWireguardKeys(ring rings.RingID) error { return writeWireguardKeys(z, ring) } func writeWireguardKeys(m MachineIterator, ring rings.RingID) error { var err error m.ForEachMachine(func(p *Machine) bool { err = p.WriteWireguardKeys(ring) if os.IsNotExist(err) { // ignore err = nil } return err != nil }) return err } // WriteWireguardKeys writes the wgN.key/wgN.pub files func (m *Machine) WriteWireguardKeys(ringID rings.RingID) error { var err error var key, pub string var ri *RingInfo ri, _ = m.getRingInfo(ringID) if ri != nil { key = ri.Keys.PrivateKey.String() pub = ri.Keys.PublicKey.String() } switch { case key == "": return fs.ErrNotExist case pub == "": pub = ri.Keys.PrivateKey.Public().String() } keyFile, pubFile, _ := ri.Ring.Files() err = m.WriteStringFile(key+"\n", keyFile) if err != nil { return err } err = m.WriteStringFile(pub+"\n", pubFile) if err != nil { return err } return nil }