diff --git a/calculator.go b/calculator.go index 0f21770..480862c 100644 --- a/calculator.go +++ b/calculator.go @@ -24,7 +24,11 @@ func (c *calculator) modulus(a Amount, d int64) Amount { return a % d } -func (c *calculator) allocate(a Amount, r, s int) Amount { +func (c *calculator) allocate(a Amount, r, s uint) Amount { + if a == 0 || s == 0 { + return 0 + } + return a * int64(r) / int64(s) } diff --git a/money.go b/money.go index de58e89..28fc431 100644 --- a/money.go +++ b/money.go @@ -265,16 +265,19 @@ func (m *Money) Allocate(rs ...int) ([]*Money, error) { } // Calculate sum of ratios. - var sum int + var sum uint for _, r := range rs { - sum += r + if r < 0 { + return nil, errors.New("negative ratios not allowed") + } + sum += uint(r) } var total int64 ms := make([]*Money, 0, len(rs)) for _, r := range rs { party := &Money{ - amount: mutate.calc.allocate(m.amount, r, sum), + amount: mutate.calc.allocate(m.amount, uint(r), sum), currency: m.currency, } @@ -282,6 +285,12 @@ func (m *Money) Allocate(rs ...int) ([]*Money, error) { total += party.amount } + // if the sum of all ratios is zero, then we just returns zeros and don't do anything + // with the leftover + if sum == 0 { + return ms, nil + } + // Calculate leftover value and divide to first parties. lo := m.amount - total sub := int64(1) diff --git a/money_test.go b/money_test.go index 0d9ac17..e336998 100644 --- a/money_test.go +++ b/money_test.go @@ -470,6 +470,10 @@ func TestMoney_Allocate(t *testing.T) { {100, []int{30, 30, 30}, []int64{34, 33, 33}}, {200, []int{25, 25, 50}, []int64{50, 50, 100}}, {5, []int{50, 25, 25}, []int64{3, 1, 1}}, + {0, []int{0, 0, 0, 0}, []int64{0, 0, 0, 0}}, + {0, []int{50, 10}, []int64{0, 0}}, + {10, []int{0, 100}, []int64{0, 10}}, + {10, []int{0, 0}, []int64{0, 0}}, } for _, tc := range tcs {