Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: generate a dictionary of schema matching information when interrogating a col_schema_match() validation step #42

Merged
merged 17 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
378 changes: 376 additions & 2 deletions pointblank/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

from dataclasses import dataclass

import narwhals as nw

from pointblank._utils import _get_tbl_type, _is_lib_present
from pointblank._constants import IBIS_BACKENDS

Expand Down Expand Up @@ -680,3 +678,379 @@ def _process_columns(
return columns

return list(kwargs.items())


def _schema_info_generate_colname_dict(
colname_matched: bool,
matched_to: str | None,
dtype_present: bool,
dtype_input: str | list[str],
dtype_matched: bool,
dtype_multiple: bool,
dtype_matched_pos: int,
) -> dict[str, any]:

return {
"colname_matched": colname_matched,
"matched_to": matched_to,
"dtype_present": dtype_present,
"dtype_input": dtype_input,
"dtype_matched": dtype_matched,
"dtype_multiple": dtype_multiple,
"dtype_matched_pos": dtype_matched_pos,
}


def _schema_info_generate_columns_dict(
colnames: list[str] | None,
colname_dict: list[dict[str, any]] | None,
) -> dict[str, dict[str, any]]:
"""
Generate the columns dictionary for the schema information dictionary.

Parameters
----------
colnames
A list of column names. The columns included are those of the user-supplied schema.
colname_dict
A list of dictionaries containing column name information. The columns included are
those of the user-supplied schema.

Returns
-------
dict[str, dict[str, any]]
The columns dictionary.
"""
return {colnames[i]: colname_dict[i] for i in range(len(colnames))}


def _schema_info_generate_params_dict(
complete: bool,
in_order: bool,
case_sensitive_colnames: bool,
case_sensitive_dtypes: bool,
full_match_dtypes: bool,
) -> dict[str, any]:
"""
Generate the parameters dictionary for the schema information dictionary.

Parameters
----------
complete
Whether the schema is complete.
in_order
Whether the schema is in order.
case_sensitive_colnames
Whether column names are case-sensitive.
case_sensitive_dtypes
Whether data types are case-sensitive.
full_match_dtypes
Whether data types must match exactly.

Returns
-------
dict[str, any]
The parameters dictionary.
"""

return {
"complete": complete,
"in_order": in_order,
"case_sensitive_colnames": case_sensitive_colnames,
"case_sensitive_dtypes": case_sensitive_dtypes,
"full_match_dtypes": full_match_dtypes,
}


def _get_schema_validation_info(
data_tbl: any,
schema: Schema,
passed: bool,
complete: bool,
in_order: bool,
case_sensitive_colnames: bool,
case_sensitive_dtypes: bool,
full_match_dtypes: bool,
) -> dict[str, any]:
"""
Get the schema validation information dictionary.

Parameters
----------
schema_exp
The expected schema.
schema_tgt
The target schema.

Returns
-------
dict[str, any]
The schema validation information dictionary.

Explanation of the schema validation information dictionary
----------------------------------------------------------

This is how the schema validation information dictionary is structured:

- passed: bool # Whether the schema validation passed
- params: dict[str, any] # Parameters used in the schema validation
- complete: bool # Whether the schema should be complete
- in_order: bool # Whether the schema should be in order
- case_sensitive_colnames: bool # Whether column names are case-sensitive
- case_sensitive_dtypes: bool # Whether data types are case-sensitive
- full_match_dtypes: bool # Whether data types must match exactly or partially
- columns_found: list[str] # Columns in the target table found in the schema
- columns_not_found: list[str] # Columns not found in the target table (from schema)
- columns_unmatched: list[str] # Columns in the schema unmatched in the target table
- columns_full_set: bool # Full set of columns is matched (w/ no extra columns)
- columns_subset: bool # Subset of columns is matched (w/ no extra columns)
- columns_matched_in_order: bool # Whether columns are matched in order
- columns_matched_any_order: bool # Whether columns are matched in any order
- columns: dict[str, dict[str, any]] # Column information dictionary
- {colname}: str # Column name in the expected schema
- colname_matched: bool # Whether the column name is matched to the target table
- matched_to: str # Column name in the target table
- dtype_present: bool # Whether a dtype is present in the expected schema
- dtype_input: [dtype] # dtypes provided in the expected schema
- dtype_matched: bool # Is there a dtype match to the target table column?
- dtype_multiple: bool # Are there multiple dtypes in the expected schema?
- dtype_matched_pos: int # Position of the matched dtype in the expected schema
"""

schema_exp = schema
schema_tgt = Schema(tbl=data_tbl)

# Initialize the schema information dictionary
schema_info = {
"passed": passed,
"params": {},
"columns_found": [],
"columns_not_found": [],
"columns_unmatched": [],
"columns_full_set": False,
"columns_subset": False,
"columns_matched_in_order": False,
"columns_matched_any_order": False,
}

# Generate the parameters dictionary
schema_info["params"] = _schema_info_generate_params_dict(
complete=complete,
in_order=in_order,
case_sensitive_colnames=case_sensitive_colnames,
case_sensitive_dtypes=case_sensitive_dtypes,
full_match_dtypes=full_match_dtypes,
)

# Get the columns of the target table
tgt_colnames = schema_tgt.get_column_list()

# Get the columns of the expected schema
exp_colnames = schema_exp.get_column_list()

# Create a mapping of lowercased column names to original names in the target table schema
tgt_colname_mapping = {col.lower(): col for col in tgt_colnames}

if case_sensitive_colnames:

# Which columns are in both the target table and the expected schema?
columns_found = [col for col in exp_colnames if col in tgt_colnames]

# Which columns from the expected schema aren't in the target table?
columns_unmatched = [col for col in exp_colnames if col not in tgt_colnames]

# Which columns are in the target table but not in the expected schema?
columns_not_found = [col for col in tgt_colnames if col not in exp_colnames]

else:

# Convert expected column names to lowercase for case-insensitive comparison
exp_colnames_lower = [col.lower() for col in exp_colnames]

# Which columns are in both the target table and the expected schema?
columns_found = [
tgt_colname_mapping[col.lower()] for col in exp_colnames if col.lower() in tgt_colnames
]

# Which columns from the expected schema aren't in the target table?
columns_unmatched = [col for col in exp_colnames if col.lower() not in tgt_colnames]

# Which columns are in the target table but not in the expected schema?
columns_not_found = [col for col in tgt_colnames if col.lower() not in exp_colnames_lower]

# Sort `columns_found` based on the order of tgt_colnames
columns_found_sorted = sorted(columns_found, key=lambda col: tgt_colnames.index(col))

# Update the schema information dictionary
schema_info["columns_found"] = columns_found_sorted
schema_info["columns_not_found"] = columns_not_found
schema_info["columns_unmatched"] = columns_unmatched

# If the number of columns matched is the same as the number of columns in the expected schema,
# test if:
# - all columns are matched in the target table in the same order
# - all columns are matched in the target table in any order
if (
len(columns_found) == len(exp_colnames)
and len(columns_unmatched) == 0
and len(columns_not_found) == 0
):
# CASE I: Expected columns are the same as the target columns
schema_info["columns_full_set"] = True

if columns_found == tgt_colnames:
# Check if the columns are matched in order
schema_info["columns_matched_in_order"] = True

elif set(columns_found) == set(tgt_colnames):
# Check if the columns are matched in any order
schema_info["columns_matched_any_order"] = True

elif (
len(columns_found) == len(exp_colnames)
and len(columns_found) > 0
and len(columns_unmatched) == 0
):
# CASE II: Expected columns are a subset of the target columns
schema_info["columns_subset"] = True

# Filter the columns in the target table that are matched
tgt_colnames_matched = [col for col in tgt_colnames if col in columns_found]

# If the columns are matched in order, set `columns_matched_in_order` to True; do this
# for case-sensitive and case-insensitive comparisons
if case_sensitive_colnames:

if columns_found == tgt_colnames_matched:
schema_info["columns_matched_in_order"] = True

elif set(columns_found) == set(tgt_colnames_matched):
schema_info["columns_matched_any_order"] = True

else:

if [col.lower() for col in columns_found] == [
col.lower() for col in tgt_colnames_matched
]:
schema_info["columns_matched_in_order"] = True

elif set([col.lower() for col in columns_found]) == set(
[col.lower() for col in tgt_colnames_matched]
):
schema_info["columns_matched_any_order"] = True

# For each column in the expected schema, determine if the column name is matched
# and if the dtype is matched
colname_dict = []

for col in exp_colnames:

#
# Phase I: Determine if the column name is matched
#

if case_sensitive_colnames:

# Does the column name have a match in the expected schema?
colname_matched = col in columns_found

# If the column name is matched, get the column name in the target table
if colname_matched:
matched_to = col
else:
matched_to = None
else:

# Does the column name have a match in the expected schema? A lowercase comparison
# is used here to determine if the column name is matched
colname_matched = col.lower() in columns_found

# If the column name is matched, get the column name in the target table; this involves
# mapping the lowercase column name to the original column name in the target table
if colname_matched:
matched_to = tgt_colname_mapping[
columns_found[[col.lower() for col in columns_found].index(col.lower())]
]
else:
matched_to = None

# Get the dtype of the column in the expected schema
# If there is a dtype for the column in the expected schema, get it
if len(schema_exp.columns[exp_colnames.index(col)]) == 1:
dtype_input = None
else:
dtype_input = schema_exp.columns[exp_colnames.index(col)][1]

if isinstance(dtype_input, str):
dtype_input = [dtype_input]

# Is a dtype present in the expected schema column?
dtype_present = dtype_input is not None

#
# Phase II: Determine if the dtype of the column in the target table is matched
#

if colname_matched and dtype_present:

# Get the dtype of the column in the target table
dtype_tgt = schema_tgt.columns[tgt_colnames.index(matched_to)][1]

# Determine if the dtype of the column in the expected schema is matched
dtype_matches = []
dtype_matches_pos = []

# Iterate through the dtypes of the column in the expected schema and determine if
# any of them match the dtype of the column in the target table
for i in range(len(dtype_input)):

if not case_sensitive_dtypes:
dtype_input[i] = dtype_input[i].lower()
dtype_tgt = dtype_tgt.lower()

if full_match_dtypes and dtype_input[i] == dtype_tgt:
dtype_matches.append(True)
dtype_matches_pos.append(i)

if not full_match_dtypes and dtype_input[i] in dtype_tgt:
dtype_matches.append(True)
dtype_matches_pos.append(i)

# If there are no matches for any of the dtypes provided, set `dtype_matched` to False
dtype_matched = any(dtype_matches)

# If there are multiple dtypes for a column, set `dtype_multiple` to True
dtype_multiple = len(dtype_input) > 1

# Even if there are multiple matches for the dtype, we simply get the first position
# of the matched dtype
if dtype_matched:
dtype_matched_pos = dtype_matches_pos[0]
else:
dtype_matched_pos = None

else:

dtype_tgt = None
dtype_matched = False
dtype_multiple = False
dtype_matched_pos = None

colname_dict.append(
_schema_info_generate_colname_dict(
colname_matched=colname_matched,
matched_to=matched_to,
dtype_present=dtype_present,
dtype_input=dtype_input,
dtype_matched=dtype_matched,
dtype_multiple=dtype_multiple,
dtype_matched_pos=dtype_matched_pos,
)
)

# Generate the columns dictionary
schema_info["columns"] = _schema_info_generate_columns_dict(
colnames=exp_colnames, colname_dict=colname_dict
)

return schema_info
Loading
Loading