aboutsummaryrefslogtreecommitdiff
path: root/acid.go
diff options
context:
space:
mode:
authorPřemysl Eric Janouch <p@janouch.name>2024-04-14 22:07:39 +0200
committerPřemysl Eric Janouch <p@janouch.name>2024-04-15 00:05:53 +0200
commiteda0f22f072b7985c6919770858bcdd566290f86 (patch)
treecdefb28cb27a2c68617b974911ad3483892c53b1 /acid.go
parent013e7eba28013f1f3c998b06e777c8293badde87 (diff)
downloadacid-eda0f22f072b7985c6919770858bcdd566290f86.tar.gz
acid-eda0f22f072b7985c6919770858bcdd566290f86.tar.xz
acid-eda0f22f072b7985c6919770858bcdd566290f86.zip
Rewrite RPC handling for wider usability
Diffstat (limited to 'acid.go')
-rw-r--r--acid.go100
1 files changed, 84 insertions, 16 deletions
diff --git a/acid.go b/acid.go
index 280323c..64d69dc 100644
--- a/acid.go
+++ b/acid.go
@@ -22,6 +22,7 @@ import (
"os/signal"
"sort"
"strconv"
+ "strings"
"sync"
"syscall"
ttemplate "text/template"
@@ -361,7 +362,23 @@ func handlePush(w http.ResponseWriter, r *http.Request) {
const rpcHeaderSignature = "X-ACID-Signature"
-func rpcRestart(w io.Writer, ids []int64) {
+var errWrongUsage = errors.New("wrong usage")
+
+func rpcRestart(ctx context.Context,
+ w io.Writer, fs *flag.FlagSet, args []string) error {
+ if err := fs.Parse(args); err != nil {
+ return err
+ }
+
+ ids := []int64{}
+ for _, arg := range fs.Args() {
+ id, err := strconv.ParseInt(arg, 10, 64)
+ if err != nil {
+ return fmt.Errorf("%w: %s", errWrongUsage, err)
+ }
+ ids = append(ids, id)
+ }
+
gRunningMutex.Lock()
defer gRunningMutex.Unlock()
@@ -373,7 +390,7 @@ func rpcRestart(w io.Writer, ids []int64) {
// The executor bumps to "running" after inserting into gRunning,
// so we should not need to exclude that state here.
- result, err := gDB.ExecContext(context.Background(), `UPDATE task
+ result, err := gDB.ExecContext(ctx, `UPDATE task
SET state = ?, detail = '', notified = 0 WHERE id = ?`,
taskStateNew, id)
if err != nil {
@@ -384,6 +401,35 @@ func rpcRestart(w io.Writer, ids []int64) {
}
notifierAwaken()
executorAwaken()
+ return nil
+}
+
+// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
+
+var rpcCommands = map[string]struct {
+ // handler must not write anything when returning an error.
+ handler func(context.Context, io.Writer, *flag.FlagSet, []string) error
+ usage string
+ function string
+}{
+ "restart": {rpcRestart, "ID...",
+ "Schedule tasks with the given IDs to be rerun."},
+}
+
+func rpcPrintCommands(w io.Writer) {
+ // The alphabetic ordering is unfortunate, but tolerable.
+ keys := []string{}
+ for key := range rpcCommands {
+ keys = append(keys, key)
+ }
+ sort.Strings(keys)
+
+ fmt.Fprintf(w, "Commands:\n")
+ for _, key := range keys {
+ cmd := rpcCommands[key]
+ fmt.Fprintf(w, " %s [OPTION...] %s\n \t%s\n",
+ key, cmd.usage, cmd.function)
+ }
}
func handleRPC(w http.ResponseWriter, r *http.Request) {
@@ -410,21 +456,43 @@ func handleRPC(w http.ResponseWriter, r *http.Request) {
return
}
+ // Our handling closely follows what the flag package does internally.
+
command, args := args[0], args[1:]
- switch command {
- case "restart":
- ids := []int64{}
- for _, arg := range args {
- id, err := strconv.ParseInt(arg, 10, 64)
- if err != nil {
- http.Error(w, err.Error(), http.StatusBadRequest)
- return
- }
- ids = append(ids, id)
- }
- rpcRestart(w, ids)
- default:
- http.Error(w, "Unknown command: "+command, http.StatusBadRequest)
+ cmd, ok := rpcCommands[command]
+ if !ok {
+ http.Error(w, "unknown command: "+command, http.StatusBadRequest)
+ rpcPrintCommands(w)
+ return
+ }
+
+ // If we redirected the FlagSet straight to the response,
+ // we would be unable to set our own HTTP status.
+ b := bytes.NewBuffer(nil)
+
+ fs := flag.NewFlagSet(command, flag.ContinueOnError)
+ fs.SetOutput(b)
+ fs.Usage = func() {
+ fmt.Fprintf(fs.Output(),
+ "Usage: %s [OPTION...] %s\n%s\n",
+ fs.Name(), cmd.usage, cmd.function)
+ fs.PrintDefaults()
+ }
+
+ err = cmd.handler(r.Context(), w, fs, args)
+
+ // Wrap this error to make it as if fs.Parse discovered the issue.
+ if errors.Is(err, errWrongUsage) {
+ fmt.Fprintln(fs.Output(), err)
+ fs.Usage()
+ }
+
+ // The flag package first prints all errors that it returns.
+ // If the buffer ends up not being empty, flush it into the request.
+ if b.Len() != 0 {
+ http.Error(w, strings.TrimSpace(b.String()), http.StatusBadRequest)
+ } else if err != nil {
+ http.Error(w, err.Error(), http.StatusUnprocessableEntity)
}
}