Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial implemention for plpgsql ast traversal #190

Merged
merged 1 commit into from
Aug 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions src/pgspot/path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# pylint: disable=fixme


class Path:
"""A path is a sequence of steps that will be executed in a PLpgSQL function."""

def __init__(self, root, steps=None, stack=None):
self.root = root
# steps is the list of nodes that have been processed
self.steps = steps.copy() if steps else []
# stack is a list of nodes that are yet to be processed
self.stack = stack.copy() if stack else []

def copy(self):
return Path(self.root, self.steps, self.stack)

def __str__(self):
return " -> ".join([str(step) for step in self.steps])


def paths(root):
p = Path(root)
pathes = []
dfs(root, p, pathes)
yield p

while pathes:
p = pathes.pop(0)
t = p.stack.pop(0)
dfs(t, p, pathes)
yield p


def dfs(node, path, pathes):
"""traverse tree depth first similar to how it would get executed"""
if not node:
return
if node:
match node.type:
case "PLpgSQL_function":
# This should be top level node and so stack should be empty
assert not path.stack
path.stack = [node.action] + path.stack
case "PLpgSQL_stmt_block":
# FIXME: Add support for exception handling
path.stack = node.body + path.stack
case "PLpgSQL_stmt_if":
path.steps.append(node)
if node.elsif_list:
for elsif in node.elsif_list:
alt = path.copy()
alt.stack = elsif.stmts + alt.stack
pathes.append(alt)
if node.else_body:
alt = path.copy()
alt.stack = node.else_body + alt.stack
pathes.append(alt)

path.stack = node.then_body + path.stack

# different types of loops
# FIXME: Add support for loop exit
case (
"PLpgSQL_stmt_loop"
| "PLpgSQL_stmt_while"
| "PLpgSQL_stmt_forc"
| "PLpgSQL_stmt_fori"
| "PLpgSQL_stmt_fors"
| "PLpgSQL_stmt_dynfors"
):
path.stack = node.body + path.stack

# nodes with no children
case (
"PLpgSQL_stmt_assert"
| "PLpgSQL_stmt_assign"
| "PLpgSQL_stmt_call"
| "PLpgSQL_stmt_close"
| "PLpgSQL_stmt_commit"
| "PLpgSQL_stmt_dynexecute"
| "PLpgSQL_stmt_execsql"
| "PLpgSQL_stmt_fetch"
| "PLpgSQL_stmt_getdiag"
| "PLpgSQL_stmt_open"
| "PLpgSQL_stmt_perform"
| "PLpgSQL_stmt_raise"
| "PLpgSQL_stmt_rollback"
):
path.steps.append(node)

# nodes not yet implemented
case (
"PLpgSQL_stmt_case"
| "PLpgSQL_stmt_exit"
| "PLpgSQL_stmt_forc"
| "PLpgSQL_stmt_foreach_a"
):
raise Exception(f"Not yet implemented {node.type}")

# nodes that will end current path
case (
"PLpgSQL_stmt_return"
| "PLpgSQL_stmt_return_next"
| "PLpgSQL_stmt_return_query"
):
path.steps.append(node)
path.stack.clear()
return

case _:
raise Exception(f"Unknown node type {node.type}")

while path.stack:
t = path.stack.pop(0)
dfs(t, path, pathes)
23 changes: 20 additions & 3 deletions src/pgspot/plpgsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
#
# Disable pylint warnings about class names cause we are trying to match the
# AST names used by PLpgSQL parser.
# pylint: disable-msg=C0103
# pylint: disable-msg=R0903
# pylint: disable-msg=invalid-name,too-few-public-methods


class PLpgSQLNode:
def __init__(self, raw):
self.type = list(raw.keys())[0]
self.lineno = None
self.lineno = ""
for k, v in raw[self.type].items():
setattr(self, k, build_node(v))

Expand All @@ -24,6 +23,24 @@ def __repr__(self):
return f"{self.type}({fields})"


class PLpgSQL_stmt_if(PLpgSQLNode):
def __init__(self, raw):
self.then_body = None
self.elsif_list = None
self.else_body = None
super().__init__(raw)


class PLpgSQL_row(PLpgSQLNode):
def __init__(self, raw):
# PLpgSQL_row has a fields attribute which is a list of dicts that
# don't have the same structure as other node dicts. So we pop it out
# and set it as an attribute directly instead of having it handled by
# recursion.
self.fields = raw["PLpgSQL_row"].pop("fields")
super().__init__(raw)


class PLpgSQL_var(PLpgSQLNode):
def __init__(self, raw):
self.refname = None
Expand Down
149 changes: 149 additions & 0 deletions tests/plpgsql_path_if_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from pglast import parse_plpgsql
from pgspot.plpgsql import build_node
from pgspot.path import paths


def test_if_minimal_stmt():
sql = """
CREATE FUNCTION foo(cmd TEXT) RETURNS void AS $$
BEGIN
IF EXISTS (SELECT FROM pg_stat_activity) THEN
EXECUTE cmd || '1';
END IF;
END
$$;
"""
parsed = parse_plpgsql(sql)
node = build_node(parsed[0])
assert node.type == "PLpgSQL_function"

pathes = list(paths(node))
assert len(pathes) == 1

assert (
str(pathes[0])
== "PLpgSQL_stmt_if(3) -> PLpgSQL_stmt_dynexecute(4) -> PLpgSQL_stmt_return()"
)


def test_if_else():
sql = """
CREATE FUNCTION foo(cmd TEXT) RETURNS void AS $$
BEGIN
IF EXISTS (SELECT FROM pg_stat_activity) THEN
EXECUTE cmd || '1';
ELSIF EXISTS (SELECT FROM pg_stat_activity) THEN
EXECUTE cmd || '2';
ELSIF EXISTS (SELECT FROM pg_stat_activity) THEN
EXECUTE cmd || '3';
ELSIF EXISTS (SELECT FROM pg_stat_activity) THEN
EXECUTE cmd || '4';
ELSE
EXECUTE cmd || '5';
END IF;
END
$$;
"""
parsed = parse_plpgsql(sql)
node = build_node(parsed[0])
assert node.type == "PLpgSQL_function"

pathes = list(paths(node))
assert len(pathes) == 5

assert (
str(pathes[0])
== "PLpgSQL_stmt_if(3) -> PLpgSQL_stmt_dynexecute(4) -> PLpgSQL_stmt_return()"
)
assert (
str(pathes[1])
== "PLpgSQL_stmt_if(3) -> PLpgSQL_stmt_dynexecute(6) -> PLpgSQL_stmt_return()"
)
assert (
str(pathes[2])
== "PLpgSQL_stmt_if(3) -> PLpgSQL_stmt_dynexecute(8) -> PLpgSQL_stmt_return()"
)
assert (
str(pathes[3])
== "PLpgSQL_stmt_if(3) -> PLpgSQL_stmt_dynexecute(10) -> PLpgSQL_stmt_return()"
)
assert (
str(pathes[4])
== "PLpgSQL_stmt_if(3) -> PLpgSQL_stmt_dynexecute(12) -> PLpgSQL_stmt_return()"
)


def test_if_stmt():
sql = """
CREATE FUNCTION foo(cmd TEXT) RETURNS void AS $$
BEGIN
IF EXISTS (SELECT 1 FROM pg_stat_activity) THEN
EXECUTE cmd || '1';
ELSE
EXECUTE cmd || '2';
RETURN 'foo';
END IF;
IF EXISTS (SELECT 1 FROM pg_stat_activity) THEN
EXECUTE cmd;
ELSE
EXECUTE cmd;
END IF;
END
$$;
"""
parsed = parse_plpgsql(sql)
node = build_node(parsed[0])
assert node.type == "PLpgSQL_function"

pathes = list(paths(node))
assert len(pathes) == 3

assert (
str(pathes[0])
== "PLpgSQL_stmt_if(3) -> PLpgSQL_stmt_dynexecute(4) -> PLpgSQL_stmt_if(9) -> PLpgSQL_stmt_dynexecute(10) -> PLpgSQL_stmt_return()"
)
assert (
str(pathes[1])
== "PLpgSQL_stmt_if(3) -> PLpgSQL_stmt_dynexecute(6) -> PLpgSQL_stmt_return(7)"
)
assert (
str(pathes[2])
== "PLpgSQL_stmt_if(3) -> PLpgSQL_stmt_dynexecute(4) -> PLpgSQL_stmt_if(9) -> PLpgSQL_stmt_dynexecute(12) -> PLpgSQL_stmt_return()"
)


def test_nested_if_stmt():
sql = """
CREATE FUNCTION foo(cmd TEXT) RETURNS void AS $$
BEGIN
IF EXISTS (SELECT FROM pg_stat_activity) THEN
EXECUTE cmd || '1';
ELSE
IF EXISTS (SELECT FROM pg_stat_activity) THEN
EXECUTE cmd || '2';
ELSE
EXECUTE cmd || '3';
END IF;
END IF;
END
$$;
"""
parsed = parse_plpgsql(sql)
node = build_node(parsed[0])
assert node.type == "PLpgSQL_function"

pathes = list(paths(node))
assert len(pathes) == 3

assert (
str(pathes[0])
== "PLpgSQL_stmt_if(3) -> PLpgSQL_stmt_dynexecute(4) -> PLpgSQL_stmt_return()"
)
assert (
str(pathes[1])
== "PLpgSQL_stmt_if(3) -> PLpgSQL_stmt_if(6) -> PLpgSQL_stmt_dynexecute(7) -> PLpgSQL_stmt_return()"
)
assert (
str(pathes[2])
== "PLpgSQL_stmt_if(3) -> PLpgSQL_stmt_if(6) -> PLpgSQL_stmt_dynexecute(9) -> PLpgSQL_stmt_return()"
)
Loading