diff options
| -rw-r--r-- | acid.go | 100 | 
1 files changed, 84 insertions, 16 deletions
| @@ -22,6 +22,7 @@ import (  	"os/signal"  	"sort"  	"strconv" +	"strings"  	"sync"  	"syscall"  	ttemplate "text/template" @@ -361,7 +362,23 @@ func handlePush(w http.ResponseWriter, r *http.Request) {  const rpcHeaderSignature = "X-ACID-Signature" -func rpcRestart(w io.Writer, ids []int64) { +var errWrongUsage = errors.New("wrong usage") + +func rpcRestart(ctx context.Context, +	w io.Writer, fs *flag.FlagSet, args []string) error { +	if err := fs.Parse(args); err != nil { +		return err +	} + +	ids := []int64{} +	for _, arg := range fs.Args() { +		id, err := strconv.ParseInt(arg, 10, 64) +		if err != nil { +			return fmt.Errorf("%w: %s", errWrongUsage, err) +		} +		ids = append(ids, id) +	} +  	gRunningMutex.Lock()  	defer gRunningMutex.Unlock() @@ -373,7 +390,7 @@ func rpcRestart(w io.Writer, ids []int64) {  		// The executor bumps to "running" after inserting into gRunning,  		// so we should not need to exclude that state here. -		result, err := gDB.ExecContext(context.Background(), `UPDATE task +		result, err := gDB.ExecContext(ctx, `UPDATE task  			SET state = ?, detail = '', notified = 0 WHERE id = ?`,  			taskStateNew, id)  		if err != nil { @@ -384,6 +401,35 @@ func rpcRestart(w io.Writer, ids []int64) {  	}  	notifierAwaken()  	executorAwaken() +	return nil +} + +// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + +var rpcCommands = map[string]struct { +	// handler must not write anything when returning an error. +	handler  func(context.Context, io.Writer, *flag.FlagSet, []string) error +	usage    string +	function string +}{ +	"restart": {rpcRestart, "ID...", +		"Schedule tasks with the given IDs to be rerun."}, +} + +func rpcPrintCommands(w io.Writer) { +	// The alphabetic ordering is unfortunate, but tolerable. +	keys := []string{} +	for key := range rpcCommands { +		keys = append(keys, key) +	} +	sort.Strings(keys) + +	fmt.Fprintf(w, "Commands:\n") +	for _, key := range keys { +		cmd := rpcCommands[key] +		fmt.Fprintf(w, "  %s [OPTION...] %s\n    \t%s\n", +			key, cmd.usage, cmd.function) +	}  }  func handleRPC(w http.ResponseWriter, r *http.Request) { @@ -410,21 +456,43 @@ func handleRPC(w http.ResponseWriter, r *http.Request) {  		return  	} +	// Our handling closely follows what the flag package does internally. +  	command, args := args[0], args[1:] -	switch command { -	case "restart": -		ids := []int64{} -		for _, arg := range args { -			id, err := strconv.ParseInt(arg, 10, 64) -			if err != nil { -				http.Error(w, err.Error(), http.StatusBadRequest) -				return -			} -			ids = append(ids, id) -		} -		rpcRestart(w, ids) -	default: -		http.Error(w, "Unknown command: "+command, http.StatusBadRequest) +	cmd, ok := rpcCommands[command] +	if !ok { +		http.Error(w, "unknown command: "+command, http.StatusBadRequest) +		rpcPrintCommands(w) +		return +	} + +	// If we redirected the FlagSet straight to the response, +	// we would be unable to set our own HTTP status. +	b := bytes.NewBuffer(nil) + +	fs := flag.NewFlagSet(command, flag.ContinueOnError) +	fs.SetOutput(b) +	fs.Usage = func() { +		fmt.Fprintf(fs.Output(), +			"Usage: %s [OPTION...] %s\n%s\n", +			fs.Name(), cmd.usage, cmd.function) +		fs.PrintDefaults() +	} + +	err = cmd.handler(r.Context(), w, fs, args) + +	// Wrap this error to make it as if fs.Parse discovered the issue. +	if errors.Is(err, errWrongUsage) { +		fmt.Fprintln(fs.Output(), err) +		fs.Usage() +	} + +	// The flag package first prints all errors that it returns. +	// If the buffer ends up not being empty, flush it into the request. +	if b.Len() != 0 { +		http.Error(w, strings.TrimSpace(b.String()), http.StatusBadRequest) +	} else if err != nil { +		http.Error(w, err.Error(), http.StatusUnprocessableEntity)  	}  } | 
