Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gather input stats for url pull sessions #192

Merged
merged 2 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions pkg/media/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/livekit/ingress/pkg/media/urlpull"
"github.com/livekit/ingress/pkg/media/whip"
"github.com/livekit/ingress/pkg/params"
"github.com/livekit/ingress/pkg/stats"
"github.com/livekit/ingress/pkg/types"
"github.com/livekit/protocol/ingress"
"github.com/livekit/protocol/livekit"
Expand All @@ -43,6 +44,8 @@ type Source interface {
type Input struct {
lock sync.Mutex

statsReporter *stats.MediaStatsReporter

bin *gst.Bin
source Source

Expand All @@ -69,6 +72,16 @@ func NewInput(ctx context.Context, p *params.Params) (*Input, error) {
closeFuse: core.NewFuse(),
}

if p.InputType == livekit.IngressInput_URL_INPUT {
// Gather input stats from the pipeline

statsUpdater := &stats.LocalStatsUpdater{
Params: p,
}

i.statsReporter = stats.NewMediaStats(statsUpdater)
}

srcs := src.GetSources()
if len(srcs) == 0 {
return nil, errors.ErrSourceNotReady
Expand Down Expand Up @@ -178,6 +191,11 @@ func (i *Input) onPadAdded(_ *gst.Element, pad *gst.Pad) {
return
}
pad = ghostPad.Pad

if i.statsReporter != nil {
// Gather bitrate stats from pipeline itself
i.addBitrateProbe(kind)
}
} else {
var sink *gst.Element

Expand All @@ -197,3 +215,45 @@ func (i *Input) onPadAdded(_ *gst.Element, pad *gst.Pad) {
i.onOutputReady(pad, kind)
}
}

func (i *Input) addBitrateProbe(kind types.StreamKind) {
// Do a best effort to add probe to retrieve bitrate.
// The multiqueue is generally created in the pipeline before the decoders
mq, err := i.bin.GetElementByName("multiqueue0")

if err != nil {
// No multiqueue in that pipeline
logger.Debugw("could not retrieve multiqueue element from pipeline", "error", err)
return
}

pads, err := mq.GetSinkPads()
if err != nil {
logger.Errorw("failed retrieving multiqueue sink pads", err)
return
}

for _, pad := range pads {
caps := pad.GetCurrentCaps()
gstStruct := caps.GetStructureAt(0)
padKind := getKindFromGstMimeType(gstStruct)

if padKind == kind {
pad.AddProbe(gst.PadProbeTypeBuffer, func(pad *gst.Pad, info *gst.PadProbeInfo) gst.PadProbeReturn {
buffer := info.GetBuffer()
if buffer == nil {
return gst.PadProbeOK
}

size := buffer.GetSize()
i.statsReporter.MediaReceived(kind, size)

return gst.PadProbeOK
})

return
}
}

logger.Debugw("no pad on multiqueue with required kind found", "kind", kind)
}
16 changes: 2 additions & 14 deletions pkg/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ func (s *Service) HandleWHIPPublishRequest(streamKey, resourceId string, ihs rpc
pRes.params.SetStatus(livekit.IngressState_ENDPOINT_PUBLISHING, "")
pRes.params.SendStateUpdate(ctx)

s.sm.IngressStarted(pRes.params.IngressInfo, &localSessionAPI{params: pRes.params})
s.sm.IngressStarted(pRes.params.IngressInfo, &localSessionAPI{stats.LocalStatsUpdater{Params: pRes.params}})
} else {
pRes.params.SetExtraParams(&params.WhipExtraParams{
MimeTypes: mimeTypes,
Expand Down Expand Up @@ -509,7 +509,7 @@ func (s *Service) HealthHandler(w http.ResponseWriter, r *http.Request) {
}

type localSessionAPI struct {
params *params.Params
stats.LocalStatsUpdater
}

func (a *localSessionAPI) GetProfileData(ctx context.Context, profileName string, timeout int, debug int) (b []byte, err error) {
Expand All @@ -521,18 +521,6 @@ func (a *localSessionAPI) GetPipelineDot(ctx context.Context) (string, error) {
return "", errors.ErrIngressNotFound
}

func (a *localSessionAPI) UpdateMediaStats(ctx context.Context, s *types.MediaStats) error {
if s.AudioAverageBitrate != nil && s.AudioCurrentBitrate != nil {
a.params.SetInputAudioBitrate(*s.AudioAverageBitrate, *s.AudioCurrentBitrate)
}

if s.VideoAverageBitrate != nil && s.VideoCurrentBitrate != nil {
a.params.SetInputVideoBitrate(*s.VideoAverageBitrate, *s.VideoCurrentBitrate)
}

return nil
}

func RegisterIngressRpcHandlers(server rpc.IngressHandlerServer, info *livekit.IngressInfo) error {
if err := server.RegisterUpdateIngressTopic(info.IngressId); err != nil {
return err
Expand Down
29 changes: 23 additions & 6 deletions pkg/stats/media.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,22 @@ import (

"github.com/frostbyte73/core"

"github.com/livekit/ingress/pkg/params"
"github.com/livekit/ingress/pkg/types"
)

type MediaStatsReporter struct {
sessionAPI types.SessionAPI
statsUpdater types.MediaStatsUpdater

lock sync.Mutex
done core.Fuse
stats map[types.StreamKind]*trackStats
}

type LocalStatsUpdater struct {
Params *params.Params
}

type trackStats struct {
totalBytes int64
startTime time.Time
Expand All @@ -26,11 +31,11 @@ type trackStats struct {
lastQueryTime time.Time
}

func NewMediaStats(sessionAPI types.SessionAPI) *MediaStatsReporter {
func NewMediaStats(statsUpdater types.MediaStatsUpdater) *MediaStatsReporter {
m := &MediaStatsReporter{
sessionAPI: sessionAPI,
stats: make(map[types.StreamKind]*trackStats),
done: core.NewFuse(),
statsUpdater: statsUpdater,
stats: make(map[types.StreamKind]*trackStats),
done: core.NewFuse(),
}

go func() {
Expand Down Expand Up @@ -97,7 +102,7 @@ func (m *MediaStatsReporter) updateIngressState(ctx context.Context) {
ms.VideoCurrentBitrate = &videoCurrentBps
}

m.sessionAPI.UpdateMediaStats(ctx, ms)
m.statsUpdater.UpdateMediaStats(ctx, ms)
}

func (s *trackStats) mediaReceived(size int64) {
Expand All @@ -123,3 +128,15 @@ func (s *trackStats) getStats() (uint32, uint32) {

return averageBps, currentBps
}

func (a *LocalStatsUpdater) UpdateMediaStats(ctx context.Context, s *types.MediaStats) error {
if s.AudioAverageBitrate != nil && s.AudioCurrentBitrate != nil {
a.Params.SetInputAudioBitrate(*s.AudioAverageBitrate, *s.AudioCurrentBitrate)
}

if s.VideoAverageBitrate != nil && s.VideoCurrentBitrate != nil {
a.Params.SetInputVideoBitrate(*s.VideoAverageBitrate, *s.VideoCurrentBitrate)
}

return nil
}
11 changes: 9 additions & 2 deletions pkg/types/session_api.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
package types

import "context"
import (
"context"
)

type MediaStatsUpdater interface {
UpdateMediaStats(ctx context.Context, stats *MediaStats) error
}

type SessionAPI interface {
MediaStatsUpdater

GetProfileData(ctx context.Context, profileName string, timeout int, debug int) (b []byte, err error)
GetPipelineDot(ctx context.Context) (string, error)
UpdateMediaStats(ctx context.Context, stats *MediaStats) error
}

type MediaStats struct {
Expand Down