diff --git a/checks/cassandra/check.go b/checks/cassandra/check.go index d37c375..a24d2a6 100644 --- a/checks/cassandra/check.go +++ b/checks/cassandra/check.go @@ -10,31 +10,29 @@ import ( // Config is the Cassandra checker configuration settings container. type Config struct { - // Hosts is a list of Cassandra hosts. At least one is required. + // Hosts is a list of Cassandra hosts. Optional if Session is supplied. Hosts []string - // Keyspace is the Cassandra keyspace to which you want to connect. Required. + // Keyspace is the Cassandra keyspace to which you want to connect. Optional if Session is supplied. Keyspace string + // Session is a gocql session and can be used in place of Hosts and Keyspace. Recommended. + // Optional if Hosts & Keyspace are supplied. + Session *gocql.Session } -// New creates new Cassandra health check that verifies the following: -// - that a connection can be established through creating a session -// - that queries can be executed by describing keyspaces +// New creates new Cassandra health check that verifies that a connection exists and can be used to query the cluster. func New(config Config) func(ctx context.Context) error { return func(ctx context.Context) error { - if len(config.Hosts) < 1 || len(config.Keyspace) < 1 { - return errors.New("keyspace name and hosts are required to initialize cassandra health check") + shutdown, session, err := initSession(config) + if err != nil { + return fmt.Errorf("cassandra health check failed on connect: %w", err) } - cluster := gocql.NewCluster(config.Hosts...) - cluster.Keyspace = config.Keyspace + defer shutdown() - session, err := cluster.CreateSession() if err != nil { return fmt.Errorf("cassandra health check failed on connect: %w", err) } - defer session.Close() - err = session.Query("DESCRIBE KEYSPACES;").WithContext(ctx).Exec() if err != nil { return fmt.Errorf("cassandra health check failed on describe: %w", err) @@ -43,3 +41,22 @@ func New(config Config) func(ctx context.Context) error { return nil } } + +func initSession(c Config) (func(), *gocql.Session, error) { + if c.Session != nil { + return func() {}, c.Session, nil + } + + if len(c.Hosts) < 1 || len(c.Keyspace) < 1 { + return nil, nil, errors.New("cassandra cluster config or keyspace name and hosts are required to initialize cassandra health check") + } + + cluster := gocql.NewCluster(c.Hosts...) + cluster.Keyspace = c.Keyspace + session, err := cluster.CreateSession() + if err != nil { + return nil, nil, err + } + + return session.Close, session, err +} diff --git a/checks/cassandra/check_test.go b/checks/cassandra/check_test.go index ea5a25a..2bdd3e6 100644 --- a/checks/cassandra/check_test.go +++ b/checks/cassandra/check_test.go @@ -26,6 +26,21 @@ func TestNew(t *testing.T) { require.NoError(t, err) } +func TestNew_withClusterConfig(t *testing.T) { + initDB(t) + cluster := gocql.NewCluster(getHosts(t)...) + cluster.Keyspace = KEYSPACE + session, err := cluster.CreateSession() + require.NoError(t, err) + + check := New(Config{ + Session: session, + }) + + err = check(context.Background()) + require.NoError(t, err) +} + func TestNewWithError(t *testing.T) { check := New(Config{})