Files

202 lines
4.8 KiB
Go

package clientcmd
import (
"context"
"flag"
"fmt"
"io"
"net/http"
"os"
"strings"
"time"
"inbox/internal/client"
)
const DefaultAPIURL = "http://127.0.0.1:3000"
type APIClient interface {
Do(ctx context.Context, method, path string, headers http.Header, body []byte) (client.Response, error)
}
type APIDeps struct {
Stdout io.Writer
Stdin io.Reader
NewClient func(baseURL string, httpClient *http.Client) APIClient
}
type headersFlag struct {
values []string
}
func (h *headersFlag) String() string {
return strings.Join(h.values, ",")
}
func (h *headersFlag) Set(value string) error {
value = strings.TrimSpace(value)
if value == "" {
return fmt.Errorf("header cannot be empty")
}
h.values = append(h.values, value)
return nil
}
func RunAPI(args []string, deps APIDeps) error {
flagSet := flag.NewFlagSet("api", flag.ContinueOnError)
flagSet.SetOutput(io.Discard)
addr := flagSet.String("addr", strings.TrimSpace(os.Getenv("INBOX_API_URL")), "Inbox API base URL")
data := flagSet.String("data", "", "Inline request body")
file := flagSet.String("file", "", "Request body file path, or - for stdin")
timeout := flagSet.Duration("timeout", 30*time.Second, "HTTP request timeout")
var headerValues headersFlag
flagSet.Var(&headerValues, "header", "HTTP header in 'Key: Value' format; repeatable")
if err := flagSet.Parse(normalizeAPIArgs(args)); err != nil {
return err
}
rest := flagSet.Args()
if len(rest) < 2 {
return fmt.Errorf("usage: inbox api [flags] METHOD PATH")
}
method := strings.ToUpper(strings.TrimSpace(rest[0]))
path := strings.TrimSpace(rest[1])
if method == "" || path == "" {
return fmt.Errorf("usage: inbox api [flags] METHOD PATH")
}
if *addr == "" {
*addr = DefaultAPIURL
}
body, err := loadBody(*data, *file, deps.Stdin)
if err != nil {
return err
}
headers, err := buildHeaders(headerValues.values, body)
if err != nil {
return err
}
newClient := deps.NewClient
if newClient == nil {
newClient = func(baseURL string, httpClient *http.Client) APIClient {
return client.New(baseURL, httpClient)
}
}
ctx, cancel := context.WithTimeout(context.Background(), *timeout)
defer cancel()
resp, err := newClient(*addr, &http.Client{Timeout: *timeout}).Do(ctx, method, path, headers, body)
if err != nil {
return err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
payload := strings.TrimSpace(string(resp.Body))
if payload == "" {
return fmt.Errorf("api %s %s returned status %d", method, path, resp.StatusCode)
}
return fmt.Errorf("api %s %s returned status %d: %s", method, path, resp.StatusCode, payload)
}
stdout := deps.Stdout
if stdout == nil {
stdout = os.Stdout
}
if len(resp.Body) == 0 {
return nil
}
if _, err := stdout.Write(resp.Body); err != nil {
return fmt.Errorf("write response: %w", err)
}
if len(resp.Body) > 0 && resp.Body[len(resp.Body)-1] != '\n' {
if _, err := io.WriteString(stdout, "\n"); err != nil {
return fmt.Errorf("write trailing newline: %w", err)
}
}
return nil
}
func loadBody(data, file string, stdin io.Reader) ([]byte, error) {
if data != "" && file != "" {
return nil, fmt.Errorf("--data and --file cannot be used together")
}
switch {
case data != "":
return []byte(data), nil
case file == "":
return nil, nil
case file == "-":
if stdin == nil {
stdin = os.Stdin
}
body, err := io.ReadAll(stdin)
if err != nil {
return nil, fmt.Errorf("read stdin: %w", err)
}
return body, nil
default:
body, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("read request file: %w", err)
}
return body, nil
}
}
func buildHeaders(items []string, body []byte) (http.Header, error) {
headers := make(http.Header)
for _, item := range items {
key, value, ok := strings.Cut(item, ":")
if !ok {
return nil, fmt.Errorf("invalid header %q", item)
}
key = strings.TrimSpace(key)
value = strings.TrimSpace(value)
if key == "" {
return nil, fmt.Errorf("invalid header %q", item)
}
headers.Add(key, value)
}
if len(body) > 0 && headers.Get("Content-Type") == "" {
headers.Set("Content-Type", "application/json")
}
return headers, nil
}
func normalizeAPIArgs(args []string) []string {
if len(args) == 0 {
return nil
}
flags := make([]string, 0, len(args))
positionals := make([]string, 0, len(args))
for i := 0; i < len(args); i++ {
arg := args[i]
if !strings.HasPrefix(arg, "-") {
positionals = append(positionals, arg)
continue
}
flags = append(flags, arg)
if strings.Contains(arg, "=") {
continue
}
if takesFlagValue(arg) && i+1 < len(args) {
flags = append(flags, args[i+1])
i++
}
}
return append(flags, positionals...)
}
func takesFlagValue(flagName string) bool {
switch flagName {
case "--addr", "--data", "--file", "--timeout", "--header":
return true
default:
return false
}
}