Skip to content

Commit 1d82add

Browse files
committed
refactor: add missing test cases
* Validate that the trace middleware adds the expected traces * Validate that certificate hashes with both '/' and '+' characters are converted correctly
1 parent 1be65fb commit 1d82add

File tree

5 files changed

+146
-5
lines changed

5 files changed

+146
-5
lines changed

.talismanrc

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ fileignoreconfig:
1717
checksum: bcdef78dfc66e140acd32d12df1806b95c3541f51cc0208a43abff49552fdcd8
1818
- filename: gateway/registry/remote_test.go
1919
checksum: f5aa4dbb5e14d772613612eeb02df83ae4458875487e3c408eff2950b460c298
20+
- filename: gateway/server/tracing.go
21+
checksum: 10c205849723d591f5c90fbf0068fa5cf77b8e545e351bc538610b950bc18c3f
2022
- filename: gateway/server/ws.go
2123
checksum: 4cbde936242380603e07cf8bd049dbca9d1c3108843d10e58f588540176c6d23
2224
- filename: gateway/server/ws_test.go

gateway/registry/remote_test.go

+42
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,48 @@ func TestLookupCertificate(t *testing.T) {
6666
assert.Equal(t, want.Raw, got.Raw)
6767
}
6868

69+
func TestLookupCertificateWithSlashesAndPlusesInHash(t *testing.T) {
70+
var want *x509.Certificate
71+
var certHash [32]byte
72+
var b64CertHash string
73+
74+
count := 0
75+
for {
76+
count++
77+
want = generateCertificate(t)
78+
certHash = sha256.Sum256(want.Raw)
79+
b64CertHash = base64.StdEncoding.EncodeToString(certHash[:])
80+
81+
if strings.Contains(b64CertHash, "/") && strings.Contains(b64CertHash, "+") {
82+
break
83+
}
84+
}
85+
t.Logf("Generated %d certificates before finding one with slashes and pluses in the hash", count)
86+
87+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
88+
hashToMatch := strings.Replace(b64CertHash, "/", "_", -1)
89+
hashToMatch = strings.Replace(hashToMatch, "+", "-", -1)
90+
if r.URL.Path != fmt.Sprintf("/api/v0/certificate/%s", hashToMatch) {
91+
http.NotFound(w, r)
92+
return
93+
}
94+
block := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: want.Raw})
95+
blockWithNewlinesReplaced := strings.Replace(string(block), "\n", "\\n", -1)
96+
_, _ = w.Write([]byte(fmt.Sprintf(`{"certificate":"%s"}`, blockWithNewlinesReplaced)))
97+
}))
98+
defer server.Close()
99+
100+
reg := registry.RemoteRegistry{
101+
ManagerApiAddr: server.URL,
102+
}
103+
104+
got, err := reg.LookupCertificate(b64CertHash)
105+
require.NoError(t, err)
106+
require.NotNil(t, got)
107+
108+
assert.Equal(t, want.Raw, got.Raw)
109+
}
110+
69111
func generateCertificate(t *testing.T) *x509.Certificate {
70112
keyPair, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
71113
require.NoError(t, err)

gateway/server/middleware.go

+6-5
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,8 @@ import (
2020
func TraceRequest(tracer trace.Tracer) func(http.Handler) http.Handler {
2121
return func(h http.Handler) http.Handler {
2222
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
23-
slog.Info("websocket connection received", "path", r.URL.Path, "method", r.Method)
24-
slog.Info("processing connection", "uri", r.RequestURI)
25-
26-
newCtx, span := tracer.Start(r.Context(), fmt.Sprintf("%s %s", r.Method, r.URL.String()), trace.WithSpanKind(trace.SpanKindServer),
23+
newCtx, span := tracer.Start(r.Context(),
24+
fmt.Sprintf("%s %s", r.Method, r.URL.String()), trace.WithSpanKind(trace.SpanKindServer),
2725
trace.WithAttributes(
2826
semconv.HTTPScheme(getScheme(r)),
2927
semconv.HTTPMethod(r.Method),
@@ -35,11 +33,11 @@ func TraceRequest(tracer trace.Tracer) func(http.Handler) http.Handler {
3533
routePattern := chi.RouteContext(r.Context()).RoutePattern()
3634
if routePattern != "" {
3735
span.SetName(fmt.Sprintf("%s %s", r.Method, routePattern))
36+
span.SetAttributes(semconv.HTTPRoute(chi.RouteContext(r.Context()).RoutePattern()))
3837
} else {
3938
span.SetStatus(codes.Error, "not found")
4039
span.SetAttributes(semconv.HTTPStatusCode(http.StatusNotFound))
4140
}
42-
span.SetAttributes(semconv.HTTPRoute(chi.RouteContext(r.Context()).RoutePattern()))
4341
})
4442
}
4543
}
@@ -77,6 +75,9 @@ func TLSOffload(registry registry.DeviceRegistry) func(http.Handler) http.Handle
7775
span.SetAttributes(attribute.String("cert.lookup.error", "NotFound"))
7876
slog.Warn("certificate not found", "clientCertHashHeader", clientCertHashHeader)
7977
}
78+
} else {
79+
clientCertErrHeader := r.Header.Get("X-Client-Cert-Error")
80+
span.SetAttributes(attribute.String("cert.valid.error", clientCertErrHeader))
8081
}
8182
}
8283
}

gateway/server/middleware_test.go

+42
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,48 @@ import (
1212
"testing"
1313
)
1414

15+
func TestTraceMatchedRequest(t *testing.T) {
16+
tracer, traceExporter := server.GetTracer()
17+
18+
r := chi.NewRouter()
19+
r.Use(server.TraceRequest(tracer))
20+
r.HandleFunc("/id/{id}", func(w http.ResponseWriter, r *http.Request) {
21+
w.WriteHeader(http.StatusOK)
22+
})
23+
24+
req := httptest.NewRequest(http.MethodGet, "/id/1234", nil)
25+
rr := httptest.NewRecorder()
26+
r.ServeHTTP(rr, req)
27+
28+
server.AssertSpan(t, &traceExporter.GetSpans()[0], "GET /id/{id}", map[string]any{
29+
"http.scheme": "ws",
30+
"http.method": "GET",
31+
"http.url": "/id/1234",
32+
"http.route": "/id/{id}",
33+
})
34+
}
35+
36+
func TestTraceUnmatchedRequest(t *testing.T) {
37+
tracer, traceExporter := server.GetTracer()
38+
39+
r := chi.NewRouter()
40+
r.Use(server.TraceRequest(tracer))
41+
r.HandleFunc("/something", func(w http.ResponseWriter, r *http.Request) {
42+
w.WriteHeader(http.StatusOK)
43+
})
44+
45+
req := httptest.NewRequest(http.MethodGet, "/other", nil)
46+
rr := httptest.NewRecorder()
47+
r.ServeHTTP(rr, req)
48+
49+
server.AssertSpan(t, &traceExporter.GetSpans()[0], "GET /other", map[string]any{
50+
"http.scheme": "ws",
51+
"http.method": "GET",
52+
"http.url": "/other",
53+
"http.status_code": http.StatusNotFound,
54+
})
55+
}
56+
1557
func TestTLSOffloadWithNoClientCert(t *testing.T) {
1658
r := chi.NewRouter()
1759
r.Use(server.TLSOffload(registry.NewMockRegistry()))

gateway/server/tracing.go

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
package server
4+
5+
import (
6+
"github.com/stretchr/testify/assert"
7+
"go.opentelemetry.io/otel"
8+
tracesdk "go.opentelemetry.io/otel/sdk/trace"
9+
"go.opentelemetry.io/otel/sdk/trace/tracetest"
10+
"go.opentelemetry.io/otel/trace"
11+
"golang.org/x/exp/maps"
12+
"sort"
13+
"testing"
14+
)
15+
16+
func GetTracer() (trace.Tracer, *tracetest.InMemoryExporter) {
17+
traceExporter := tracetest.NewInMemoryExporter()
18+
tracerProvider := tracesdk.NewTracerProvider(
19+
tracesdk.WithSampler(tracesdk.AlwaysSample()),
20+
tracesdk.WithSyncer(traceExporter),
21+
)
22+
otel.SetTracerProvider(tracerProvider)
23+
24+
return tracerProvider.Tracer("test"), traceExporter
25+
}
26+
27+
func AssertSpan(t *testing.T, span *tracetest.SpanStub, name string, attributes map[string]any) {
28+
assert.Equal(t, name, span.Name)
29+
assert.Len(t, span.Attributes, len(attributes))
30+
var gotKeys []string
31+
for _, attr := range span.Attributes {
32+
gotKeys = append(gotKeys, string(attr.Key))
33+
want, ok := attributes[string(attr.Key)]
34+
if !ok {
35+
t.Errorf("unexpected attribute %s", attr.Key)
36+
}
37+
switch want.(type) {
38+
case string:
39+
assert.Equal(t, want, attr.Value.AsString())
40+
case int:
41+
assert.Equal(t, want, int(attr.Value.AsInt64()))
42+
case bool:
43+
assert.Equal(t, want, attr.Value.AsBool())
44+
case float64:
45+
assert.Equal(t, want, attr.Value.AsFloat64())
46+
default:
47+
t.Errorf("unsupported attribute type %T", want)
48+
}
49+
}
50+
sort.Strings(gotKeys)
51+
wantKeys := maps.Keys(attributes)
52+
sort.Strings(wantKeys)
53+
assert.Equal(t, wantKeys, gotKeys)
54+
}

0 commit comments

Comments
 (0)