diff --git a/cmd/jpictl/gateway.go b/cmd/jpictl/gateway.go new file mode 100644 index 0000000..1004db3 --- /dev/null +++ b/cmd/jpictl/gateway.go @@ -0,0 +1,168 @@ +package main + +import ( + "bytes" + "fmt" + "os" + "strconv" + "strings" + + "git.jpi.io/amery/jpictl/pkg/zones" + "github.com/spf13/cobra" +) + +// Command +var gatewayCmd = &cobra.Command{ + Use: "gateway", + Short: "gateway operates on ring0/ring1 gateways", +} + +// gateway set +var gatewaySetCmd = &cobra.Command{ + Use: "set", + Short: "gateway set sets machines as gateways", + RunE: func(_ *cobra.Command, args []string) error { + m, err := cfg.LoadZones(false) + if err != nil { + return err + } + + for _, arg := range args { + err = gatewaySet(m, arg) + if err != nil { + return err + } + } + return nil + }, +} + +func gatewaySet(zi zones.ZoneIterator, gw string) error { + var err error + zi.ForEachZone(func(z *zones.Zone) bool { + for _, m := range z.Machines { + if m.Name == gw { + z.SetGateway(m.ID, true) + return true + } + } + err = fmt.Errorf("machine %s not found", gw) + return false + }) + return err +} + +// gateway unset +var gatewayUnsetCmd = &cobra.Command{ + Use: "unset", + Short: "gateway unset sets machines as non-gateways", + RunE: func(_ *cobra.Command, args []string) error { + m, err := cfg.LoadZones(false) + if err != nil { + return err + } + + for _, arg := range args { + err = gatewayUnset(m, arg) + if err != nil { + return err + } + } + + return nil + }, +} + +func gatewayUnset(zi zones.ZoneIterator, ngw string) error { + var err error + zi.ForEachZone(func(z *zones.Zone) bool { + for _, m := range z.Machines { + if m.Name == ngw && m.IsGateway() { + z.SetGateway(m.ID, false) + m.RemoveWireguardConfig(0) + return true + } + } + err = fmt.Errorf("machine %s not found", ngw) + return false + }) + return err +} + +// gateway list +var gatewayListCmd = &cobra.Command{ + Use: "list", + Short: "gateway list lists gateways", + RunE: func(_ *cobra.Command, args []string) error { + m, err := cfg.LoadZones(false) + if err != nil { + return err + } + + switch { + case len(args) == 0: + return gatewayListAll(m) + default: + for _, arg := range args { + err = gatewayList(m, arg) + if err != nil { + return err + } + } + return nil + } + }, +} + +func gatewayListAll(zi zones.ZoneIterator) error { + var b bytes.Buffer + var err error + zi.ForEachZone(func(z *zones.Zone) bool { + b.WriteString(z.Name + ":") + var sIDs []string + ids, num := z.GatewayIDs() + if num == 0 { + err = fmt.Errorf("no gateway in zone %s", z.Name) + return false + } + for _, i := range ids { + sIDs = append(sIDs, strconv.Itoa(i)) + } + b.WriteString(strings.Join(sIDs, ", ")) + b.WriteString("\n") + _, err = b.WriteTo(os.Stdout) + return false + }) + return err +} + +func gatewayList(zi zones.ZoneIterator, m string) error { + var b bytes.Buffer + var err error + zi.ForEachZone(func(z *zones.Zone) bool { + if z.Name == m { + b.WriteString(z.Name + ":") + ids, num := z.GatewayIDs() + if num == 0 { + err = fmt.Errorf("no gateway in zone %s", z.Name) + return true + } + + b.WriteString(fmt.Sprint(ids)) + b.WriteString("\n") + _, err = b.WriteTo(os.Stdout) + return true + } + err = fmt.Errorf("zone %s not found", m) + return false + }) + return err +} + +func init() { + rootCmd.AddCommand(gatewayCmd) + + gatewayCmd.AddCommand(gatewaySetCmd) + gatewayCmd.AddCommand(gatewayUnsetCmd) + gatewayCmd.AddCommand(gatewayListCmd) +}