Skip to content

Commit

Permalink
feat: show model param sizes (#179)
Browse files Browse the repository at this point in the history
* feat: param sizes

* feat: param sizes

* feat: param sizes

* minor updates
  • Loading branch information
sammcj authored Mar 8, 2025
1 parent 5755965 commit 5b28791
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 17 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ echo "alias g=gollama" >> ~/.zshrc
- `m`: Sort by modified
- `k`: Sort by quantisation
- `f`: Sort by family
- `B`: Sort by parameter size
- `l`: Link model to LM Studio
- `L`: Link all models to LM Studio
- `r`: Rename model _**(Work in progress)**_
Expand Down
53 changes: 50 additions & 3 deletions app_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"fmt"
"sort"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -240,6 +241,8 @@ func (m *AppModel) handleKeyMsg(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
return m.handleSortByQuantKey()
case key.Matches(msg, m.keys.SortByFamily):
return m.handleSortByFamilyKey()
case key.Matches(msg, m.keys.SortByParamSize):
return m.handleSortByParamSizeKey()
case key.Matches(msg, m.keys.RunModel):
return m.handleRunModelKey()
case key.Matches(msg, m.keys.AltScreen):
Expand Down Expand Up @@ -512,6 +515,41 @@ func (m *AppModel) handleSortByFamilyKey() (tea.Model, tea.Cmd) {
return m, nil
}

func (m *AppModel) handleSortByParamSizeKey() (tea.Model, tea.Cmd) {
logging.DebugLogger.Println("SortByParamSize key matched")
m.cfg.SortOrder = "paramsize"

// Helper function to extract numeric value from parameter size strings
getParamSizeValue := func(paramSize string) float64 {
if paramSize == "" {
return 0
}

// Remove the "B" suffix if present
numStr := paramSize
if len(paramSize) > 0 && paramSize[len(paramSize)-1] == 'B' {
numStr = paramSize[:len(paramSize)-1]
}

// Parse the numeric part
size, err := strconv.ParseFloat(numStr, 64)
if err != nil {
return 0
}
return size
}

// Sort models by parameter size (largest first)
sort.Slice(m.models, func(i, j int) bool {
sizeI := getParamSizeValue(m.models[i].ParameterSize)
sizeJ := getParamSizeValue(m.models[j].ParameterSize)
return sizeI > sizeJ
})

m.refreshList()
return m, nil
}

func (m *AppModel) handleRunModelKey() (tea.Model, tea.Cmd) {
logging.DebugLogger.Println("RunModel key matched")
if item, ok := m.list.SelectedItem().(Model); ok {
Expand Down Expand Up @@ -861,10 +899,19 @@ func (m *AppModel) inspectModelView(model Model) string {
{"Name", model.Name},
{"ID", model.ID},
{"Size (GB)", fmt.Sprintf("%.2f", model.Size)},
{"quantisation Level", model.QuantizationLevel},
}

// Add parameter size if available
if model.ParameterSize != "" {
rows = append(rows, table.Row{"Parameter Size", model.ParameterSize})
}

// Add remaining fields
rows = append(rows, []table.Row{
{"Quantisation Level", model.QuantizationLevel},
{"Modified", model.Modified.Format("2006-01-02")},
{"Family", model.Family},
}
}...)

// getModelParams returns a map of model parameters, so we need to iterate over the map and add the parameters to the rows
for key, value := range modelParams {
Expand Down Expand Up @@ -980,7 +1027,7 @@ func (m *AppModel) topView() string {
func (k KeyMap) FullHelp() [][]key.Binding {
return [][]key.Binding{
{k.Space, k.Delete, k.RunModel, k.LinkModel, k.LinkAllModels, k.CopyModel, k.PushModel}, // first column
{k.SortByName, k.SortBySize, k.SortByModified, k.SortByQuant, k.SortByFamily}, // second column
{k.SortByName, k.SortBySize, k.SortByModified, k.SortByQuant, k.SortByFamily, k.SortByParamSize}, // second column
{k.Top, k.EditModel, k.InspectModel, k.Quit}, // third column
}
}
Expand Down
50 changes: 40 additions & 10 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"os"
"strings"

"github.com/charmbracelet/lipgloss"
"github.com/sammcj/gollama/config"
"github.com/sammcj/gollama/logging"
"github.com/sammcj/gollama/styles"
Expand All @@ -26,6 +27,7 @@ func parseAPIResponse(resp *api.ListResponse) []Model {
QuantizationLevel: modelResp.Details.QuantizationLevel,
Family: modelResp.Details.Family,
Modified: modelResp.ModifiedAt,
ParameterSize: modelResp.Details.ParameterSize,
}
}
logging.DebugLogger.Println("Models:", models)
Expand All @@ -36,14 +38,18 @@ func normalizeSize(size float64) float64 {
return size // Sizes are already in GB in the API response
}

func calculateColumnWidths(totalWidth int) (nameWidth, sizeWidth, quantWidth, modifiedWidth, idWidth, familyWidth int) {
// Constant for parameter size column width
const minParamSizeWidth = 10

func calculateColumnWidths(totalWidth int) (nameWidth, sizeWidth, quantWidth, modifiedWidth, idWidth, familyWidth, paramSizeWidth int) {
// Calculate column widths
nameWidth = int(0.45 * float64(totalWidth))
nameWidth = int(0.40 * float64(totalWidth))
sizeWidth = int(0.05 * float64(totalWidth))
quantWidth = int(0.05 * float64(totalWidth))
familyWidth = int(0.05 * float64(totalWidth))
modifiedWidth = int(0.05 * float64(totalWidth))
idWidth = int(0.02 * float64(totalWidth))
paramSizeWidth = int(0.05 * float64(totalWidth))

// Set the absolute minimum width for each column
if nameWidth < minNameWidth {
Expand All @@ -64,10 +70,13 @@ func calculateColumnWidths(totalWidth int) (nameWidth, sizeWidth, quantWidth, mo
if familyWidth < minFamilyWidth {
familyWidth = minFamilyWidth
}
if paramSizeWidth < minParamSizeWidth {
paramSizeWidth = minParamSizeWidth
}

// If the total width is less than the sum of the minimum column widths, adjust the name column width and make sure all columns are aligned
if totalWidth < nameWidth+sizeWidth+quantWidth+familyWidth+modifiedWidth+idWidth {
nameWidth = totalWidth - sizeWidth - quantWidth - familyWidth - modifiedWidth - idWidth
if totalWidth < nameWidth+sizeWidth+quantWidth+familyWidth+modifiedWidth+idWidth+paramSizeWidth {
nameWidth = totalWidth - sizeWidth - quantWidth - familyWidth - modifiedWidth - idWidth - paramSizeWidth
}

return
Expand Down Expand Up @@ -109,7 +118,7 @@ func wrapText(text string, width int) string {
return wrapped
}

func calculateColumnWidthsTerminal() (nameWidth, sizeWidth, quantWidth, modifiedWidth, idWidth, familyWidth int) {
func calculateColumnWidthsTerminal() (nameWidth, sizeWidth, quantWidth, modifiedWidth, idWidth, familyWidth, paramSizeWidth int) {
// use the terminal width to calculate column widths
minWidth := 120

Expand Down Expand Up @@ -140,16 +149,17 @@ func listModels(models []Model) {
}

stripString := cfg.StripString
nameWidth, sizeWidth, quantWidth, modifiedWidth, idWidth, familyWidth := calculateColumnWidthsTerminal()
nameWidth, sizeWidth, quantWidth, modifiedWidth, idWidth, familyWidth, paramSizeWidth := calculateColumnWidthsTerminal()

// Add extra spacing between columns
colSpacing := 2
longestNameAllowed := 60

// Create the header with proper padding and alignment
header := fmt.Sprintf("%-*s%-*s%-*s%-*s%-*s%-*s",
header := fmt.Sprintf("%-*s%-*s%-*s%-*s%-*s%-*s%-*s",
nameWidth, "Name",
sizeWidth+colSpacing, "Size",
paramSizeWidth+colSpacing, "Params",
quantWidth+colSpacing, "Quant",
familyWidth+colSpacing, "Family",
modifiedWidth+colSpacing, "Modified",
Expand All @@ -163,7 +173,7 @@ func listModels(models []Model) {
}

// Prepare columns for padding
var names, sizes, quants, families, modified, ids []string
var names, sizes, quants, families, modified, ids, paramSizes []string
var longestName int
for _, model := range models {
if len(model.Name) > longestName {
Expand All @@ -175,6 +185,7 @@ func listModels(models []Model) {
}
names = append(names, model.Name)
sizes = append(sizes, fmt.Sprintf("%.2fGB", model.Size))
paramSizes = append(paramSizes, model.ParameterSize)
quants = append(quants, model.QuantizationLevel)
families = append(families, model.Family)
modified = append(modified, model.Modified.Format("2006-01-02"))
Expand All @@ -184,6 +195,7 @@ func listModels(models []Model) {
// Calculate maximum width for each column
maxNameWidth := nameWidth
maxSizeWidth := sizeWidth + colSpacing
maxParamSizeWidth := paramSizeWidth + colSpacing
maxQuantWidth := quantWidth + colSpacing
maxFamilyWidth := familyWidth + colSpacing
maxModifiedWidth := modifiedWidth + colSpacing
Expand All @@ -193,14 +205,21 @@ func listModels(models []Model) {
for i := range names {
names[i] = fmt.Sprintf("%-*s", maxNameWidth, names[i])
sizes[i] = fmt.Sprintf("%-*s", maxSizeWidth, sizes[i])
paramSizes[i] = fmt.Sprintf("%-*s", maxParamSizeWidth, paramSizes[i])
quants[i] = fmt.Sprintf("%-*s", maxQuantWidth, quants[i])
families[i] = fmt.Sprintf("%-*s", maxFamilyWidth, families[i])
modified[i] = fmt.Sprintf("%-*s", maxModifiedWidth, modified[i])
// if the longest name is more than longestNameAllowed characters, don't display the model sha
if longestName > longestNameAllowed {
ids[i] = ""
// remove the ID header
header = fmt.Sprintf("%-*s%-*s%-*s%-*s%-*s", nameWidth, "Name", sizeWidth+colSpacing, "Size", quantWidth+colSpacing, "Quant", familyWidth+colSpacing, "Family", modifiedWidth, "Modified")
header = fmt.Sprintf("%-*s%-*s%-*s%-*s%-*s%-*s",
nameWidth, "Name",
sizeWidth+colSpacing, "Size",
paramSizeWidth+colSpacing, "Params",
quantWidth+colSpacing, "Quant",
familyWidth+colSpacing, "Family",
modifiedWidth, "Modified")
} else {
ids[i] = fmt.Sprintf("%-*s", maxIdWidth, ids[i])
}
Expand All @@ -215,13 +234,24 @@ func listModels(models []Model) {
name := styles.ItemNameStyle(index).Render(names[index])
id := styles.ItemIDStyle().Render(ids[index])
size := styles.SizeStyle(model.Size).Render(sizes[index])
// Apply direct color based on parameter size
var paramSize string
if paramSizes[index] != "" {
// Format the string first
formattedParamSize := fmt.Sprintf("%-*s", maxParamSizeWidth, paramSizes[index])
// Apply color directly using paramSizeColour
paramSize = lipgloss.NewStyle().Foreground(paramSizeColour(paramSizes[index])).Render(formattedParamSize)
} else {
paramSize = fmt.Sprintf("%-*s", maxParamSizeWidth, paramSizes[index])
}
family := styles.FamilyStyle(model.Family).Render(families[index])
quant := styles.QuantStyle(model.QuantizationLevel).Render(quants[index])
modified := styles.ItemIDStyle().Render(modified[index])

row := fmt.Sprintf("%-*s%-*s%-*s%-*s%-*s%-*s",
row := fmt.Sprintf("%-*s%-*s%-*s%-*s%-*s%-*s%-*s",
maxNameWidth, name,
maxSizeWidth, size,
maxParamSizeWidth, paramSize,
maxQuantWidth, quant,
maxFamilyWidth, family,
maxModifiedWidth, modified,
Expand Down
7 changes: 4 additions & 3 deletions item_delegate.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,22 +109,23 @@ func (d itemDelegate) Render(w io.Writer, m list.Model, index int, item list.Ite
quantStyle = selectedStyle.Inherit(quantStyle)
}

nameWidth, sizeWidth, quantWidth, modifiedWidth, idWidth, familyWidth := calculateColumnWidths(m.Width())
nameWidth, sizeWidth, quantWidth, modifiedWidth, idWidth, familyWidth, paramSizeWidth := calculateColumnWidths(m.Width())

// Ensure the text fits within the terminal width
// Add consistent padding between columns
padding := 2
name := nameStyle.Width(nameWidth).Render(truncate(model.Name, nameWidth-padding))
size := sizeStyle.Width(sizeWidth).Render(fmt.Sprintf("%*.2fGB", sizeWidth-padding-2, model.Size))
paramSize := styles.ParamSizeStyle(model.ParameterSize).Width(paramSizeWidth).Render(fmt.Sprintf("%-*s", paramSizeWidth-padding, model.ParameterSize))
quant := quantStyle.Width(quantWidth).Render(fmt.Sprintf("%-*s", quantWidth-padding, model.QuantizationLevel))
family := familyStyle.Width(familyWidth).Render(fmt.Sprintf("%-*s", familyWidth-padding, model.Family))
modified := dateStyle.Width(modifiedWidth).Render(fmt.Sprintf("%-*s", modifiedWidth-padding, model.Modified.Format("2006-01-02")))
id := shaStyle.Width(idWidth).Render(fmt.Sprintf("%-*s", idWidth-padding, model.ID))

// Add padding between columns
spacer := strings.Repeat(" ", padding)
row := fmt.Sprintf("%s%s%s%s%s%s%s%s%s%s%s",
name, spacer, size, spacer, quant, spacer, family, spacer, modified, spacer, id)
row := fmt.Sprintf("%s%s%s%s%s%s%s%s%s%s%s%s%s",
name, spacer, size, spacer, paramSize, spacer, quant, spacer, family, spacer, modified, spacer, id)

fmt.Fprint(w, row)
}
2 changes: 2 additions & 0 deletions keymap.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type KeyMap struct {
SortByModified key.Binding
SortByQuant key.Binding
SortByFamily key.Binding
SortByParamSize key.Binding
RunModel key.Binding
ConfirmYes key.Binding
ConfirmNo key.Binding
Expand Down Expand Up @@ -65,6 +66,7 @@ func NewKeyMap() *KeyMap {
SortByFamily: key.NewBinding(key.WithKeys("f"), key.WithHelp("f", "^family")),
SortByModified: key.NewBinding(key.WithKeys("m"), key.WithHelp("m", "^modified")),
SortByName: key.NewBinding(key.WithKeys("n"), key.WithHelp("n", "^name")),
SortByParamSize: key.NewBinding(key.WithKeys("B"), key.WithHelp("B", "^params")),
SortByQuant: key.NewBinding(key.WithKeys("K"), key.WithHelp("K", "^quant")),
SortBySize: key.NewBinding(key.WithKeys("s"), key.WithHelp("s", "^size")),
Top: key.NewBinding(key.WithKeys("t"), key.WithHelp("t", "top")),
Expand Down
7 changes: 6 additions & 1 deletion model.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ type Model struct {
Modified time.Time
Selected bool
Family string
ParameterSize string
}

func (m Model) SelectedStr() string {
Expand All @@ -24,7 +25,11 @@ func (m Model) SelectedStr() string {
}

func (m Model) Description() string {
return fmt.Sprintf("ID: %s, Size: %.2f GB, Quant: %s, Modified: %s", m.ID, m.Size, m.QuantizationLevel, m.Modified.Format("2006-01-02"))
paramSizeStr := ""
if m.ParameterSize != "" {
paramSizeStr = fmt.Sprintf(", Parameters: %s", m.ParameterSize)
}
return fmt.Sprintf("ID: %s, Size: %.2f GB, Quant: %s%s, Modified: %s", m.ID, m.Size, m.QuantizationLevel, paramSizeStr, m.Modified.Format("2006-01-02"))
}

func (m Model) FilterValue() string {
Expand Down
29 changes: 29 additions & 0 deletions styles.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main

import (
"math"
"strconv"

"github.com/charmbracelet/lipgloss"
)
Expand Down Expand Up @@ -82,6 +83,34 @@ func sizeColour(size float64) lipgloss.Color {
return lipgloss.Color(synthGradient[index])
}

func paramSizeColour(paramSize string) lipgloss.Color {
// Extract the numeric part from parameter size strings like "7.6B", "32B", etc.
if paramSize == "" {
return lipgloss.Color(synthGradient[0])
}

// Remove the "B" suffix if present
numStr := paramSize
if paramSize[len(paramSize)-1] == 'B' {
numStr = paramSize[:len(paramSize)-1]
}

// Parse the numeric part
size, err := strconv.ParseFloat(numStr, 64)
if err != nil {
// Default to first color if parsing fails
return lipgloss.Color(synthGradient[0])
}

// Use logarithmic scale similar to sizeColour but adjusted for parameter sizes
// Parameter sizes typically range from 1B to 100B+
index := int(math.Log10(size+1) * 3)
if index >= len(synthGradient) {
index = len(synthGradient) - 1
}
return lipgloss.Color(synthGradient[index])
}

func familyColour(family string, index int) lipgloss.Color {
colour, exists := familyColours[family]
if !exists {
Expand Down
Loading

0 comments on commit 5b28791

Please sign in to comment.