mirror of
https://github.com/prymitive/karma
synced 2026-05-07 03:26:52 +00:00
356 lines
9.1 KiB
Go
356 lines
9.1 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"encoding/pem"
|
|
"errors"
|
|
"fmt"
|
|
"math/big"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"path"
|
|
"strconv"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/beme/abide"
|
|
"github.com/rogpeppe/go-internal/testscript"
|
|
"github.com/rs/zerolog/log"
|
|
"github.com/spf13/pflag"
|
|
)
|
|
|
|
func mainShoulFail() int {
|
|
initLogger()
|
|
err := serve(pflag.ContinueOnError)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("Execution failed")
|
|
return 0
|
|
}
|
|
log.Error().Msg("No error logged")
|
|
return 100
|
|
}
|
|
|
|
func mainShouldWork() int {
|
|
initLogger()
|
|
err := serve(pflag.ContinueOnError)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("Execution failed")
|
|
return 100
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func TestMain(m *testing.M) {
|
|
ecode := testscript.RunMain(m, map[string]func() int{
|
|
"karma.bin-should-fail": mainShoulFail,
|
|
"karma.bin-should-work": mainShouldWork,
|
|
})
|
|
err := abide.Cleanup()
|
|
if err != nil {
|
|
fmt.Printf("abide.Cleanup() error: %v\n", err)
|
|
ecode = 1
|
|
}
|
|
os.Exit(ecode)
|
|
}
|
|
|
|
func TestScripts(t *testing.T) {
|
|
testscript.Run(t, testscript.Params{
|
|
Dir: "tests/testscript",
|
|
UpdateScripts: os.Getenv("UPDATE_SNAPSHOTS") == "1",
|
|
Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
|
|
"http": httpServer,
|
|
"cert": tlsCert,
|
|
},
|
|
Setup: func(env *testscript.Env) error {
|
|
env.Values["mocks"] = &httpMocks{responses: map[string][]httpMock{}}
|
|
return nil
|
|
},
|
|
})
|
|
}
|
|
|
|
func httpServer(ts *testscript.TestScript, _ bool, args []string) {
|
|
mocks := ts.Value("mocks").(*httpMocks)
|
|
|
|
if len(args) == 0 {
|
|
ts.Fatalf("! http command requires arguments")
|
|
}
|
|
cmd := args[0]
|
|
|
|
ts.Logf("test-script http command: %s", strings.Join(args, " "))
|
|
|
|
switch cmd {
|
|
// http response name /200 200 OK
|
|
case "response":
|
|
if len(args) < 5 {
|
|
ts.Fatalf("! http response command requires '$NAME $PATH $CODE $BODY' args, got [%s]", strings.Join(args, " "))
|
|
}
|
|
name := args[1]
|
|
path := args[2]
|
|
code, err := strconv.Atoi(args[3])
|
|
ts.Check(err)
|
|
body := strings.Join(args[4:], " ")
|
|
mocks.add(name, httpMock{pattern: path, handler: func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(code)
|
|
_, err := w.Write([]byte(body))
|
|
ts.Check(err)
|
|
}})
|
|
// http response name /200 200 OK
|
|
case "slow-response":
|
|
if len(args) < 6 {
|
|
ts.Fatalf("! http response command requires '$NAME $PATH $DELAY $CODE $BODY' args, got [%s]", strings.Join(args, " "))
|
|
}
|
|
name := args[1]
|
|
path := args[2]
|
|
delay, err := time.ParseDuration(args[3])
|
|
ts.Check(err)
|
|
code, err := strconv.Atoi(args[4])
|
|
ts.Check(err)
|
|
body := strings.Join(args[5:], " ")
|
|
mocks.add(name, httpMock{pattern: path, handler: func(w http.ResponseWriter, r *http.Request) {
|
|
time.Sleep(delay)
|
|
w.WriteHeader(code)
|
|
_, err := w.Write([]byte(body))
|
|
ts.Check(err)
|
|
}})
|
|
// http redirect name /foo/src /dst
|
|
case "redirect":
|
|
if len(args) != 4 {
|
|
ts.Fatalf("! http redirect command requires '$NAME $SRCPATH $DSTPATH' args, got [%s]", strings.Join(args, " "))
|
|
}
|
|
name := args[1]
|
|
srcpath := args[2]
|
|
dstpath := args[3]
|
|
mocks.add(name, httpMock{pattern: srcpath, handler: func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Location", dstpath)
|
|
w.WriteHeader(http.StatusFound)
|
|
}})
|
|
// http start name 127.0.0.1:7088 [TLS cert] [TLS key]
|
|
case "start":
|
|
if len(args) < 3 {
|
|
ts.Fatalf("! http start command requires '$NAME $LISTEN [$TLS_CERT $TLS_KEY]' args, got [%s]", strings.Join(args, " "))
|
|
}
|
|
name := args[1]
|
|
listen := args[2]
|
|
var isTLS bool
|
|
var tlsCert, tlsKey string
|
|
if len(args) == 5 {
|
|
isTLS = true
|
|
tlsCert = args[3]
|
|
tlsKey = args[4]
|
|
}
|
|
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/__ping__", func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte("pong"))
|
|
})
|
|
for n, mockList := range mocks.responses {
|
|
if n == name {
|
|
for _, mock := range mockList {
|
|
mock := mock
|
|
mux.HandleFunc(mock.pattern, mock.handler)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
|
|
listener, err := net.Listen("tcp", listen)
|
|
ts.Check(err)
|
|
server := &http.Server{Addr: listen, Handler: mux}
|
|
sChan := make(chan struct{}, 1)
|
|
sCtx, sCancel := context.WithTimeout(context.Background(), time.Minute*5)
|
|
|
|
ts.Defer(func() {
|
|
ts.Logf("running deferred http server shutdown")
|
|
defer sCancel()
|
|
sChan <- struct{}{}
|
|
ts.Logf("http server %s shutting down", name)
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
_ = server.Shutdown(ctx)
|
|
})
|
|
|
|
go func() {
|
|
select {
|
|
case <-sCtx.Done():
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
_ = server.Shutdown(ctx)
|
|
ts.Fatalf("http server %s running for too long, forcing a shutdown: %s", name, sCtx.Err())
|
|
case <-sChan:
|
|
return
|
|
}
|
|
}()
|
|
|
|
go func() {
|
|
if isTLS {
|
|
err = server.ServeTLS(listener, tlsCert, tlsKey)
|
|
} else {
|
|
err = server.Serve(listener)
|
|
}
|
|
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
ts.Fatalf("http server failed to start: %s", err)
|
|
}
|
|
}()
|
|
|
|
tr := &http.Transport{
|
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
|
}
|
|
client := &http.Client{Transport: tr}
|
|
|
|
try := 0
|
|
for {
|
|
try++
|
|
if try > 30 {
|
|
ts.Fatalf("HTTP server didn't pass healt checks after %d check(s)", try)
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
|
|
|
uri := fmt.Sprintf("http://%s/__ping__", listen)
|
|
if isTLS {
|
|
uri = fmt.Sprintf("https://%s/__ping__", listen)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri, nil)
|
|
ts.Check(err)
|
|
|
|
resp, err := client.Do(req)
|
|
if err == nil && resp.StatusCode == http.StatusOK {
|
|
ts.Logf("http server ready after %d check(s)", try)
|
|
cancel()
|
|
return
|
|
}
|
|
|
|
cancel()
|
|
time.Sleep(time.Second)
|
|
}
|
|
default:
|
|
ts.Fatalf("! unknown http command: %v", args)
|
|
}
|
|
}
|
|
|
|
type httpMock struct {
|
|
pattern string
|
|
handler func(http.ResponseWriter, *http.Request)
|
|
}
|
|
|
|
type httpMocks struct {
|
|
responses map[string][]httpMock
|
|
}
|
|
|
|
func (m *httpMocks) add(name string, mock httpMock) {
|
|
if _, ok := m.responses[name]; !ok {
|
|
m.responses[name] = []httpMock{}
|
|
}
|
|
m.responses[name] = append(m.responses[name], mock)
|
|
}
|
|
|
|
func tlsCert(ts *testscript.TestScript, _ bool, args []string) {
|
|
if len(args) < 2 {
|
|
ts.Fatalf("! cert command requires '$DIRNAME $NAME' args, got [%s]", strings.Join(args, " "))
|
|
}
|
|
dirname := args[0]
|
|
name := args[1]
|
|
|
|
ts.Logf("test-script cert command: %s", strings.Join(args, " "))
|
|
|
|
ca := &x509.Certificate{
|
|
SerialNumber: big.NewInt(2019),
|
|
Subject: pkix.Name{
|
|
Organization: []string{"Company, INC."},
|
|
Country: []string{"US"},
|
|
Province: []string{""},
|
|
Locality: []string{"San Francisco"},
|
|
StreetAddress: []string{""},
|
|
PostalCode: []string{""},
|
|
},
|
|
NotBefore: time.Now(),
|
|
NotAfter: time.Now().AddDate(10, 0, 0),
|
|
IsCA: true,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
|
BasicConstraintsValid: true,
|
|
}
|
|
|
|
caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
|
|
if err != nil {
|
|
ts.Fatalf("failed to generate CA private key: %s", err)
|
|
}
|
|
|
|
caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey)
|
|
if err != nil {
|
|
ts.Fatalf("failed to generate CA cert: %s", err)
|
|
}
|
|
|
|
writeCert(ts, dirname, name+"-ca.pem", &pem.Block{
|
|
Type: "CERTIFICATE",
|
|
Bytes: caBytes,
|
|
})
|
|
writeCert(ts, dirname, name+"-ca.key", &pem.Block{
|
|
Type: "RSA PRIVATE KEY",
|
|
Bytes: x509.MarshalPKCS1PrivateKey(caPrivKey),
|
|
})
|
|
|
|
cert := &x509.Certificate{
|
|
SerialNumber: big.NewInt(1658),
|
|
Subject: pkix.Name{
|
|
Organization: []string{""},
|
|
Country: []string{""},
|
|
Province: []string{""},
|
|
Locality: []string{""},
|
|
StreetAddress: []string{""},
|
|
PostalCode: []string{""},
|
|
},
|
|
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
|
|
NotBefore: time.Now(),
|
|
NotAfter: time.Now().AddDate(0, 0, 1),
|
|
SubjectKeyId: []byte{1, 2, 3, 4, 6},
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
|
KeyUsage: x509.KeyUsageDigitalSignature,
|
|
}
|
|
|
|
certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
|
|
if err != nil {
|
|
ts.Fatalf("failed to generate cert private key: %s", err)
|
|
}
|
|
|
|
certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &certPrivKey.PublicKey, caPrivKey)
|
|
if err != nil {
|
|
ts.Fatalf("failed to generate cert: %s", err)
|
|
}
|
|
|
|
writeCert(ts, dirname, name+".pem", &pem.Block{
|
|
Type: "CERTIFICATE",
|
|
Bytes: certBytes,
|
|
})
|
|
writeCert(ts, dirname, name+".key", &pem.Block{
|
|
Type: "RSA PRIVATE KEY",
|
|
Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
|
|
})
|
|
}
|
|
|
|
func writeCert(ts *testscript.TestScript, dirname, filename string, block *pem.Block) {
|
|
fullpath := path.Join(dirname, filename)
|
|
|
|
f, err := os.Create(fullpath)
|
|
if err != nil {
|
|
ts.Fatalf("failed to write %s: %s", fullpath, err)
|
|
}
|
|
|
|
if err = pem.Encode(f, block); err != nil {
|
|
ts.Fatalf("failed to encode %s: %s", fullpath, err)
|
|
}
|
|
|
|
if err = f.Close(); err != nil {
|
|
ts.Fatalf("failed to close %s: %s", fullpath, err)
|
|
}
|
|
}
|