@@ -7,10 +7,9 @@ use anyhow::{Context, Result};
7
7
use axum:: body:: { self , Body , Bytes } ;
8
8
use axum:: extract:: ws:: WebSocketUpgrade ;
9
9
use axum:: http:: header:: { HeaderName , CONTENT_LENGTH , CONTENT_TYPE , HOST } ;
10
- use axum:: http:: response:: Parts ;
11
10
use axum:: http:: { HeaderValue , Request , StatusCode } ;
12
11
use axum:: middleware:: Next ;
13
- use axum:: response:: Response ;
12
+ use axum:: response:: { IntoResponse , Response } ;
14
13
use axum:: routing:: { get, get_service, Router } ;
15
14
use axum_server:: tls_rustls:: RustlsConfig ;
16
15
use axum_server:: Handle ;
@@ -362,36 +361,61 @@ fn router(state: Arc<State>, cfg: Arc<RtcServe>) -> Result<Router> {
362
361
async fn html_address_middleware < B : std:: fmt:: Debug > (
363
362
request : Request < B > ,
364
363
next : Next < B > ,
365
- ) -> ( Parts , Bytes ) {
364
+ ) -> Response {
366
365
let uri = request. headers ( ) . get ( HOST ) . cloned ( ) ;
367
366
let response = next. run ( request) . await ;
367
+
368
+ // if it's not a success, we don't modify it
369
+ if !response. status ( ) . is_success ( ) {
370
+ return response;
371
+ }
372
+
373
+ // if it doesn't look like HTML, we ignore it too
374
+ let is_html = response
375
+ . headers ( )
376
+ . get ( CONTENT_TYPE )
377
+ . map ( |t| t == "text/html" )
378
+ . unwrap_or_default ( ) ;
379
+ if !is_html {
380
+ return response;
381
+ }
382
+
383
+ // split into parts and body
368
384
let ( parts, body) = response. into_parts ( ) ;
369
385
386
+ // turn the body into bytes
370
387
match hyper:: body:: to_bytes ( body) . await {
371
- Err ( _) => ( parts, Bytes :: default ( ) ) ,
388
+ Err ( err) => {
389
+ tracing:: debug!( "Unable to intercept: {err}" ) ;
390
+ ( parts, Bytes :: default ( ) ) . into_response ( )
391
+ }
372
392
Ok ( bytes) => {
373
- let ( mut parts, mut bytes) = ( parts, bytes) ;
374
-
375
- // turn into a string literal, or replace with "current host" on the client side
376
- let uri = uri
377
- . and_then ( |uri| uri. to_str ( ) . map ( |s| format ! ( "'{}'" , s) ) . ok ( ) )
378
- . unwrap_or_else ( || "window.location.host" . into ( ) ) ;
379
-
380
- if parts
381
- . headers
382
- . get ( CONTENT_TYPE )
383
- . map ( |t| t == "text/html" )
384
- . unwrap_or ( false )
385
- {
386
- if let Ok ( data_str) = std:: str:: from_utf8 ( & bytes) {
387
- let data_str = data_str. replace ( "'{{__TRUNK_ADDRESS__}}'" , & uri) ;
393
+ let mut parts = parts;
394
+ let mut bytes = bytes;
395
+
396
+ match std:: str:: from_utf8 ( & bytes) {
397
+ Ok ( data_str) => {
398
+ tracing:: debug!( "Replacing variable" ) ;
399
+
400
+ // turn into a string literal, or replace with "current host" on the client side
401
+ let uri = uri
402
+ . and_then ( |uri| uri. to_str ( ) . map ( |s| format ! ( "'{}'" , s) ) . ok ( ) )
403
+ . unwrap_or_else ( || "window.location.host" . into ( ) ) ;
404
+
405
+ let data_str = data_str
406
+ . replace ( "'{{__TRUNK_ADDRESS__}}'" , & uri)
407
+ // minification will turn that into backticks
408
+ . replace ( "`{{__TRUNK_ADDRESS__}}`" , & uri) ;
388
409
let bytes_vec = data_str. as_bytes ( ) . to_vec ( ) ;
389
410
parts. headers . insert ( CONTENT_LENGTH , bytes_vec. len ( ) . into ( ) ) ;
390
411
bytes = Bytes :: from ( bytes_vec) ;
391
412
}
413
+ Err ( err) => {
414
+ tracing:: debug!( "Unable to parse for injecting: {err}" ) ;
415
+ }
392
416
}
393
417
394
- ( parts, bytes)
418
+ ( parts, bytes) . into_response ( )
395
419
}
396
420
}
397
421
}
0 commit comments