Browse Source

cluster: decouple RingID from WireguardInterfaceID

Signed-off-by: Alejandro Mery <amery@jpi.io>
pull/51/head
Alejandro Mery 4 months ago
parent
commit
187149c129
  1. 2
      go.mod
  2. 4
      go.sum
  3. 15
      pkg/cluster/errors.go
  4. 6
      pkg/cluster/machine.go
  5. 93
      pkg/cluster/machine_rings.go
  6. 11
      pkg/cluster/machine_scan.go
  7. 90
      pkg/cluster/rings.go
  8. 6
      pkg/cluster/sync.go
  9. 80
      pkg/cluster/wireguard.go

2
go.mod

@ -4,7 +4,7 @@ go 1.19
require ( require (
asciigoat.org/ini v0.2.5 asciigoat.org/ini v0.2.5
darvaza.org/core v0.13.1 darvaza.org/core v0.13.3
darvaza.org/resolver v0.9.2 darvaza.org/resolver v0.9.2
darvaza.org/sidecar v0.4.0 darvaza.org/sidecar v0.4.0
darvaza.org/slog v0.5.7 darvaza.org/slog v0.5.7

4
go.sum

@ -4,8 +4,8 @@ asciigoat.org/ini v0.2.5 h1:4gRIp9rU+XQt8+HMqZO5R7GavMv9Yl2+N+je6djDIAE=
asciigoat.org/ini v0.2.5/go.mod h1:gmXzJ9XFqf1NLk5nQkj04USQ4tMtdRJHNQX6vp3DzjU= asciigoat.org/ini v0.2.5/go.mod h1:gmXzJ9XFqf1NLk5nQkj04USQ4tMtdRJHNQX6vp3DzjU=
darvaza.org/cache/x/simplelru v0.1.8 h1:rvFucut4wKYbsYc994yR3P0M08NqlsvZxr5G4QK82tw= darvaza.org/cache/x/simplelru v0.1.8 h1:rvFucut4wKYbsYc994yR3P0M08NqlsvZxr5G4QK82tw=
darvaza.org/cache/x/simplelru v0.1.8/go.mod h1:Mv1isOJTcXYK+aK0AvUe+/3KpRTXDsYga6rdTS/upNs= darvaza.org/cache/x/simplelru v0.1.8/go.mod h1:Mv1isOJTcXYK+aK0AvUe+/3KpRTXDsYga6rdTS/upNs=
darvaza.org/core v0.13.1 h1:ZoAfZ3OLnw+t28qMQQxXrDIkETmT2h5gAO6F1XuBpwg= darvaza.org/core v0.13.3 h1:DOsidY49WXsWiJulOIxDq578h/3ekgx0trWxbvgv5bc=
darvaza.org/core v0.13.1/go.mod h1:47Ydh67KnzjLNu1mzX3r2zpphbxQqEaihMsUq5GflQ4= darvaza.org/core v0.13.3/go.mod h1:47Ydh67KnzjLNu1mzX3r2zpphbxQqEaihMsUq5GflQ4=
darvaza.org/resolver v0.9.2 h1:sUX6LZ1eN5TzJW7L4m7HM+BvwBeWl8dYYDGVSe+AIhk= darvaza.org/resolver v0.9.2 h1:sUX6LZ1eN5TzJW7L4m7HM+BvwBeWl8dYYDGVSe+AIhk=
darvaza.org/resolver v0.9.2/go.mod h1:XWqPhrxoOKNzRuSozOwmE1M6QVqQL28jEdxylnIO8Nw= darvaza.org/resolver v0.9.2/go.mod h1:XWqPhrxoOKNzRuSozOwmE1M6QVqQL28jEdxylnIO8Nw=
darvaza.org/sidecar v0.4.0 h1:wHghxzLsiT82WDBBUf34aTqtOvRBg4UbxVIJgKNXRVA= darvaza.org/sidecar v0.4.0 h1:wHghxzLsiT82WDBBUf34aTqtOvRBg4UbxVIJgKNXRVA=

15
pkg/cluster/errors.go

@ -1,6 +1,13 @@
package cluster package cluster
import "errors" import (
"errors"
"io/fs"
"darvaza.org/core"
"git.jpi.io/amery/jpictl/pkg/rings"
)
var ( var (
// ErrInvalidName indicates the name isn't valid // ErrInvalidName indicates the name isn't valid
@ -14,3 +21,9 @@ var (
// the intended purpose // the intended purpose
ErrInvalidNode = errors.New("invalid node") ErrInvalidNode = errors.New("invalid node")
) )
// ErrInvalidRing returns an error indicating the [rings.RingID]
// can't be used for the intended purpose
func ErrInvalidRing(ringID rings.RingID) error {
return core.QuietWrap(fs.ErrInvalid, "invalid ring %v", ringID)
}

6
pkg/cluster/machine.go

@ -53,13 +53,13 @@ func (m *Machine) IsActive() bool {
// IsGateway tells if the Machine is a ring0 gateway // IsGateway tells if the Machine is a ring0 gateway
func (m *Machine) IsGateway() bool { func (m *Machine) IsGateway() bool {
_, ok := m.getRingInfo(0) _, ok := m.getRingInfo(rings.RingZeroID)
return ok return ok
} }
// SetGateway enables/disables a Machine ring0 integration // SetGateway enables/disables a Machine ring0 integration
func (m *Machine) SetGateway(enabled bool) error { func (m *Machine) SetGateway(enabled bool) error {
ri, found := m.getRingInfo(0) ri, found := m.getRingInfo(rings.RingZeroID)
switch { switch {
case !found && !enabled: case !found && !enabled:
return nil return nil
@ -72,7 +72,7 @@ func (m *Machine) SetGateway(enabled bool) error {
} }
ri.Enabled = enabled ri.Enabled = enabled
return m.SyncWireguardConfig(0) return m.SyncWireguardConfig(rings.RingZeroID)
} }
// Zone indicates the [Zone] this machine belongs to // Zone indicates the [Zone] this machine belongs to

93
pkg/cluster/machine_rings.go

@ -13,14 +13,21 @@ import (
) )
// GetWireguardKeys reads a wgN.key/wgN.pub files // GetWireguardKeys reads a wgN.key/wgN.pub files
func (m *Machine) GetWireguardKeys(ring int) (wireguard.KeyPair, error) { func (m *Machine) GetWireguardKeys(ringID rings.RingID) (wireguard.KeyPair, error) {
var ( var (
data []byte data []byte
err error
out wireguard.KeyPair out wireguard.KeyPair
) )
data, err = m.ReadFile("wg%v.key", ring) ring, err := AsWireguardInterfaceID(ringID)
if err != nil {
// invalid ring
return out, err
}
keyFile, pubFile, _ := ring.Files()
data, err = m.ReadFile(keyFile)
if err != nil { if err != nil {
// failed to read // failed to read
return out, err return out, err
@ -29,11 +36,11 @@ func (m *Machine) GetWireguardKeys(ring int) (wireguard.KeyPair, error) {
out.PrivateKey, err = wireguard.PrivateKeyFromBase64(string(data)) out.PrivateKey, err = wireguard.PrivateKeyFromBase64(string(data))
if err != nil { if err != nil {
// bad key // bad key
err = core.Wrap(err, "wg%v.key", ring) err = core.Wrap(err, keyFile)
return out, err return out, err
} }
data, err = m.ReadFile("wg%v.pub", ring) data, err = m.ReadFile(pubFile)
switch { switch {
case os.IsNotExist(err): case os.IsNotExist(err):
// no wgN.pub is fine // no wgN.pub is fine
@ -45,7 +52,7 @@ func (m *Machine) GetWireguardKeys(ring int) (wireguard.KeyPair, error) {
out.PublicKey, err = wireguard.PublicKeyFromBase64(string(data)) out.PublicKey, err = wireguard.PublicKeyFromBase64(string(data))
if err != nil { if err != nil {
// bad key // bad key
err = core.Wrap(err, "wg%v.pub", ring) err = core.Wrap(err, pubFile)
return out, err return out, err
} }
} }
@ -54,8 +61,8 @@ func (m *Machine) GetWireguardKeys(ring int) (wireguard.KeyPair, error) {
return out, err return out, err
} }
func (m *Machine) tryReadWireguardKeys(ring int) error { func (m *Machine) tryReadWireguardKeys(ringID rings.RingID) error {
kp, err := m.GetWireguardKeys(ring) kp, err := m.GetWireguardKeys(ringID)
switch { switch {
case os.IsNotExist(err): case os.IsNotExist(err):
// ignore // ignore
@ -66,20 +73,25 @@ func (m *Machine) tryReadWireguardKeys(ring int) error {
default: default:
// import keys // import keys
ri := &RingInfo{ ri := &RingInfo{
Ring: ring, Ring: MustWireguardInterfaceID(ringID),
Keys: kp, Keys: kp,
} }
return m.applyRingInfo(ring, ri) return m.applyRingInfo(ringID, ri)
} }
} }
// RemoveWireguardKeys deletes wgN.key and wgN.pub from // RemoveWireguardKeys deletes wgN.key and wgN.pub from
// the machine's config directory // the machine's config directory
func (m *Machine) RemoveWireguardKeys(ring int) error { func (m *Machine) RemoveWireguardKeys(ringID rings.RingID) error {
var err error ring, err := AsWireguardInterfaceID(ringID)
if err != nil {
return err
}
err = m.RemoveFile("wg%v.pub", ring) keyFile, pubFile, _ := ring.Files()
err = m.RemoveFile(pubFile)
switch { switch {
case os.IsNotExist(err): case os.IsNotExist(err):
// ignore // ignore
@ -87,7 +99,7 @@ func (m *Machine) RemoveWireguardKeys(ring int) error {
return err return err
} }
err = m.RemoveFile("wg%v.key", ring) err = m.RemoveFile(keyFile)
if os.IsNotExist(err) { if os.IsNotExist(err) {
// ignore // ignore
err = nil err = nil
@ -97,8 +109,13 @@ func (m *Machine) RemoveWireguardKeys(ring int) error {
} }
// GetWireguardConfig reads a wgN.conf file // GetWireguardConfig reads a wgN.conf file
func (m *Machine) GetWireguardConfig(ring int) (*wireguard.Config, error) { func (m *Machine) GetWireguardConfig(ringID rings.RingID) (*wireguard.Config, error) {
data, err := m.ReadFile("wg%v.conf", ring) ring, err := AsWireguardInterfaceID(ringID)
if err != nil {
return nil, err
}
data, err := m.ReadFile(ring.ConfFile())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -107,7 +124,7 @@ func (m *Machine) GetWireguardConfig(ring int) (*wireguard.Config, error) {
return wireguard.NewConfigFromReader(r) return wireguard.NewConfigFromReader(r)
} }
func (m *Machine) tryApplyWireguardConfig(ring int) error { func (m *Machine) tryApplyWireguardConfig(ring rings.RingID) error {
wg, err := m.GetWireguardConfig(ring) wg, err := m.GetWireguardConfig(ring)
switch { switch {
case os.IsNotExist(err): case os.IsNotExist(err):
@ -119,7 +136,7 @@ func (m *Machine) tryApplyWireguardConfig(ring int) error {
} }
} }
func (m *Machine) applyWireguardConfigNode(ring int, wg *wireguard.Config) error { func (m *Machine) applyWireguardConfigNode(ring rings.RingID, wg *wireguard.Config) error {
addr := wg.GetAddress() addr := wg.GetAddress()
if !core.IsZero(addr) { if !core.IsZero(addr) {
zoneID, nodeID, ok := Rings[ring].Decode(addr) zoneID, nodeID, ok := Rings[ring].Decode(addr)
@ -139,7 +156,7 @@ func (m *Machine) applyWireguardConfigNode(ring int, wg *wireguard.Config) error
return nil return nil
} }
func (m *Machine) applyWireguardConfig(ring int, wg *wireguard.Config) error { func (m *Machine) applyWireguardConfig(ring rings.RingID, wg *wireguard.Config) error {
if err := m.applyWireguardConfigNode(ring, wg); err != nil { if err := m.applyWireguardConfigNode(ring, wg); err != nil {
return err return err
} }
@ -153,7 +170,7 @@ func (m *Machine) applyWireguardConfig(ring int, wg *wireguard.Config) error {
WithField("subsystem", "wireguard"). WithField("subsystem", "wireguard").
WithField("node", m.Name). WithField("node", m.Name).
WithField("peer", peer.Endpoint.Host). WithField("peer", peer.Endpoint.Host).
WithField("ring", ring). WithField("ring", MustWireguardInterfaceID(ring)).
Print("ignoring unknown endpoint") Print("ignoring unknown endpoint")
case err != nil: case err != nil:
return core.Wrap(err, "peer") return core.Wrap(err, "peer")
@ -163,9 +180,9 @@ func (m *Machine) applyWireguardConfig(ring int, wg *wireguard.Config) error {
return nil return nil
} }
func (m *Machine) getRingInfo(ring int) (*RingInfo, bool) { func (m *Machine) getRingInfo(ring rings.RingID) (*RingInfo, bool) {
for _, ri := range m.Rings { for _, ri := range m.Rings {
if ri.Ring == ring { if ri.RingID() == ring {
return ri, ri.Enabled return ri, ri.Enabled
} }
} }
@ -173,13 +190,13 @@ func (m *Machine) getRingInfo(ring int) (*RingInfo, bool) {
return nil, false return nil, false
} }
func (m *Machine) applyRingInfo(ring int, new *RingInfo) error { func (m *Machine) applyRingInfo(ring rings.RingID, new *RingInfo) error {
cur, _ := m.getRingInfo(ring) cur, _ := m.getRingInfo(ring)
if cur == nil { if cur == nil {
// first, append // first, append
m.debug(). m.debug().
WithField("node", m.Name). WithField("node", m.Name).
WithField("ring", ring). WithField("ring", MustWireguardInterfaceID(ring)).
Print("found") Print("found")
m.Rings = append(m.Rings, new) m.Rings = append(m.Rings, new)
return nil return nil
@ -189,9 +206,11 @@ func (m *Machine) applyRingInfo(ring int, new *RingInfo) error {
return cur.Merge(new) return cur.Merge(new)
} }
func (m *Machine) applyWireguardInterfaceConfig(ring int, data wireguard.InterfaceConfig) error { func (m *Machine) applyWireguardInterfaceConfig(ring rings.RingID,
data wireguard.InterfaceConfig) error {
//
ri := &RingInfo{ ri := &RingInfo{
Ring: ring, Ring: MustWireguardInterfaceID(ring),
Enabled: true, Enabled: true,
Keys: wireguard.KeyPair{ Keys: wireguard.KeyPair{
PrivateKey: data.PrivateKey, PrivateKey: data.PrivateKey,
@ -201,7 +220,7 @@ func (m *Machine) applyWireguardInterfaceConfig(ring int, data wireguard.Interfa
return m.applyRingInfo(ring, ri) return m.applyRingInfo(ring, ri)
} }
func (m *Machine) applyWireguardPeerConfig(ring int, pc wireguard.PeerConfig) error { func (m *Machine) applyWireguardPeerConfig(ring rings.RingID, pc wireguard.PeerConfig) error {
peer, found := m.getPeerByName(pc.Endpoint.Name()) peer, found := m.getPeerByName(pc.Endpoint.Name())
switch { switch {
case !found: case !found:
@ -213,7 +232,7 @@ func (m *Machine) applyWireguardPeerConfig(ring int, pc wireguard.PeerConfig) er
default: default:
// apply RingInfo // apply RingInfo
ri := &RingInfo{ ri := &RingInfo{
Ring: ring, Ring: MustWireguardInterfaceID(ring),
Enabled: true, Enabled: true,
Keys: wireguard.KeyPair{ Keys: wireguard.KeyPair{
PublicKey: pc.PublicKey, PublicKey: pc.PublicKey,
@ -260,8 +279,13 @@ func (m *Machine) setRingDefaults(ri *RingInfo) error {
// RemoveWireguardConfig deletes wgN.conf from the machine's // RemoveWireguardConfig deletes wgN.conf from the machine's
// config directory. // config directory.
func (m *Machine) RemoveWireguardConfig(ring int) error { func (m *Machine) RemoveWireguardConfig(ringID rings.RingID) error {
err := m.RemoveFile("wg%v.conf", ring) ring, err := AsWireguardInterfaceID(ringID)
if err != nil {
return err
}
err = m.RemoveFile(ring.ConfFile())
if os.IsNotExist(err) { if os.IsNotExist(err) {
err = nil err = nil
} }
@ -269,7 +293,12 @@ func (m *Machine) RemoveWireguardConfig(ring int) error {
return err return err
} }
func (m *Machine) createRingInfo(ring int, enabled bool) (*RingInfo, error) { func (m *Machine) createRingInfo(ringID rings.RingID, enabled bool) (*RingInfo, error) {
ring, err := AsWireguardInterfaceID(ringID)
if err != nil {
return nil, err
}
keys, err := wireguard.NewKeyPair() keys, err := wireguard.NewKeyPair()
if err != nil { if err != nil {
return nil, err return nil, err
@ -281,7 +310,7 @@ func (m *Machine) createRingInfo(ring int, enabled bool) (*RingInfo, error) {
Keys: keys, Keys: keys,
} }
err = m.applyRingInfo(ring, ri) err = m.applyRingInfo(ringID, ri)
if err != nil { if err != nil {
return nil, err return nil, err
} }

11
pkg/cluster/machine_scan.go

@ -9,6 +9,7 @@ import (
"time" "time"
"darvaza.org/core" "darvaza.org/core"
"git.jpi.io/amery/jpictl/pkg/rings" "git.jpi.io/amery/jpictl/pkg/rings"
) )
@ -38,8 +39,8 @@ func (m *Machine) init() error {
return core.Wrap(err, m.Name) return core.Wrap(err, m.Name)
} }
for i := 0; i < RingsCount; i++ { for _, ring := range Rings {
if err := m.tryReadWireguardKeys(i); err != nil { if err := m.tryReadWireguardKeys(ring.ID); err != nil {
return core.Wrap(err, m.Name) return core.Wrap(err, m.Name)
} }
} }
@ -72,12 +73,12 @@ func (m *Machine) setID() error {
// scan is called once we know about all zones and machine names // scan is called once we know about all zones and machine names
func (m *Machine) scan(_ *ScanOptions) error { func (m *Machine) scan(_ *ScanOptions) error {
for i := 0; i < RingsCount; i++ { for _, ring := range Rings {
if err := m.tryApplyWireguardConfig(i); err != nil { if err := m.tryApplyWireguardConfig(ring.ID); err != nil {
m.error(err). m.error(err).
WithField("subsystem", "wireguard"). WithField("subsystem", "wireguard").
WithField("node", m.Name). WithField("node", m.Name).
WithField("ring", i). WithField("ring", MustWireguardInterfaceID(ring.ID)).
Print() Print()
return err return err
} }

90
pkg/cluster/rings.go

@ -4,28 +4,86 @@ import (
"fmt" "fmt"
"io/fs" "io/fs"
"net/netip" "net/netip"
"strconv"
"git.jpi.io/amery/jpictl/pkg/rings" "git.jpi.io/amery/jpictl/pkg/rings"
"git.jpi.io/amery/jpictl/pkg/wireguard" "git.jpi.io/amery/jpictl/pkg/wireguard"
) )
const ( const (
// RingsCount indicates how many wireguard rings we have
RingsCount = 2
// RingZeroPort is the port wireguard uses for ring0 // RingZeroPort is the port wireguard uses for ring0
RingZeroPort = 51800 RingZeroPort = 51800
// RingOnePort is the port wireguard uses for ring1 // RingOnePort is the port wireguard uses for ring1
RingOnePort = 51810 RingOnePort = 51810
) )
// WireguardInterfaceID represents the number in the `wg%v`
// interface name.
type WireguardInterfaceID uint
// AsWireguardInterfaceID returns the [WireguardInterfaceID] for
// a valid [rings.RingID].
func AsWireguardInterfaceID(ring rings.RingID) (WireguardInterfaceID, error) {
switch ring {
case rings.RingZeroID:
return 0, nil
case rings.RingOneID:
return 1, nil
default:
return 0, ErrInvalidRing(ring)
}
}
// MustWireguardInterfaceID returns the [WireguardInterfaceID] for
// a valid [rings.RingID], and panics if it's not.
func MustWireguardInterfaceID(ring rings.RingID) WireguardInterfaceID {
id, err := AsWireguardInterfaceID(ring)
if err != nil {
panic(err)
}
return id
}
// RingID tells the [rings.RingID] of the [WireguardInterfaceID].
func (wi WireguardInterfaceID) RingID() rings.RingID {
return rings.RingID(wi + 1)
}
// PubFile returns "wgN.pub"
func (wi WireguardInterfaceID) PubFile() string {
return fmt.Sprintf("wg%v.pub", wi)
}
// KeyFile returns "wgN.key"
func (wi WireguardInterfaceID) KeyFile() string {
return fmt.Sprintf("wg%v.key", wi)
}
// ConfFile returns "wgN.conf"
func (wi WireguardInterfaceID) ConfFile() string {
return fmt.Sprintf("wg%v.conf", wi)
}
// Files returns all wgN.ext file names.
func (wi WireguardInterfaceID) Files() (keyFile, pubFile, confFile string) {
prefix := "wg" + strconv.Itoa(int(wi))
return prefix + ".key", prefix + ".pub", prefix + ".conf"
}
// RingInfo contains represents the Wireguard endpoint details // RingInfo contains represents the Wireguard endpoint details
// for a Machine on a particular ring // for a Machine on a particular ring
type RingInfo struct { type RingInfo struct {
Ring int Ring WireguardInterfaceID
Enabled bool Enabled bool
Keys wireguard.KeyPair Keys wireguard.KeyPair
} }
// RingID returns the [rings.RingID] for this [RingInfo].
func (ri *RingInfo) RingID() rings.RingID {
return rings.RingID(ri.Ring + 1)
}
// Merge attempts to combine two RingInfo structs // Merge attempts to combine two RingInfo structs
func (ri *RingInfo) Merge(alter *RingInfo) error { func (ri *RingInfo) Merge(alter *RingInfo) error {
switch { switch {
@ -76,7 +134,7 @@ func canMergeKeyPairs(p1, p2 wireguard.KeyPair) bool {
// RingAddressEncoder provides encoder/decoder access for a particular // RingAddressEncoder provides encoder/decoder access for a particular
// Wireguard ring // Wireguard ring
type RingAddressEncoder struct { type RingAddressEncoder struct {
ID int ID rings.RingID
Port uint16 Port uint16
Encode func(zoneID rings.ZoneID, nodeID rings.NodeID) (netip.Addr, bool) Encode func(zoneID rings.ZoneID, nodeID rings.NodeID) (netip.Addr, bool)
Decode func(addr netip.Addr) (zoneID rings.ZoneID, nodeID rings.NodeID, ok bool) Decode func(addr netip.Addr) (zoneID rings.ZoneID, nodeID rings.NodeID, ok bool)
@ -85,20 +143,20 @@ type RingAddressEncoder struct {
var ( var (
// RingZero is a wg0 address encoder/decoder // RingZero is a wg0 address encoder/decoder
RingZero = RingAddressEncoder{ RingZero = RingAddressEncoder{
ID: 0, ID: rings.RingZeroID,
Port: RingZeroPort, Port: RingZeroPort,
Decode: ParseRingZeroAddress, Decode: ParseRingZeroAddress,
Encode: RingZeroAddress, Encode: RingZeroAddress,
} }
// RingOne is a wg1 address encoder/decoder // RingOne is a wg1 address encoder/decoder
RingOne = RingAddressEncoder{ RingOne = RingAddressEncoder{
ID: 1, ID: rings.RingOneID,
Port: RingOnePort, Port: RingOnePort,
Decode: ParseRingOneAddress, Decode: ParseRingOneAddress,
Encode: RingOneAddress, Encode: RingOneAddress,
} }
// Rings provides indexed access to the ring address encoders // Rings provides indexed access to the ring address encoders
Rings = [RingsCount]RingAddressEncoder{ Rings = []RingAddressEncoder{
RingZero, RingZero,
RingOne, RingOne,
} }
@ -199,7 +257,7 @@ func (r *Ring) AddPeer(p *Machine) bool {
Address: addr, Address: addr,
PrivateKey: ri.Keys.PrivateKey, PrivateKey: ri.Keys.PrivateKey,
PeerConfig: wireguard.PeerConfig{ PeerConfig: wireguard.PeerConfig{
Name: fmt.Sprintf("%s-%v", p.Name, r.ID), Name: fmt.Sprintf("%s-%v", p.Name, ri.Ring),
PublicKey: ri.Keys.PublicKey, PublicKey: ri.Keys.PublicKey,
Endpoint: wireguard.EndpointAddress{ Endpoint: wireguard.EndpointAddress{
Host: p.FullName(), Host: p.FullName(),
@ -323,11 +381,21 @@ func (rp *RingPeer) AllowCIDR(addr netip.Addr, bits int) {
} }
// NewRing composes a new Ring for Wireguard setup // NewRing composes a new Ring for Wireguard setup
func NewRing(z ZoneIterator, m MachineIterator, ring int) (*Ring, error) { func NewRing(z ZoneIterator, m MachineIterator, ringID rings.RingID) (*Ring, error) {
r := &Ring{ var r *Ring
RingAddressEncoder: Rings[ring], for _, ring := range Rings {
if ringID == ring.ID {
r = &Ring{
RingAddressEncoder: ring,
ZoneIterator: z, ZoneIterator: z,
} }
break
}
}
if r == nil {
return nil, ErrInvalidRing(ringID)
}
m.ForEachMachine(func(p *Machine) bool { m.ForEachMachine(func(p *Machine) bool {
r.AddPeer(p) r.AddPeer(p)

6
pkg/cluster/sync.go

@ -35,13 +35,13 @@ func (m *Cluster) SyncMkdirAll() error {
func (m *Cluster) SyncAllWireguard() error { func (m *Cluster) SyncAllWireguard() error {
var err error var err error
for ring := 0; ring < RingsCount; ring++ { for _, ring := range Rings {
err = m.WriteWireguardKeys(ring) err = m.WriteWireguardKeys(ring.ID)
if err != nil { if err != nil {
return err return err
} }
err = m.SyncWireguardConfig(ring) err = m.SyncWireguardConfig(ring.ID)
if err != nil { if err != nil {
return err return err
} }

80
pkg/cluster/wireguard.go

@ -3,6 +3,8 @@ package cluster
import ( import (
"io/fs" "io/fs"
"os" "os"
"git.jpi.io/amery/jpictl/pkg/rings"
) )
var ( var (
@ -26,22 +28,22 @@ var (
// A WireguardConfigPruner deletes wgN.conf on all machines under // A WireguardConfigPruner deletes wgN.conf on all machines under
// its scope with the specified ring disabled // its scope with the specified ring disabled
type WireguardConfigPruner interface { type WireguardConfigPruner interface {
PruneWireguardConfig(ring int) error PruneWireguardConfig(ring rings.RingID) error
} }
// PruneWireguardConfig removes wgN.conf files of machines with // PruneWireguardConfig removes wgN.conf files of machines with
// the corresponding ring disabled on all zones // the corresponding ring disabled on all zones
func (m *Cluster) PruneWireguardConfig(ring int) error { func (m *Cluster) PruneWireguardConfig(ring rings.RingID) error {
return pruneWireguardConfig(m, ring) return pruneWireguardConfig(m, ring)
} }
// PruneWireguardConfig removes wgN.conf files of machines with // PruneWireguardConfig removes wgN.conf files of machines with
// the corresponding ring disabled. // the corresponding ring disabled.
func (z *Zone) PruneWireguardConfig(ring int) error { func (z *Zone) PruneWireguardConfig(ring rings.RingID) error {
return pruneWireguardConfig(z, ring) return pruneWireguardConfig(z, ring)
} }
func pruneWireguardConfig(m MachineIterator, ring int) error { func pruneWireguardConfig(m MachineIterator, ring rings.RingID) error {
var err error var err error
m.ForEachMachine(func(p *Machine) bool { m.ForEachMachine(func(p *Machine) bool {
@ -59,7 +61,7 @@ func pruneWireguardConfig(m MachineIterator, ring int) error {
// PruneWireguardConfig deletes the wgN.conf file if its // PruneWireguardConfig deletes the wgN.conf file if its
// presence on the ring is disabled // presence on the ring is disabled
func (m *Machine) PruneWireguardConfig(ring int) error { func (m *Machine) PruneWireguardConfig(ring rings.RingID) error {
_, ok := m.getRingInfo(ring) _, ok := m.getRingInfo(ring)
if !ok { if !ok {
return m.RemoveWireguardConfig(ring) return m.RemoveWireguardConfig(ring)
@ -71,16 +73,16 @@ func (m *Machine) PruneWireguardConfig(ring int) error {
// A WireguardConfigWriter rewrites all wgN.conf on all machines under // A WireguardConfigWriter rewrites all wgN.conf on all machines under
// its scope attached to that ring // its scope attached to that ring
type WireguardConfigWriter interface { type WireguardConfigWriter interface {
WriteWireguardConfig(ring int) error WriteWireguardConfig(ring rings.RingID) error
} }
// WriteWireguardConfig rewrites all wgN.conf on all machines // WriteWireguardConfig rewrites all wgN.conf on all machines
// attached to that ring // attached to that ring
func (m *Cluster) WriteWireguardConfig(ring int) error { func (m *Cluster) WriteWireguardConfig(ring rings.RingID) error {
switch ring { switch ring {
case 0: case rings.RingZeroID:
return writeWireguardConfig(m, m, ring) return writeWireguardConfig(m, m, ring)
case 1: case rings.RingOneID:
var err error var err error
m.ForEachZone(func(z *Zone) bool { m.ForEachZone(func(z *Zone) bool {
err = writeWireguardConfig(m, z, ring) err = writeWireguardConfig(m, z, ring)
@ -88,24 +90,24 @@ func (m *Cluster) WriteWireguardConfig(ring int) error {
}) })
return err return err
default: default:
return fs.ErrInvalid return ErrInvalidRing(ring)
} }
} }
// WriteWireguardConfig rewrites all wgN.conf on all machines // WriteWireguardConfig rewrites all wgN.conf on all machines
// on the Zone attached to that ring // on the Zone attached to that ring
func (z *Zone) WriteWireguardConfig(ring int) error { func (z *Zone) WriteWireguardConfig(ring rings.RingID) error {
switch ring { switch ring {
case 0: case rings.RingZeroID:
return writeWireguardConfig(z.zones, z.zones, ring) return writeWireguardConfig(z.zones, z.zones, ring)
case 1: case rings.RingOneID:
return writeWireguardConfig(z.zones, z, ring) return writeWireguardConfig(z.zones, z, ring)
default: default:
return fs.ErrInvalid return ErrInvalidRing(ring)
} }
} }
func writeWireguardConfig(z ZoneIterator, m MachineIterator, ring int) error { func writeWireguardConfig(z ZoneIterator, m MachineIterator, ring rings.RingID) error {
r, err := NewRing(z, m, ring) r, err := NewRing(z, m, ring)
if err != nil { if err != nil {
return err return err
@ -121,7 +123,7 @@ func writeWireguardConfig(z ZoneIterator, m MachineIterator, ring int) error {
// WriteWireguardConfig rewrites the wgN.conf file of this Machine // WriteWireguardConfig rewrites the wgN.conf file of this Machine
// if enabled // if enabled
func (m *Machine) WriteWireguardConfig(ring int) error { func (m *Machine) WriteWireguardConfig(ring rings.RingID) error {
r, err := NewRing(m.zone.zones, m.zone, ring) r, err := NewRing(m.zone.zones, m.zone, ring)
if err != nil { if err != nil {
return err return err
@ -131,12 +133,17 @@ func (m *Machine) WriteWireguardConfig(ring int) error {
} }
func (m *Machine) writeWireguardRingConfig(r *Ring) error { func (m *Machine) writeWireguardRingConfig(r *Ring) error {
ring, err := AsWireguardInterfaceID(r.ID)
if err != nil {
return err
}
wg, err := r.ExportConfig(m) wg, err := r.ExportConfig(m)
if err != nil { if err != nil {
return nil return nil
} }
f, err := m.CreateTruncFile("wg%v.conf", r.ID) f, err := m.CreateTruncFile(ring.ConfFile())
if err != nil { if err != nil {
return err return err
} }
@ -149,16 +156,16 @@ func (m *Machine) writeWireguardRingConfig(r *Ring) error {
// A WireguardConfigSyncer updates all wgN.conf on all machines under // A WireguardConfigSyncer updates all wgN.conf on all machines under
// its scope reflecting the state of the ring // its scope reflecting the state of the ring
type WireguardConfigSyncer interface { type WireguardConfigSyncer interface {
SyncWireguardConfig(ring int) error SyncWireguardConfig(ring rings.RingID) error
} }
// SyncWireguardConfig updates all wgN.conf files for the specified // SyncWireguardConfig updates all wgN.conf files for the specified
// ring // ring
func (m *Cluster) SyncWireguardConfig(ring int) error { func (m *Cluster) SyncWireguardConfig(ring rings.RingID) error {
switch ring { switch ring {
case 0: case rings.RingZeroID:
return syncWireguardConfig(m, m, ring) return syncWireguardConfig(m, m, ring)
case 1: case rings.RingOneID:
var err error var err error
m.ForEachZone(func(z *Zone) bool { m.ForEachZone(func(z *Zone) bool {
err = syncWireguardConfig(m, z, ring) err = syncWireguardConfig(m, z, ring)
@ -166,24 +173,24 @@ func (m *Cluster) SyncWireguardConfig(ring int) error {
}) })
return err return err
default: default:
return fs.ErrInvalid return ErrInvalidRing(ring)
} }
} }
// SyncWireguardConfig updates all wgN.conf files for the specified // SyncWireguardConfig updates all wgN.conf files for the specified
// ring // ring
func (z *Zone) SyncWireguardConfig(ring int) error { func (z *Zone) SyncWireguardConfig(ring rings.RingID) error {
switch ring { switch ring {
case 0: case rings.RingZeroID:
return syncWireguardConfig(z.zones, z.zones, ring) return syncWireguardConfig(z.zones, z.zones, ring)
case 1: case rings.RingOneID:
return syncWireguardConfig(z.zones, z, ring) return syncWireguardConfig(z.zones, z, ring)
default: default:
return fs.ErrInvalid return ErrInvalidRing(ring)
} }
} }
func syncWireguardConfig(z ZoneIterator, m MachineIterator, ring int) error { func syncWireguardConfig(z ZoneIterator, m MachineIterator, ring rings.RingID) error {
r, err := NewRing(z, m, ring) r, err := NewRing(z, m, ring)
if err != nil { if err != nil {
return err return err
@ -203,27 +210,27 @@ func syncWireguardConfig(z ZoneIterator, m MachineIterator, ring int) error {
// SyncWireguardConfig updates all wgN.conf files for the specified // SyncWireguardConfig updates all wgN.conf files for the specified
// ring // ring
func (m *Machine) SyncWireguardConfig(ring int) error { func (m *Machine) SyncWireguardConfig(ring rings.RingID) error {
return m.zone.SyncWireguardConfig(ring) return m.zone.SyncWireguardConfig(ring)
} }
// A WireguardKeysWriter writes the Wireguard Keys for all machines // A WireguardKeysWriter writes the Wireguard Keys for all machines
// under its scope for the specified ring // under its scope for the specified ring
type WireguardKeysWriter interface { type WireguardKeysWriter interface {
WriteWireguardKeys(ring int) error WriteWireguardKeys(ring rings.RingID) error
} }
// WriteWireguardKeys rewrites all wgN.{key,pub} files // WriteWireguardKeys rewrites all wgN.{key,pub} files
func (m *Cluster) WriteWireguardKeys(ring int) error { func (m *Cluster) WriteWireguardKeys(ring rings.RingID) error {
return writeWireguardKeys(m, ring) return writeWireguardKeys(m, ring)
} }
// WriteWireguardKeys rewrites all wgN.{key,pub} files on this zone // WriteWireguardKeys rewrites all wgN.{key,pub} files on this zone
func (z *Zone) WriteWireguardKeys(ring int) error { func (z *Zone) WriteWireguardKeys(ring rings.RingID) error {
return writeWireguardKeys(z, ring) return writeWireguardKeys(z, ring)
} }
func writeWireguardKeys(m MachineIterator, ring int) error { func writeWireguardKeys(m MachineIterator, ring rings.RingID) error {
var err error var err error
m.ForEachMachine(func(p *Machine) bool { m.ForEachMachine(func(p *Machine) bool {
@ -240,12 +247,12 @@ func writeWireguardKeys(m MachineIterator, ring int) error {
} }
// WriteWireguardKeys writes the wgN.key/wgN.pub files // WriteWireguardKeys writes the wgN.key/wgN.pub files
func (m *Machine) WriteWireguardKeys(ring int) error { func (m *Machine) WriteWireguardKeys(ringID rings.RingID) error {
var err error var err error
var key, pub string var key, pub string
var ri *RingInfo var ri *RingInfo
ri, _ = m.getRingInfo(ring) ri, _ = m.getRingInfo(ringID)
if ri != nil { if ri != nil {
key = ri.Keys.PrivateKey.String() key = ri.Keys.PrivateKey.String()
pub = ri.Keys.PublicKey.String() pub = ri.Keys.PublicKey.String()
@ -258,12 +265,13 @@ func (m *Machine) WriteWireguardKeys(ring int) error {
pub = ri.Keys.PrivateKey.Public().String() pub = ri.Keys.PrivateKey.Public().String()
} }
err = m.WriteStringFile(key+"\n", "wg%v.key", ring) keyFile, pubFile, _ := ri.Ring.Files()
err = m.WriteStringFile(key+"\n", keyFile)
if err != nil { if err != nil {
return err return err
} }
err = m.WriteStringFile(pub+"\n", "wg%v.pub", ring) err = m.WriteStringFile(pub+"\n", pubFile)
if err != nil { if err != nil {
return err return err
} }

Loading…
Cancel
Save