Compare commits

...

21 Commits

Author SHA1 Message Date
amery 7ca01aa1e4 zones: Machine.RemoveWireguardConfig()
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-25 15:36:58 +00:00
amery 8b72667f4d zones: Machine.RemoveWireguardKeys()
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-25 15:20:07 +00:00
amery 49694eb7cb zones: Machine.WriteWireguardKeys()
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-25 14:53:31 +00:00
amery 15a98c05ec zones: Machine.WriteStringFile()
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-25 14:53:24 +00:00
amery a005823d44 zones: Machine.CreateFile() and Machine.CreateTruncFile()
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-25 14:53:24 +00:00
amery 7af8484acc zones: introduce Machine.OpenFile()
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-24 21:14:33 +00:00
amery 0f1f1ce968 zones: introduce Machine.RemoveFile()
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-24 21:14:33 +00:00
amery 5058f286c6 zones: switch to using hackpadfs/os.FS as the standard os.FS is incomplete
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-24 21:14:33 +00:00
amery 86075eb47f zones: move Machine.ReadFile to a dedicated machine_file.go
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-24 21:14:31 +00:00
amery c81b782b26 zones: Machine.IsGateway()
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-24 21:07:55 +00:00
amery 0f62ee2e53 zones: rename Machine.RingAddresses to Machine.Rings
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-24 20:52:08 +00:00
amery 30a7bceda3 wireguard: make KeyPairs solid
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-24 19:49:59 +00:00
amery 60e2687d04 wireguard: make keys arrays instead of slices
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-24 19:35:11 +00:00
amery 1419e55d5b zones: remove useless RingInfo.Address
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-24 19:06:33 +00:00
amery ffdacb833b zones: add Port information to RingAddressEncoder
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-24 17:56:19 +00:00
amery aca0a5e834 zones: calculate Machine.ID on init
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-24 17:55:43 +00:00
amery 61374d4cc5 zones: load wireguard key pairs on Machine.init()
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-24 14:35:40 +00:00
amery 975e166da7 zones: allow RingInfo.Merge() to enable, but not disable
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-24 14:34:00 +00:00
amery b16c648f2c zones: introduce Machine.GetWireguardKeys()
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-23 21:03:03 +00:00
amery 47d79f7576 wireguard: introduce KeyPair.Validate()
it will also set the PublicKey field is empty

Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-23 00:29:15 +00:00
amery e2f831fd6a wireguard: introduce NewKeyPair, NewPrivateKey, and PrivateKey.Public()
Signed-off-by: Alejandro Mery <amery@jpi.io>
2023-08-23 00:18:31 +00:00
11 changed files with 415 additions and 133 deletions
+2
View File
@@ -8,8 +8,10 @@ require (
darvaza.org/sidecar v0.0.0-20230721122716-b9c54b8adbaf darvaza.org/sidecar v0.0.0-20230721122716-b9c54b8adbaf
darvaza.org/slog v0.5.2 darvaza.org/slog v0.5.2
github.com/burntSushi/toml v0.3.1 github.com/burntSushi/toml v0.3.1
github.com/hack-pad/hackpadfs v0.2.1
github.com/mgechev/revive v1.3.2 github.com/mgechev/revive v1.3.2
github.com/spf13/cobra v1.7.0 github.com/spf13/cobra v1.7.0
golang.org/x/crypto v0.12.0
gopkg.in/gcfg.v1 v1.2.3 gopkg.in/gcfg.v1 v1.2.3
) )
+4
View File
@@ -26,6 +26,8 @@ github.com/fatih/color v1.15.0/go.mod h1:0h5ZqXfHYED7Bhv2ZJamyIOUej9KtShiJESRwBD
github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4= github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4=
github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/hack-pad/hackpadfs v0.2.1 h1:FelFhIhv26gyjujoA/yeFO+6YGlqzmc9la/6iKMIxMw=
github.com/hack-pad/hackpadfs v0.2.1/go.mod h1:khQBuCEwGXWakkmq8ZiFUvUZz84ZkJ2KNwKvChs4OrU=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
@@ -70,6 +72,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 h1:MGwJjxBy0HJshjDNfLsYO8xppfqWlA5ZT9OhtUUhTNw= golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 h1:MGwJjxBy0HJshjDNfLsYO8xppfqWlA5ZT9OhtUUhTNw=
golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
+103 -18
View File
@@ -2,8 +2,11 @@ package wireguard
import ( import (
"bytes" "bytes"
"crypto/rand"
"encoding/base64" "encoding/base64"
"errors" "errors"
"golang.org/x/crypto/curve25519"
) )
const ( const (
@@ -16,64 +19,79 @@ const (
var ( var (
// ErrInvalidKeySize indicates the key size is wrong // ErrInvalidKeySize indicates the key size is wrong
ErrInvalidKeySize = errors.New("invalid key size") ErrInvalidKeySize = errors.New("invalid key size")
// ErrInvalidPrivateKey indicates the private key is invalid
ErrInvalidPrivateKey = errors.New("invalid private key")
// ErrInvalidPublicKey indicates the public key is invalid
ErrInvalidPublicKey = errors.New("invalid public key")
) )
type ( type (
// PrivateKey is a binary Wireguard Private Key // PrivateKey is a binary Wireguard Private Key
PrivateKey []byte PrivateKey [PrivateKeySize]byte
// PublicKey is a binary Wireguard Public Key // PublicKey is a binary Wireguard Public Key
PublicKey []byte PublicKey [PublicKeySize]byte
) )
func (key PrivateKey) String() string { func (key PrivateKey) String() string {
return encodeKey(key) switch {
case key.IsZero():
return ""
default:
return base64.StdEncoding.EncodeToString(key[:])
}
} }
func (pub PublicKey) String() string { func (pub PublicKey) String() string {
return encodeKey(pub) switch {
case pub.IsZero():
return ""
default:
return base64.StdEncoding.EncodeToString(pub[:])
}
} }
// IsZero tells if the key hasn't been set // IsZero tells if the key hasn't been set
func (key PrivateKey) IsZero() bool { func (key PrivateKey) IsZero() bool {
return len(key) == 0 var zero PrivateKey
return key.Equal(zero)
} }
// IsZero tells if the key hasn't been set // IsZero tells if the key hasn't been set
func (pub PublicKey) IsZero() bool { func (pub PublicKey) IsZero() bool {
return len(pub) == 0 var zero PublicKey
return pub.Equal(zero)
} }
// Equal checks if two private keys are identical // Equal checks if two private keys are identical
func (key PrivateKey) Equal(alter PrivateKey) bool { func (key PrivateKey) Equal(alter PrivateKey) bool {
return bytes.Equal(key, alter) return bytes.Equal(key[:], alter[:])
} }
// Equal checks if two public keys are identical // Equal checks if two public keys are identical
func (pub PublicKey) Equal(alter PublicKey) bool { func (pub PublicKey) Equal(alter PublicKey) bool {
return bytes.Equal(pub, alter) return bytes.Equal(pub[:], alter[:])
} }
// PrivateKeyFromBase64 decodes a base64-based string into // PrivateKeyFromBase64 decodes a base64-based string into
// a [PrivateKey] // a [PrivateKey]
func PrivateKeyFromBase64(data string) (PrivateKey, error) { func PrivateKeyFromBase64(data string) (PrivateKey, error) {
b, err := decodeKey(data, PrivateKeySize) b, err := decodeKey(data, PrivateKeySize)
return b, err if err != nil {
var zero PrivateKey
return zero, err
}
return *(*[PrivateKeySize]byte)(b), nil
} }
// PublicKeyFromBase64 decodes a base64-based string into // PublicKeyFromBase64 decodes a base64-based string into
// a [PublicKey] // a [PublicKey]
func PublicKeyFromBase64(data string) (PublicKey, error) { func PublicKeyFromBase64(data string) (PublicKey, error) {
b, err := decodeKey(data, PublicKeySize) b, err := decodeKey(data, PublicKeySize)
return b, err if err != nil {
} var zero PublicKey
return zero, err
func encodeKey(b []byte) string {
switch {
case len(b) == 0:
return ""
default:
return base64.StdEncoding.EncodeToString(b)
} }
return *(*[PublicKeySize]byte)(b), nil
} }
func decodeKey(data string, size int) ([]byte, error) { func decodeKey(data string, size int) ([]byte, error) {
@@ -89,8 +107,75 @@ func decodeKey(data string, size int) ([]byte, error) {
} }
} }
// NewPrivateKey creates a new PrivateKey
func NewPrivateKey() (PrivateKey, error) {
var s [PrivateKeySize]byte
_, err := rand.Read(s[:])
if err != nil {
var zero PrivateKey
return zero, err
}
// apply same clamping as wireguard-go/device/noise-helpers.go
s[0] &= 0xf8
s[31] = (s[31] & 0x7f) | 0x40
return s, nil
}
// Public generates the corresponding PublicKey
func (key PrivateKey) Public() PublicKey {
var pub PublicKey
if !key.IsZero() {
in := (*[PrivateKeySize]byte)(&key)
out := (*[PublicKeySize]byte)(&pub)
curve25519.ScalarBaseMult(out, in)
}
return pub
}
// KeyPair holds a Key pair // KeyPair holds a Key pair
type KeyPair struct { type KeyPair struct {
PrivateKey PrivateKey PrivateKey PrivateKey
PublicKey PublicKey PublicKey PublicKey
} }
// Validate checks the PublicKey matches the PrivateKey,
// and sets the PublicKey if missing
func (kp *KeyPair) Validate() error {
keyLen := len(kp.PrivateKey)
pubLen := len(kp.PublicKey)
switch {
case keyLen != PrivateKeySize:
// bad private key
return ErrInvalidPrivateKey
case pubLen == 0:
// no public key, set it
kp.PublicKey = kp.PrivateKey.Public()
return nil
case pubLen != PublicKeySize:
// bad public key
return ErrInvalidPublicKey
case !kp.PrivateKey.Public().Equal(kp.PublicKey):
// wrong public key
return ErrInvalidPublicKey
default:
// correct public key
return nil
}
}
// NewKeyPair creates a new KeyPair for Wireguard
func NewKeyPair() (KeyPair, error) {
var out KeyPair
key, err := NewPrivateKey()
if err == nil {
out.PrivateKey = key
out.PublicKey = key.Public()
}
return out, nil
}
+3 -5
View File
@@ -82,14 +82,12 @@ func getRingZeroGatewayID(z *Zone) int {
var firstNodeID, gatewayID int var firstNodeID, gatewayID int
z.ForEachMachine(func(p *Machine) bool { z.ForEachMachine(func(p *Machine) bool {
nodeID := p.ID()
if firstNodeID == 0 { if firstNodeID == 0 {
firstNodeID = nodeID firstNodeID = p.ID
} }
if _, found := p.getRingInfo(0); found { if p.IsGateway() {
gatewayID = nodeID gatewayID = p.ID
} }
return gatewayID != 0 return gatewayID != 0
+6 -50
View File
@@ -1,52 +1,24 @@
package zones package zones
import ( import (
"fmt"
"io/fs"
"net/netip" "net/netip"
"path/filepath"
"strconv"
"strings" "strings"
"sync"
) )
// A Machine is a machine on a Zone // A Machine is a machine on a Zone
type Machine struct { type Machine struct {
mu sync.Mutex
zone *Zone zone *Zone
id int ID int
Name string `toml:"name"` Name string `toml:"name"`
PublicAddresses []netip.Addr `toml:"public,omitempty"` PublicAddresses []netip.Addr `toml:"public,omitempty"`
RingAddresses []*RingInfo `toml:"rings,omitempty"` Rings []*RingInfo `toml:"rings,omitempty"`
} }
func (m *Machine) String() string { func (m *Machine) String() string {
return m.Name return m.Name
} }
// ID return the index within the [Zone] associated to this [Machine]
func (m *Machine) ID() int {
m.mu.Lock()
defer m.mu.Unlock()
if m.id == 0 {
zoneName := m.zone.Name
s := m.Name[len(zoneName)+1:]
id, err := strconv.ParseInt(s, 10, 8)
if err != nil {
panic(err)
}
m.id = int(id)
}
return m.id
}
// FullName returns the Name of the machine including domain name // FullName returns the Name of the machine including domain name
func (m *Machine) FullName() string { func (m *Machine) FullName() string {
if domain := m.zone.zones.domain; domain != "" { if domain := m.zone.zones.domain; domain != "" {
@@ -61,26 +33,10 @@ func (m *Machine) FullName() string {
return m.Name return m.Name
} }
// ReadFile reads a file from the machine's config directory // IsGateway tells if the Machine is a ring0 gateway
func (m *Machine) ReadFile(name string, args ...any) ([]byte, error) { func (m *Machine) IsGateway() bool {
base := m.zone.zones.dir _, ok := m.getRingInfo(0)
fullName := m.getFilename(name, args...) return ok
return fs.ReadFile(base, fullName)
}
func (m *Machine) getFilename(name string, args ...any) string {
if len(args) > 0 {
name = fmt.Sprintf(name, args...)
}
s := []string{
m.zone.Name,
m.Name,
name,
}
return filepath.Join(s...)
} }
func (m *Machine) getPeerByName(name string) (*Machine, bool) { func (m *Machine) getPeerByName(name string) (*Machine, bool) {
+87
View File
@@ -0,0 +1,87 @@
package zones
import (
"bytes"
"fmt"
"io"
"os"
"path/filepath"
fs "github.com/hack-pad/hackpadfs"
)
// OpenFile opens a file on the machine's config directory with the specified flags
func (m *Machine) OpenFile(name string, flags int, args ...any) (fs.File, error) {
base := m.zone.zones.dir
fullName := m.getFilename(name, args...)
return fs.OpenFile(base, fullName, flags, 0644)
}
// CreateTruncFile creates or truncates a file on the machine's config directory
func (m *Machine) CreateTruncFile(name string, args ...any) (io.WriteCloser, error) {
return m.openWriter(name, os.O_CREATE|os.O_TRUNC, args...)
}
// CreateFile creates a file on the machine's config directory
func (m *Machine) CreateFile(name string, args ...any) (io.WriteCloser, error) {
return m.openWriter(name, os.O_CREATE, args...)
}
func (m *Machine) openWriter(name string, flags int, args ...any) (io.WriteCloser, error) {
f, err := m.OpenFile(name, os.O_WRONLY|flags, args...)
if err != nil {
return nil, err
}
return f.(io.WriteCloser), nil
}
// RemoveFile deletes a file from the machine's config directory
func (m *Machine) RemoveFile(name string, args ...any) error {
base := m.zone.zones.dir
fullName := m.getFilename(name, args...)
err := fs.Remove(base, fullName)
switch {
case os.IsNotExist(err):
return nil
default:
return err
}
}
// ReadFile reads a file from the machine's config directory
func (m *Machine) ReadFile(name string, args ...any) ([]byte, error) {
base := m.zone.zones.dir
fullName := m.getFilename(name, args...)
return fs.ReadFile(base, fullName)
}
// WriteStringFile writes the given content to a file on the machine's config directory
func (m *Machine) WriteStringFile(value string, name string, args ...any) error {
f, err := m.CreateTruncFile(name, args...)
if err != nil {
return err
}
defer f.Close()
buf := bytes.NewBufferString(value)
_, err = buf.WriteTo(f)
return err
}
func (m *Machine) getFilename(name string, args ...any) string {
if len(args) > 0 {
name = fmt.Sprintf(name, args...)
}
s := []string{
m.zone.Name,
m.Name,
name,
}
return filepath.Join(s...)
}
+137 -10
View File
@@ -3,6 +3,7 @@ package zones
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"io/fs"
"os" "os"
"darvaza.org/core" "darvaza.org/core"
@@ -10,6 +11,122 @@ import (
"git.jpi.io/amery/jpictl/pkg/wireguard" "git.jpi.io/amery/jpictl/pkg/wireguard"
) )
// GetWireguardKeys reads a wgN.key/wgN.pub files
func (m *Machine) GetWireguardKeys(ring int) (wireguard.KeyPair, error) {
var (
data []byte
err error
out wireguard.KeyPair
)
data, err = m.ReadFile("wg%v.key", ring)
if err != nil {
// failed to read
return out, err
}
out.PrivateKey, err = wireguard.PrivateKeyFromBase64(string(data))
if err != nil {
// bad key
err = core.Wrapf(err, "wg%v.key", ring)
return out, err
}
data, err = m.ReadFile("wg%v.pub", ring)
switch {
case os.IsNotExist(err):
// no wgN.pub is fine
case err != nil:
// failed to read
return out, err
default:
// good read
out.PublicKey, err = wireguard.PublicKeyFromBase64(string(data))
if err != nil {
// bad key
err = core.Wrapf(err, "wg%v.pub", ring)
return out, err
}
}
err = out.Validate()
return out, err
}
func (m *Machine) tryReadWireguardKeys(ring int) error {
kp, err := m.GetWireguardKeys(ring)
switch {
case os.IsNotExist(err):
// ignore
return nil
case err != nil:
// something went wrong
return err
default:
// import keys
ri := &RingInfo{
Ring: ring,
Keys: kp,
}
return m.applyRingInfo(ring, ri)
}
}
// WriteWireguardKeys writes the wgN.key/wgN.pub files
func (m *Machine) WriteWireguardKeys(ring int) error {
var err error
var key, pub string
var ri *RingInfo
ri, _ = m.getRingInfo(ring)
if ri != nil {
key = ri.Keys.PrivateKey.String()
pub = ri.Keys.PublicKey.String()
}
switch {
case key == "":
return fs.ErrNotExist
case pub == "":
pub = ri.Keys.PrivateKey.Public().String()
}
err = m.WriteStringFile(key, "wg%v.key", ring)
if err != nil {
return err
}
err = m.WriteStringFile(pub, "wg%v.pub", ring)
if err != nil {
return err
}
return nil
}
// RemoveWireguardKeys deletes wgN.key and wgN.pub from
// the machine's config directory
func (m *Machine) RemoveWireguardKeys(ring int) error {
var err error
err = m.RemoveFile("wg%v.pub", ring)
switch {
case os.IsNotExist(err):
// ignore
case err != nil:
return err
}
err = m.RemoveFile("wg%v.key", ring)
if os.IsNotExist(err) {
// ignore
err = nil
}
return err
}
// GetWireguardConfig reads a wgN.conf file // GetWireguardConfig reads a wgN.conf file
func (m *Machine) GetWireguardConfig(ring int) (*wireguard.Config, error) { func (m *Machine) GetWireguardConfig(ring int) (*wireguard.Config, error) {
data, err := m.ReadFile("wg%v.conf", ring) data, err := m.ReadFile("wg%v.conf", ring)
@@ -61,9 +178,9 @@ func (m *Machine) applyWireguardConfig(ring int, wg *wireguard.Config) error {
} }
func (m *Machine) getRingInfo(ring int) (*RingInfo, bool) { func (m *Machine) getRingInfo(ring int) (*RingInfo, bool) {
for _, ri := range m.RingAddresses { for _, ri := range m.Rings {
if ri.Ring == ring { if ri.Ring == ring {
return ri, true return ri, ri.Enabled
} }
} }
@@ -71,10 +188,10 @@ func (m *Machine) getRingInfo(ring int) (*RingInfo, bool) {
} }
func (m *Machine) applyRingInfo(ring int, new *RingInfo) error { func (m *Machine) applyRingInfo(ring int, new *RingInfo) error {
cur, found := m.getRingInfo(ring) cur, _ := m.getRingInfo(ring)
if !found { if cur == nil {
// first, append // first, append
m.RingAddresses = append(m.RingAddresses, new) m.Rings = append(m.Rings, new)
return nil return nil
} }
@@ -86,8 +203,7 @@ func (m *Machine) applyWireguardInterfaceConfig(ring int, data wireguard.Interfa
ri := &RingInfo{ ri := &RingInfo{
Ring: ring, Ring: ring,
Enabled: true, Enabled: true,
Address: data.Address, Keys: wireguard.KeyPair{
Keys: &wireguard.KeyPair{
PrivateKey: data.PrivateKey, PrivateKey: data.PrivateKey,
}, },
} }
@@ -107,7 +223,7 @@ func (m *Machine) applyWireguardPeerConfig(ring int, pc wireguard.PeerConfig) er
ri := &RingInfo{ ri := &RingInfo{
Ring: ring, Ring: ring,
Enabled: true, Enabled: true,
Keys: &wireguard.KeyPair{ Keys: wireguard.KeyPair{
PublicKey: pc.PublicKey, PublicKey: pc.PublicKey,
}, },
} }
@@ -124,8 +240,8 @@ func (m *Machine) applyZoneNodeID(zoneID, nodeID int) error {
return fmt.Errorf("invalid %s", "zoneID") return fmt.Errorf("invalid %s", "zoneID")
case nodeID == 0: case nodeID == 0:
return fmt.Errorf("invalid %s", "nodeID") return fmt.Errorf("invalid %s", "nodeID")
case m.ID() != nodeID: case m.ID != nodeID:
return fmt.Errorf("invalid %s: %v ≠ %v", "zoneID", m.ID(), nodeID) return fmt.Errorf("invalid %s: %v ≠ %v", "zoneID", m.ID, nodeID)
case m.zone.ID != 0 && m.zone.ID != zoneID: case m.zone.ID != 0 && m.zone.ID != zoneID:
return fmt.Errorf("invalid %s: %v ≠ %v", "zoneID", m.zone.ID, zoneID) return fmt.Errorf("invalid %s: %v ≠ %v", "zoneID", m.zone.ID, zoneID)
case m.zone.ID == 0: case m.zone.ID == 0:
@@ -134,3 +250,14 @@ func (m *Machine) applyZoneNodeID(zoneID, nodeID int) error {
return nil return nil
} }
// RemoveWireguardConfig deletes wgN.conf from the machine's
// config directory.
func (m *Machine) RemoveWireguardConfig(ring int) error {
err := m.RemoveFile("wg%v.conf", ring)
if os.IsNotExist(err) {
err = nil
}
return err
}
+27
View File
@@ -3,6 +3,7 @@ package zones
import ( import (
"context" "context"
"net/netip" "net/netip"
"strconv"
"time" "time"
) )
@@ -25,6 +26,32 @@ func (m *Machine) updatePublicAddresses() error {
return nil return nil
} }
func (m *Machine) init() error {
if err := m.setID(); err != nil {
return err
}
for i := 0; i < RingsCount; i++ {
if err := m.tryReadWireguardKeys(i); err != nil {
return err
}
}
return nil
}
func (m *Machine) setID() error {
zoneName := m.zone.Name
suffix := m.Name[len(zoneName)+1:]
id, err := strconv.ParseInt(suffix, 10, 8)
if err != nil {
return err
}
m.ID = int(id)
return nil
}
func (m *Machine) scan() error { func (m *Machine) scan() error {
for i := 0; i < RingsCount; i++ { for i := 0; i < RingsCount; i++ {
if err := m.tryApplyWireguardConfig(i); err != nil { if err := m.tryApplyWireguardConfig(i); err != nil {
+29 -43
View File
@@ -14,75 +14,58 @@ const (
MaxNodeID = 0xff - 1 MaxNodeID = 0xff - 1
// RingsCount indicates how many wireguard rings we have // RingsCount indicates how many wireguard rings we have
RingsCount = 2 RingsCount = 2
// RingZeroPort is the port wireguard uses for ring0
RingZeroPort = 51800
// RingOnePort is the port wireguard uses for ring1
RingOnePort = 51810
) )
// RingInfo contains represents the Wireguard endpoint details // RingInfo contains represents the Wireguard endpoint details
// for a Machine on a particular ring // for a Machine on a particular ring
type RingInfo struct { type RingInfo struct {
Ring int `toml:"ring"` Ring int `toml:"ring"`
Enabled bool `toml:"enabled,omitempty"` Enabled bool `toml:"enabled,omitempty"`
Keys *wireguard.KeyPair `toml:"keys,omitempty"` Keys wireguard.KeyPair `toml:"keys,omitempty"`
Address netip.Addr `toml:"address,omitempty"`
} }
// Merge attempts to combine two RingInfo structs // Merge attempts to combine two RingInfo structs
func (ri *RingInfo) Merge(alter *RingInfo) error { func (ri *RingInfo) Merge(alter *RingInfo) error {
switch { switch {
case alter == nil:
return nil
case ri.Ring != alter.Ring: case ri.Ring != alter.Ring:
// different ring // different ring
return fmt.Errorf("invalid %s: %v ≠ %v", "ring", ri.Ring, alter.Ring) return fmt.Errorf("invalid %s: %v ≠ %v", "ring", ri.Ring, alter.Ring)
case ri.Enabled != alter.Enabled: case ri.Enabled && !alter.Enabled:
// different state // can't disable via Merge
return fmt.Errorf("invalid %s: %v %v", "enabled", ri.Enabled, alter.Enabled) return fmt.Errorf("invalid %s: %v %v", "enabled", ri.Enabled, alter.Enabled)
case !canMergeAddress(ri.Address, alter.Address):
// different address
return fmt.Errorf("invalid %s: %v ≠ %v", "address", ri.Address, alter.Address)
case !canMergeKeyPairs(ri.Keys, alter.Keys): case !canMergeKeyPairs(ri.Keys, alter.Keys):
// incompatible keypairs // incompatible keypairs
return fmt.Errorf("invalid %s: %s ≠ %s", "keys", ri.Keys, alter.Keys) return fmt.Errorf("invalid %s: %s ≠ %s", "keys", ri.Keys, alter.Keys)
} }
switch { return ri.unsafeMerge(alter)
case ri.Keys == nil: }
// assign keypair
ri.Keys = alter.Keys func (ri *RingInfo) unsafeMerge(alter *RingInfo) error {
case alter.Keys != nil: // enable via Merge
// fill the gaps on our keypair if alter.Enabled {
if ri.Keys.PrivateKey.IsZero() { ri.Enabled = true
ri.Keys.PrivateKey = alter.Keys.PrivateKey
}
if ri.Keys.PublicKey.IsZero() {
ri.Keys.PublicKey = alter.Keys.PublicKey
}
} }
if addressEqual(ri.Address, netip.Addr{}) { // fill the gaps on our keypair
// assign address if ri.Keys.PrivateKey.IsZero() {
ri.Address = alter.Address ri.Keys.PrivateKey = alter.Keys.PrivateKey
}
if ri.Keys.PublicKey.IsZero() {
ri.Keys.PublicKey = alter.Keys.PublicKey
} }
return nil return nil
} }
func canMergeAddress(ip1, ip2 netip.Addr) bool { func canMergeKeyPairs(p1, p2 wireguard.KeyPair) bool {
var zero netip.Addr
switch { switch {
case addressEqual(ip1, zero) || addressEqual(ip2, zero) || addressEqual(ip1, ip2):
return true
default:
return false
}
}
func addressEqual(ip1, ip2 netip.Addr) bool {
return ip1.Compare(ip2) == 0
}
func canMergeKeyPairs(p1, p2 *wireguard.KeyPair) bool {
switch {
case p1 == nil || p2 == nil:
return true
case !p1.PrivateKey.IsZero() && !p2.PrivateKey.IsZero() && !p1.PrivateKey.Equal(p2.PrivateKey): case !p1.PrivateKey.IsZero() && !p2.PrivateKey.IsZero() && !p1.PrivateKey.Equal(p2.PrivateKey):
return false return false
case !p1.PublicKey.IsZero() && !p2.PublicKey.IsZero() && !p1.PublicKey.Equal(p2.PublicKey): case !p1.PublicKey.IsZero() && !p2.PublicKey.IsZero() && !p1.PublicKey.Equal(p2.PublicKey):
@@ -96,6 +79,7 @@ func canMergeKeyPairs(p1, p2 *wireguard.KeyPair) bool {
// Wireguard ring // Wireguard ring
type RingAddressEncoder struct { type RingAddressEncoder struct {
ID int ID int
Port uint16
Encode func(zoneID, nodeID int) (netip.Addr, bool) Encode func(zoneID, nodeID int) (netip.Addr, bool)
Decode func(addr netip.Addr) (zoneID, nodeID int, ok bool) Decode func(addr netip.Addr) (zoneID, nodeID int, ok bool)
} }
@@ -104,12 +88,14 @@ var (
// RingZero is a wg0 address encoder/decoder // RingZero is a wg0 address encoder/decoder
RingZero = RingAddressEncoder{ RingZero = RingAddressEncoder{
ID: 0, ID: 0,
Port: RingZeroPort,
Decode: ParseRingZeroAddress, Decode: ParseRingZeroAddress,
Encode: RingZeroAddress, Encode: RingZeroAddress,
} }
// RingOne is a wg1 address encoder/decoder // RingOne is a wg1 address encoder/decoder
RingOne = RingAddressEncoder{ RingOne = RingAddressEncoder{
ID: 1, ID: 1,
Port: RingOnePort,
Decode: ParseRingOneAddress, Decode: ParseRingOneAddress,
Encode: RingOneAddress, Encode: RingOneAddress,
} }
+9 -5
View File
@@ -93,8 +93,8 @@ func (m *Zones) scanSort() error {
m.ForEachZone(func(z *Zone) bool { m.ForEachZone(func(z *Zone) bool {
sort.SliceStable(z.Machines, func(i, j int) bool { sort.SliceStable(z.Machines, func(i, j int) bool {
id1 := z.Machines[i].ID() id1 := z.Machines[i].ID
id2 := z.Machines[j].ID() id2 := z.Machines[j].ID
return id1 < id2 return id1 < id2
}) })
@@ -102,9 +102,9 @@ func (m *Zones) scanSort() error {
}) })
m.ForEachMachine(func(p *Machine) bool { m.ForEachMachine(func(p *Machine) bool {
sort.SliceStable(p.RingAddresses, func(i, j int) bool { sort.SliceStable(p.Rings, func(i, j int) bool {
ri1 := p.RingAddresses[i] ri1 := p.Rings[i]
ri2 := p.RingAddresses[j] ri2 := p.Rings[j]
return ri1.Ring < ri2.Ring return ri1.Ring < ri2.Ring
}) })
@@ -129,6 +129,10 @@ func (z *Zone) scan() error {
Name: e.Name(), Name: e.Name(),
} }
if err := m.init(); err != nil {
return err
}
z.Machines = append(z.Machines, m) z.Machines = append(z.Machines, m)
} }
} }
+8 -2
View File
@@ -3,7 +3,8 @@ package zones
import ( import (
"io/fs" "io/fs"
"os"
"github.com/hack-pad/hackpadfs/os"
"darvaza.org/resolver" "darvaza.org/resolver"
) )
@@ -104,5 +105,10 @@ func NewFS(dir fs.FS, domain string) (*Zones, error) {
// New builds a [Zones] tree using the given directory // New builds a [Zones] tree using the given directory
func New(dir, domain string) (*Zones, error) { func New(dir, domain string) (*Zones, error) {
return NewFS(os.DirFS(dir), domain) base, err := os.NewFS().Sub(dir)
if err != nil {
return nil, err
}
return NewFS(base, domain)
} }