Compare commits

...

23 Commits

Author SHA1 Message Date
amery c81b782b26 zones: Machine.IsGateway()
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-24 21:07:55 +00:00
amery 0f62ee2e53 zones: rename Machine.RingAddresses to Machine.Rings
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-24 20:52:08 +00:00
amery 30a7bceda3 wireguard: make KeyPairs solid
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-24 19:49:59 +00:00
amery 60e2687d04 wireguard: make keys arrays instead of slices
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-24 19:35:11 +00:00
amery 1419e55d5b zones: remove useless RingInfo.Address
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-24 19:06:33 +00:00
amery ffdacb833b zones: add Port information to RingAddressEncoder
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-24 17:56:19 +00:00
amery aca0a5e834 zones: calculate Machine.ID on init
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-24 17:55:43 +00:00
amery 61374d4cc5 zones: load wireguard key pairs on Machine.init()
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-24 14:35:40 +00:00
amery 975e166da7 zones: allow RingInfo.Merge() to enable, but not disable
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-24 14:34:00 +00:00
amery b16c648f2c zones: introduce Machine.GetWireguardKeys()
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-23 21:03:03 +00:00
amery 47d79f7576 wireguard: introduce KeyPair.Validate()
it will also set the PublicKey field is empty

Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-23 00:29:15 +00:00
amery e2f831fd6a wireguard: introduce NewKeyPair, NewPrivateKey, and PrivateKey.Public()
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-23 00:18:31 +00:00
amery 1d8c818ec4 wireguard: make PrivateKey and PublicKey two distinct types
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-22 23:46:24 +00:00
amery 2f51a463b2 zones: reduce writeEnvZone() complexity
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-22 22:53:13 +00:00
amery 0c0cba6fb5 jpictl: introduce env command
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-22 22:27:11 +00:00
amery 75206e4fa5 zones: Zones.WriteEnv() writing env variables describing the cluster
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-22 22:26:48 +00:00
amery b084e103b9 zones: introduce Machine.getRingInfo()
and refactor Machine.applyRingInfo()

Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-22 22:25:11 +00:00
amery 223edf846b zones: introduce Zone.ForEachMachine()
and refactor Zones.ForEachMachine() using it

Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-22 22:25:09 +00:00
amery fdb0f0324f zones: finish scan sorting the content
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-22 21:29:12 +00:00
amery 9aef92f32d zones: assign zoneID to zones inferable ID
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-22 21:02:21 +00:00
amery e5baf53758 zones: import wireguard keys from wgN.conf files
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-22 20:37:10 +00:00
amery 0fe451eed0 zones: introduce RingInfo.Merge()
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-22 20:32:29 +00:00
amery cb5ea80e66 zones: introduce Zones.GetMachineByName()
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-22 20:32:26 +00:00
12 changed files with 648 additions and 61 deletions
+27
View File
@@ -0,0 +1,27 @@
package main
import (
"os"
"github.com/spf13/cobra"
"git.jpi.io/amery/jpictl/pkg/zones"
)
// Command
var envCmd = &cobra.Command{
Use: "env",
Short: "generates environment variables for shell scripts",
RunE: func(_ *cobra.Command, _ []string) error {
m, err := zones.New(cfg.Base, cfg.Domain)
if err != nil {
return err
}
return m.WriteEnv(os.Stdout)
},
}
func init() {
rootCmd.AddCommand(envCmd)
}
+1
View File
@@ -10,6 +10,7 @@ require (
github.com/burntSushi/toml v0.3.1 github.com/burntSushi/toml v0.3.1
github.com/mgechev/revive v1.3.2 github.com/mgechev/revive v1.3.2
github.com/spf13/cobra v1.7.0 github.com/spf13/cobra v1.7.0
golang.org/x/crypto v0.12.0
gopkg.in/gcfg.v1 v1.2.3 gopkg.in/gcfg.v1 v1.2.3
) )
+2
View File
@@ -70,6 +70,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 h1:MGwJjxBy0HJshjDNfLsYO8xppfqWlA5ZT9OhtUUhTNw= golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 h1:MGwJjxBy0HJshjDNfLsYO8xppfqWlA5ZT9OhtUUhTNw=
golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
+4 -4
View File
@@ -31,13 +31,13 @@ func (f *Config) Peers() int {
// InterfaceConfig represents the [Interface] section // InterfaceConfig represents the [Interface] section
type InterfaceConfig struct { type InterfaceConfig struct {
Address netip.Addr Address netip.Addr
PrivateKey BinaryKey PrivateKey PrivateKey
ListenPort uint16 ListenPort uint16
} }
// PeerConfig represents a [Peer] section // PeerConfig represents a [Peer] section
type PeerConfig struct { type PeerConfig struct {
PublicKey BinaryKey PublicKey PublicKey
Endpoint EndpointAddress Endpoint EndpointAddress
AllowedIPs []netip.Prefix AllowedIPs []netip.Prefix
} }
@@ -135,7 +135,7 @@ func (p interfaceConfig) Export() (InterfaceConfig, error) {
ListenPort: p.ListenPort, ListenPort: p.ListenPort,
} }
out.PrivateKey, err = BinaryKeyFromBase64(p.PrivateKey) out.PrivateKey, err = PrivateKeyFromBase64(p.PrivateKey)
if err != nil { if err != nil {
err = core.Wrap(err, "PrivateKey") err = core.Wrap(err, "PrivateKey")
return InterfaceConfig{}, err return InterfaceConfig{}, err
@@ -162,7 +162,7 @@ func (v *intermediateConfig) ExportPeer(i int) (PeerConfig, error) {
} }
// PublicKey // PublicKey
out.PublicKey, err = BinaryKeyFromBase64(v.Peer.PublicKey[i]) out.PublicKey, err = PublicKeyFromBase64(v.Peer.PublicKey[i])
if err != nil { if err != nil {
err = core.Wrap(err, "PublicKey") err = core.Wrap(err, "PublicKey")
return out, err return out, err
+160 -15
View File
@@ -2,35 +2,180 @@ package wireguard
import ( import (
"bytes" "bytes"
"crypto/rand"
"encoding/base64" "encoding/base64"
"errors"
"golang.org/x/crypto/curve25519"
) )
// BinaryKey is a binary blob const (
type BinaryKey []byte // PrivateKeySize is the length in bytes of a Wireguard Private Key
PrivateKeySize = 32
// PublicKeySize is the length in bytes of a Wireguard Public Key
PublicKeySize = 32
)
func (k BinaryKey) String() string { var (
return base64.StdEncoding.EncodeToString(k) // ErrInvalidKeySize indicates the key size is wrong
ErrInvalidKeySize = errors.New("invalid key size")
// ErrInvalidPrivateKey indicates the private key is invalid
ErrInvalidPrivateKey = errors.New("invalid private key")
// ErrInvalidPublicKey indicates the public key is invalid
ErrInvalidPublicKey = errors.New("invalid public key")
)
type (
// PrivateKey is a binary Wireguard Private Key
PrivateKey [PrivateKeySize]byte
// PublicKey is a binary Wireguard Public Key
PublicKey [PublicKeySize]byte
)
func (key PrivateKey) String() string {
switch {
case key.IsZero():
return ""
default:
return base64.StdEncoding.EncodeToString(key[:])
}
}
func (pub PublicKey) String() string {
switch {
case pub.IsZero():
return ""
default:
return base64.StdEncoding.EncodeToString(pub[:])
}
} }
// IsZero tells if the key hasn't been set // IsZero tells if the key hasn't been set
func (k BinaryKey) IsZero() bool { func (key PrivateKey) IsZero() bool {
return len(k) == 0 var zero PrivateKey
return key.Equal(zero)
} }
// Equal checks if two keys are identical // IsZero tells if the key hasn't been set
func (k BinaryKey) Equal(alter BinaryKey) bool { func (pub PublicKey) IsZero() bool {
return bytes.Equal(k, alter) var zero PublicKey
return pub.Equal(zero)
} }
// BinaryKeyFromBase64 decodes a base64-based string into // Equal checks if two private keys are identical
// a [BinaryKey] func (key PrivateKey) Equal(alter PrivateKey) bool {
func BinaryKeyFromBase64(data string) (BinaryKey, error) { return bytes.Equal(key[:], alter[:])
}
// Equal checks if two public keys are identical
func (pub PublicKey) Equal(alter PublicKey) bool {
return bytes.Equal(pub[:], alter[:])
}
// PrivateKeyFromBase64 decodes a base64-based string into
// a [PrivateKey]
func PrivateKeyFromBase64(data string) (PrivateKey, error) {
b, err := decodeKey(data, PrivateKeySize)
if err != nil {
var zero PrivateKey
return zero, err
}
return *(*[PrivateKeySize]byte)(b), nil
}
// PublicKeyFromBase64 decodes a base64-based string into
// a [PublicKey]
func PublicKeyFromBase64(data string) (PublicKey, error) {
b, err := decodeKey(data, PublicKeySize)
if err != nil {
var zero PublicKey
return zero, err
}
return *(*[PublicKeySize]byte)(b), nil
}
func decodeKey(data string, size int) ([]byte, error) {
b, err := base64.StdEncoding.DecodeString(data) b, err := base64.StdEncoding.DecodeString(data)
return BinaryKey(b), err switch {
case err != nil:
return []byte{}, err
case len(b) != size:
err = ErrInvalidKeySize
return []byte{}, err
default:
return b, nil
}
}
// NewPrivateKey creates a new PrivateKey
func NewPrivateKey() (PrivateKey, error) {
var s [PrivateKeySize]byte
_, err := rand.Read(s[:])
if err != nil {
var zero PrivateKey
return zero, err
}
// apply same clamping as wireguard-go/device/noise-helpers.go
s[0] &= 0xf8
s[31] = (s[31] & 0x7f) | 0x40
return s, nil
}
// Public generates the corresponding PublicKey
func (key PrivateKey) Public() PublicKey {
var pub PublicKey
if !key.IsZero() {
in := (*[PrivateKeySize]byte)(&key)
out := (*[PublicKeySize]byte)(&pub)
curve25519.ScalarBaseMult(out, in)
}
return pub
} }
// KeyPair holds a Key pair // KeyPair holds a Key pair
type KeyPair struct { type KeyPair struct {
PrivateKey BinaryKey PrivateKey PrivateKey
PublicKey BinaryKey PublicKey PublicKey
}
// Validate checks the PublicKey matches the PrivateKey,
// and sets the PublicKey if missing
func (kp *KeyPair) Validate() error {
keyLen := len(kp.PrivateKey)
pubLen := len(kp.PublicKey)
switch {
case keyLen != PrivateKeySize:
// bad private key
return ErrInvalidPrivateKey
case pubLen == 0:
// no public key, set it
kp.PublicKey = kp.PrivateKey.Public()
return nil
case pubLen != PublicKeySize:
// bad public key
return ErrInvalidPublicKey
case !kp.PrivateKey.Public().Equal(kp.PublicKey):
// wrong public key
return ErrInvalidPublicKey
default:
// correct public key
return nil
}
}
// NewKeyPair creates a new KeyPair for Wireguard
func NewKeyPair() (KeyPair, error) {
var out KeyPair
key, err := NewPrivateKey()
if err == nil {
out.PrivateKey = key
out.PublicKey = key.Public()
}
return out, nil
} }
+102
View File
@@ -0,0 +1,102 @@
package zones
import (
"bytes"
"fmt"
"io"
"strings"
)
// WriteEnv generates environment variables for shell scripts
func (m *Zones) WriteEnv(w io.Writer) error {
var buf bytes.Buffer
m.writeEnvVarFn(&buf, genEnvZones, "ZONES")
m.ForEachZone(func(z *Zone) bool {
m.writeEnvZone(&buf, z)
return false
})
_, err := buf.WriteTo(w)
return err
}
func (m *Zones) writeEnvZone(w io.Writer, z *Zone) {
zoneID := z.ID
// ZONE{zoneID}
m.writeEnvVar(w, genEnvZoneNodes(z), "ZONE%v", zoneID)
// ZONE{zoneID}_NAME
m.writeEnvVar(w, z.Name, "ZONE%v_%s", zoneID, "NAME")
// ZONE{zoneID}_GW
gatewayID := getRingZeroGatewayID(z)
m.writeEnvVar(w, fmt.Sprintf("%v", gatewayID), "ZONE%v_%s", zoneID, "GW")
// ZONE{zoneID}_IP
ip, _ := RingZeroAddress(zoneID, gatewayID)
m.writeEnvVar(w, ip.String(), "ZONE%v_%s", zoneID, "IP")
}
func (m *Zones) writeEnvVarFn(w io.Writer, fn func(*Zones) string, name string, args ...any) {
var value string
if fn != nil {
value = fn(m)
}
m.writeEnvVar(w, value, name, args...)
}
func (*Zones) writeEnvVar(w io.Writer, value string, name string, args ...any) {
if len(args) > 0 {
name = fmt.Sprintf(name, args...)
}
if name != "" {
value = strings.TrimSpace(value)
_, _ = fmt.Fprintf(w, "%s=%q\n", name, value)
}
}
func genEnvZones(m *Zones) string {
s := make([]string, 0, len(m.Zones))
for _, z := range m.Zones {
s = append(s, fmt.Sprintf("%v", z.ID))
}
return strings.Join(s, " ")
}
func genEnvZoneNodes(z *Zone) string {
s := make([]string, 0, len(z.Machines))
for _, p := range z.Machines {
s = append(s, p.Name)
}
return strings.Join(s, " ")
}
func getRingZeroGatewayID(z *Zone) int {
var firstNodeID, gatewayID int
z.ForEachMachine(func(p *Machine) bool {
if firstNodeID == 0 {
firstNodeID = p.ID
}
if p.IsGateway() {
gatewayID = p.ID
}
return gatewayID != 0
})
switch {
case gatewayID == 0:
return firstNodeID
default:
return gatewayID
}
}
+12 -27
View File
@@ -5,48 +5,23 @@ import (
"io/fs" "io/fs"
"net/netip" "net/netip"
"path/filepath" "path/filepath"
"strconv"
"strings" "strings"
"sync"
) )
// A Machine is a machine on a Zone // A Machine is a machine on a Zone
type Machine struct { type Machine struct {
mu sync.Mutex
zone *Zone zone *Zone
id int ID int
Name string `toml:"name"` Name string `toml:"name"`
PublicAddresses []netip.Addr `toml:"public,omitempty"` PublicAddresses []netip.Addr `toml:"public,omitempty"`
RingAddresses []*RingInfo `toml:"rings,omitempty"` Rings []*RingInfo `toml:"rings,omitempty"`
} }
func (m *Machine) String() string { func (m *Machine) String() string {
return m.Name return m.Name
} }
// ID return the index within the [Zone] associated to this [Machine]
func (m *Machine) ID() int {
m.mu.Lock()
defer m.mu.Unlock()
if m.id == 0 {
zoneName := m.zone.Name
s := m.Name[len(zoneName)+1:]
id, err := strconv.ParseInt(s, 10, 8)
if err != nil {
panic(err)
}
m.id = int(id)
}
return m.id
}
// FullName returns the Name of the machine including domain name // FullName returns the Name of the machine including domain name
func (m *Machine) FullName() string { func (m *Machine) FullName() string {
if domain := m.zone.zones.domain; domain != "" { if domain := m.zone.zones.domain; domain != "" {
@@ -82,3 +57,13 @@ func (m *Machine) getFilename(name string, args ...any) string {
return filepath.Join(s...) return filepath.Join(s...)
} }
// IsGateway tells if the Machine is a ring0 gateway
func (m *Machine) IsGateway() bool {
_, ok := m.getRingInfo(0)
return ok
}
func (m *Machine) getPeerByName(name string) (*Machine, bool) {
return m.zone.zones.GetMachineByName(name)
}
+133 -2
View File
@@ -10,6 +10,68 @@ import (
"git.jpi.io/amery/jpictl/pkg/wireguard" "git.jpi.io/amery/jpictl/pkg/wireguard"
) )
// GetWireguardKeys reads a wgN.key/wgN.pub files
func (m *Machine) GetWireguardKeys(ring int) (wireguard.KeyPair, error) {
var (
data []byte
err error
out wireguard.KeyPair
)
data, err = m.ReadFile("wg%v.key", ring)
if err != nil {
// failed to read
return out, err
}
out.PrivateKey, err = wireguard.PrivateKeyFromBase64(string(data))
if err != nil {
// bad key
err = core.Wrapf(err, "wg%v.key", ring)
return out, err
}
data, err = m.ReadFile("wg%v.pub", ring)
switch {
case os.IsNotExist(err):
// no wgN.pub is fine
case err != nil:
// failed to read
return out, err
default:
// good read
out.PublicKey, err = wireguard.PublicKeyFromBase64(string(data))
if err != nil {
// bad key
err = core.Wrapf(err, "wg%v.pub", ring)
return out, err
}
}
err = out.Validate()
return out, err
}
func (m *Machine) tryReadWireguardKeys(ring int) error {
kp, err := m.GetWireguardKeys(ring)
switch {
case os.IsNotExist(err):
// ignore
return nil
case err != nil:
// something went wrong
return err
default:
// import keys
ri := &RingInfo{
Ring: ring,
Keys: kp,
}
return m.applyRingInfo(ring, ri)
}
}
// GetWireguardConfig reads a wgN.conf file // GetWireguardConfig reads a wgN.conf file
func (m *Machine) GetWireguardConfig(ring int) (*wireguard.Config, error) { func (m *Machine) GetWireguardConfig(ring int) (*wireguard.Config, error) {
data, err := m.ReadFile("wg%v.conf", ring) data, err := m.ReadFile("wg%v.conf", ring)
@@ -45,17 +107,86 @@ func (m *Machine) applyWireguardConfig(ring int, wg *wireguard.Config) error {
return err return err
} }
if err := m.applyWireguardInterfaceConfig(ring, wg.Interface); err != nil {
err = core.Wrapf(err, "%s: wg%v:%s", m.Name, ring, addr)
return err
}
for _, peer := range wg.Peer {
if err := m.applyWireguardPeerConfig(ring, peer); err != nil {
err = core.Wrapf(err, "%s: wg%v:%s", m.Name, ring, addr)
return err
}
}
return nil return nil
} }
func (m *Machine) getRingInfo(ring int) (*RingInfo, bool) {
for _, ri := range m.Rings {
if ri.Ring == ring {
return ri, ri.Enabled
}
}
return nil, false
}
func (m *Machine) applyRingInfo(ring int, new *RingInfo) error {
cur, _ := m.getRingInfo(ring)
if cur == nil {
// first, append
m.Rings = append(m.Rings, new)
return nil
}
// extra, merge
return cur.Merge(new)
}
func (m *Machine) applyWireguardInterfaceConfig(ring int, data wireguard.InterfaceConfig) error {
ri := &RingInfo{
Ring: ring,
Enabled: true,
Keys: wireguard.KeyPair{
PrivateKey: data.PrivateKey,
},
}
return m.applyRingInfo(ring, ri)
}
func (m *Machine) applyWireguardPeerConfig(ring int, pc wireguard.PeerConfig) error {
peer, found := m.getPeerByName(pc.Endpoint.Name())
switch {
case !found:
// unknown
case ring == 1 && m.zone != peer.zone:
// invalid zone
default:
// apply RingInfo
ri := &RingInfo{
Ring: ring,
Enabled: true,
Keys: wireguard.KeyPair{
PublicKey: pc.PublicKey,
},
}
return peer.applyRingInfo(ring, ri)
}
return fmt.Errorf("%q: invalid peer endpoint", pc.Endpoint.Host)
}
func (m *Machine) applyZoneNodeID(zoneID, nodeID int) error { func (m *Machine) applyZoneNodeID(zoneID, nodeID int) error {
switch { switch {
case zoneID == 0: case zoneID == 0:
return fmt.Errorf("invalid %s", "zoneID") return fmt.Errorf("invalid %s", "zoneID")
case nodeID == 0: case nodeID == 0:
return fmt.Errorf("invalid %s", "nodeID") return fmt.Errorf("invalid %s", "nodeID")
case m.ID() != nodeID: case m.ID != nodeID:
return fmt.Errorf("invalid %s: %v ≠ %v", "zoneID", m.ID(), nodeID) return fmt.Errorf("invalid %s: %v ≠ %v", "zoneID", m.ID, nodeID)
case m.zone.ID != 0 && m.zone.ID != zoneID: case m.zone.ID != 0 && m.zone.ID != zoneID:
return fmt.Errorf("invalid %s: %v ≠ %v", "zoneID", m.zone.ID, zoneID) return fmt.Errorf("invalid %s: %v ≠ %v", "zoneID", m.zone.ID, zoneID)
case m.zone.ID == 0: case m.zone.ID == 0:
+27
View File
@@ -3,6 +3,7 @@ package zones
import ( import (
"context" "context"
"net/netip" "net/netip"
"strconv"
"time" "time"
) )
@@ -25,6 +26,32 @@ func (m *Machine) updatePublicAddresses() error {
return nil return nil
} }
func (m *Machine) init() error {
if err := m.setID(); err != nil {
return err
}
for i := 0; i < RingsCount; i++ {
if err := m.tryReadWireguardKeys(i); err != nil {
return err
}
}
return nil
}
func (m *Machine) setID() error {
zoneName := m.zone.Name
suffix := m.Name[len(zoneName)+1:]
id, err := strconv.ParseInt(suffix, 10, 8)
if err != nil {
return err
}
m.ID = int(id)
return nil
}
func (m *Machine) scan() error { func (m *Machine) scan() error {
for i := 0; i < RingsCount; i++ { for i := 0; i < RingsCount; i++ {
if err := m.tryApplyWireguardConfig(i); err != nil { if err := m.tryApplyWireguardConfig(i); err != nil {
+58 -4
View File
@@ -1,6 +1,7 @@
package zones package zones
import ( import (
"fmt"
"net/netip" "net/netip"
"git.jpi.io/amery/jpictl/pkg/wireguard" "git.jpi.io/amery/jpictl/pkg/wireguard"
@@ -13,21 +14,72 @@ const (
MaxNodeID = 0xff - 1 MaxNodeID = 0xff - 1
// RingsCount indicates how many wireguard rings we have // RingsCount indicates how many wireguard rings we have
RingsCount = 2 RingsCount = 2
// RingZeroPort is the port wireguard uses for ring0
RingZeroPort = 51800
// RingOnePort is the port wireguard uses for ring1
RingOnePort = 51810
) )
// RingInfo contains represents the Wireguard endpoint details // RingInfo contains represents the Wireguard endpoint details
// for a Machine on a particular ring // for a Machine on a particular ring
type RingInfo struct { type RingInfo struct {
Ring int `toml:"ring"` Ring int `toml:"ring"`
Enabled bool `toml:"enabled,omitempty"` Enabled bool `toml:"enabled,omitempty"`
Keys *wireguard.KeyPair `toml:"keys,omitempty"` Keys wireguard.KeyPair `toml:"keys,omitempty"`
Address netip.Addr `toml:"address,omitempty"` }
// Merge attempts to combine two RingInfo structs
func (ri *RingInfo) Merge(alter *RingInfo) error {
switch {
case alter == nil:
return nil
case ri.Ring != alter.Ring:
// different ring
return fmt.Errorf("invalid %s: %v ≠ %v", "ring", ri.Ring, alter.Ring)
case ri.Enabled && !alter.Enabled:
// can't disable via Merge
return fmt.Errorf("invalid %s: %v → %v", "enabled", ri.Enabled, alter.Enabled)
case !canMergeKeyPairs(ri.Keys, alter.Keys):
// incompatible keypairs
return fmt.Errorf("invalid %s: %s ≠ %s", "keys", ri.Keys, alter.Keys)
}
return ri.unsafeMerge(alter)
}
func (ri *RingInfo) unsafeMerge(alter *RingInfo) error {
// enable via Merge
if alter.Enabled {
ri.Enabled = true
}
// fill the gaps on our keypair
if ri.Keys.PrivateKey.IsZero() {
ri.Keys.PrivateKey = alter.Keys.PrivateKey
}
if ri.Keys.PublicKey.IsZero() {
ri.Keys.PublicKey = alter.Keys.PublicKey
}
return nil
}
func canMergeKeyPairs(p1, p2 wireguard.KeyPair) bool {
switch {
case !p1.PrivateKey.IsZero() && !p2.PrivateKey.IsZero() && !p1.PrivateKey.Equal(p2.PrivateKey):
return false
case !p1.PublicKey.IsZero() && !p2.PublicKey.IsZero() && !p1.PublicKey.Equal(p2.PublicKey):
return false
default:
return true
}
} }
// RingAddressEncoder provides encoder/decoder access for a particular // RingAddressEncoder provides encoder/decoder access for a particular
// Wireguard ring // Wireguard ring
type RingAddressEncoder struct { type RingAddressEncoder struct {
ID int ID int
Port uint16
Encode func(zoneID, nodeID int) (netip.Addr, bool) Encode func(zoneID, nodeID int) (netip.Addr, bool)
Decode func(addr netip.Addr) (zoneID, nodeID int, ok bool) Decode func(addr netip.Addr) (zoneID, nodeID int, ok bool)
} }
@@ -36,12 +88,14 @@ var (
// RingZero is a wg0 address encoder/decoder // RingZero is a wg0 address encoder/decoder
RingZero = RingAddressEncoder{ RingZero = RingAddressEncoder{
ID: 0, ID: 0,
Port: RingZeroPort,
Decode: ParseRingZeroAddress, Decode: ParseRingZeroAddress,
Encode: RingZeroAddress, Encode: RingZeroAddress,
} }
// RingOne is a wg1 address encoder/decoder // RingOne is a wg1 address encoder/decoder
RingOne = RingAddressEncoder{ RingOne = RingAddressEncoder{
ID: 1, ID: 1,
Port: RingOnePort,
Decode: ParseRingOneAddress, Decode: ParseRingOneAddress,
Encode: RingOneAddress, Encode: RingOneAddress,
} }
+82 -1
View File
@@ -2,9 +2,25 @@ package zones
import ( import (
"io/fs" "io/fs"
"sort"
) )
func (m *Zones) scan() error { func (m *Zones) scan() error {
for _, fn := range []func() error{
m.scanDirectory,
m.scanMachines,
m.scanZoneIDs,
m.scanSort,
} {
if err := fn(); err != nil {
return err
}
}
return nil
}
func (m *Zones) scanDirectory() error {
// each directory is a zone // each directory is a zone
entries, err := fs.ReadDir(m.dir, ".") entries, err := fs.ReadDir(m.dir, ".")
if err != nil { if err != nil {
@@ -26,7 +42,7 @@ func (m *Zones) scan() error {
} }
} }
return m.scanMachines() return nil
} }
func (m *Zones) scanMachines() error { func (m *Zones) scanMachines() error {
@@ -38,6 +54,67 @@ func (m *Zones) scanMachines() error {
return err return err
} }
func (m *Zones) scanZoneIDs() error {
var hasMissing bool
var lastZoneID int
m.ForEachZone(func(z *Zone) bool {
switch {
case z.ID == 0:
hasMissing = true
case z.ID > lastZoneID:
lastZoneID = z.ID
}
return false
})
if hasMissing {
next := lastZoneID + 1
m.ForEachZone(func(z *Zone) bool {
if z.ID == 0 {
z.ID, next = next, next+1
}
return false
})
}
return nil
}
func (m *Zones) scanSort() error {
sort.SliceStable(m.Zones, func(i, j int) bool {
id1 := m.Zones[i].ID
id2 := m.Zones[j].ID
return id1 < id2
})
m.ForEachZone(func(z *Zone) bool {
sort.SliceStable(z.Machines, func(i, j int) bool {
id1 := z.Machines[i].ID
id2 := z.Machines[j].ID
return id1 < id2
})
return false
})
m.ForEachMachine(func(p *Machine) bool {
sort.SliceStable(p.Rings, func(i, j int) bool {
ri1 := p.Rings[i]
ri2 := p.Rings[j]
return ri1.Ring < ri2.Ring
})
return false
})
return nil
}
func (z *Zone) scan() error { func (z *Zone) scan() error {
// each directory is a machine // each directory is a machine
entries, err := fs.ReadDir(z.zones.dir, z.Name) entries, err := fs.ReadDir(z.zones.dir, z.Name)
@@ -52,6 +129,10 @@ func (z *Zone) scan() error {
Name: e.Name(), Name: e.Name(),
} }
if err := m.init(); err != nil {
return err
}
z.Machines = append(z.Machines, m) z.Machines = append(z.Machines, m)
} }
} }
+40 -8
View File
@@ -22,6 +22,16 @@ func (z *Zone) String() string {
return z.Name return z.Name
} }
// ForEachMachine calls a function for each Machine in the zone
// until instructed to terminate the loop
func (z *Zone) ForEachMachine(fn func(*Machine) bool) {
for _, p := range z.Machines {
if fn(p) {
return
}
}
}
// Zones represents all zones in a cluster // Zones represents all zones in a cluster
type Zones struct { type Zones struct {
dir fs.FS dir fs.FS
@@ -32,18 +42,22 @@ type Zones struct {
} }
// ForEachMachine calls a function for each Machine in the cluster // ForEachMachine calls a function for each Machine in the cluster
// until instructed to terminate the loop
func (m *Zones) ForEachMachine(fn func(*Machine) bool) { func (m *Zones) ForEachMachine(fn func(*Machine) bool) {
for _, z := range m.Zones { m.ForEachZone(func(z *Zone) bool {
for _, p := range z.Machines { var term bool
if fn(p) {
// terminate z.ForEachMachine(func(p *Machine) bool {
return term = fn(p)
} return term
} })
}
return term
})
} }
// ForEachZone calls a function for each Zone in the cluster // ForEachZone calls a function for each Zone in the cluster
// until instructed to terminate the loop
func (m *Zones) ForEachZone(fn func(*Zone) bool) { func (m *Zones) ForEachZone(fn func(*Zone) bool) {
for _, p := range m.Zones { for _, p := range m.Zones {
if fn(p) { if fn(p) {
@@ -53,6 +67,24 @@ func (m *Zones) ForEachZone(fn func(*Zone) bool) {
} }
} }
// GetMachineByName looks for a machine with the specified
// name on any zone
func (m *Zones) GetMachineByName(name string) (*Machine, bool) {
var out *Machine
if name != "" {
m.ForEachMachine(func(p *Machine) bool {
if p.Name == name {
out = p
}
return out != nil
})
}
return out, out != nil
}
// NewFS builds a [Zones] tree using the given directory // NewFS builds a [Zones] tree using the given directory
func NewFS(dir fs.FS, domain string) (*Zones, error) { func NewFS(dir fs.FS, domain string) (*Zones, error) {
lockuper := resolver.NewCloudflareLookuper() lockuper := resolver.NewCloudflareLookuper()