Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incremental (delta) update #928

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions src/datachain/delta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import TYPE_CHECKING, Optional

from datachain.error import DatasetNotFoundError

if TYPE_CHECKING:
from datachain.lib.dc import DataChain


def delta_update(dc: "DataChain", name: str) -> Optional["DataChain"]:
"""
Creates new chain that consists of the last version of current delta dataset
plus diff from the source with all needed modifications.
This way we don't need to re-calculate the whole chain from the source again(
apply all the DataChain methods like filters, mappers, generators etc.)
but just the diff part which is very important for performance.
"""
from datachain.lib.dc import DataChain

file_signal = dc.signals_schema.get_file_signal()
if not file_signal:
raise ValueError("Datasets without file signals cannot have delta updates")
try:
latest_version = dc.session.catalog.get_dataset(name).latest_version
except DatasetNotFoundError:
# first creation of delta update dataset
return None

source_ds_name = dc._query.starting_step.dataset_name
source_ds_version = dc._query.starting_step.dataset_version

diff = (
DataChain.from_dataset(source_ds_name, version=source_ds_version)
.diff(
DataChain.from_dataset(name, version=latest_version),
on=file_signal,
sys=True,
)
# We append all the steps from the original chain to diff, e.g filters, mappers.
.append_steps(dc)
)

# merging diff and the latest version of dataset
return (
DataChain.from_dataset(name, latest_version)
.diff(diff, added=True, modified=False, sys=True)
.union(diff)
)
5 changes: 5 additions & 0 deletions src/datachain/diff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def _compare( # noqa: C901
modified: bool = True,
same: bool = True,
status_col: Optional[str] = None,
sys: Optional[bool] = False,
) -> "DataChain":
"""Comparing two chains by identifying rows that are added, deleted, modified
or same"""
Expand Down Expand Up @@ -140,6 +141,10 @@ def _to_list(obj: Optional[Union[str, Sequence[str]]]) -> Optional[list[str]]:
.select_except(ldiff_col, rdiff_col)
)

if sys:
# making sure we have sys signals in final diff chain
dc_diff = dc_diff.settings(sys=True)

if not added:
dc_diff = dc_diff.filter(C(diff_col) != CompareStatus.ADDED)
if not modified:
Expand Down
35 changes: 34 additions & 1 deletion src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tqdm import tqdm

from datachain.dataset import DatasetRecord
from datachain.delta import delta_update
from datachain.func import literal
from datachain.func.base import Function
from datachain.func.func import Func
Expand Down Expand Up @@ -333,6 +334,15 @@ def clone(self) -> "Self":
"""Make a copy of the chain in a new table."""
return self._evolve(query=self._query.clone(new_table=True))

def append_steps(self, chain: "DataChain") -> "Self":
"""Returns cloned chain with appended steps from other chain.
Steps are all those modification methods applied like filters, mappers etc.
"""
dc = self.clone()
dc._query.steps += chain._query.steps.copy()
dc.signals_schema = dc.signals_schema.append(chain.signals_schema)
return dc

def _evolve(
self,
*,
Expand Down Expand Up @@ -760,16 +770,33 @@ def listings(
)

def save( # type: ignore[override]
self, name: Optional[str] = None, version: Optional[int] = None, **kwargs
self,
name: Optional[str] = None,
version: Optional[int] = None,
delta: Optional[bool] = False,
**kwargs,
) -> "Self":
"""Save to a Dataset. It returns the chain itself.

Parameters:
name : dataset name. Empty name saves to a temporary dataset that will be
removed after process ends. Temp dataset are useful for optimization.
version : version of a dataset. Default - the last version that exist.
delta : If True, we optimize on creation of the new dataset versions
by calculating diff between source and the last version and applying
all needed modifications (mappers, filters etc.) only on that diff.
At the end, we merge modified diff with last version of dataset to
create new version.
"""
schema = self.signals_schema.clone_without_sys_signals().serialize()
if delta and name:
delta_ds = delta_update(self, name)
if delta_ds:
return self._evolve(
query=delta_ds._query.save(
name=name, version=version, feature_schema=schema, **kwargs
)
)
return self._evolve(
query=self._query.save(
name=name, version=version, feature_schema=schema, **kwargs
Expand Down Expand Up @@ -1621,6 +1648,7 @@ def compare(
modified: bool = True,
same: bool = False,
status_col: Optional[str] = None,
sys: Optional[bool] = False,
) -> "DataChain":
"""Comparing two chains by identifying rows that are added, deleted, modified
or same. Result is the new chain that has additional column with possible
Expand Down Expand Up @@ -1653,6 +1681,7 @@ def compare(
same (bool): Whether to return unchanged rows in resulting chain.
status_col (str): Name of the new column that is created in resulting chain
representing diff status.
sys (bool): Whether to have sys columns in returned diff chain or not.

Example:
```py
Expand Down Expand Up @@ -1683,6 +1712,7 @@ def compare(
modified=modified,
same=same,
status_col=status_col,
sys=sys,
)

def diff(
Expand All @@ -1695,6 +1725,7 @@ def diff(
deleted: bool = False,
same: bool = False,
status_col: Optional[str] = None,
sys: Optional[bool] = False,
) -> "DataChain":
"""Similar to `.compare()`, which is more generic method to calculate difference
between two chains. Unlike `.compare()`, this method works only on those chains
Expand All @@ -1717,6 +1748,7 @@ def diff(
same (bool): Whether to return unchanged rows in resulting chain.
status_col (str): Optional name of the new column that is created in
resulting chain representing diff status.
sys (bool): Whether to have sys columns in returned diff chain or not.

Example:
```py
Expand Down Expand Up @@ -1756,6 +1788,7 @@ def get_file_signals(file: str, signals):
modified=modified,
same=same,
status_col=status_col,
sys=sys,
)

@classmethod
Expand Down
18 changes: 12 additions & 6 deletions src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,14 +454,13 @@ def row_to_objs(self, row: Sequence[Any]) -> list[DataValue]:
pos += 1
return objs

def contains_file(self) -> bool:
for type_ in self.values.values():
if (fr := ModelStore.to_pydantic(type_)) is not None and issubclass(
def get_file_signal(self) -> Optional[str]:
for signal_name, signal_type in self.values.items():
if (fr := ModelStore.to_pydantic(signal_type)) is not None and issubclass(
fr, File
):
return True

return False
return signal_name
return None

def slice(
self, keys: Sequence[str], setup: Optional[dict[str, Callable]] = None
Expand Down Expand Up @@ -647,6 +646,13 @@ def merge(

return SignalSchema(self.values | schema_right)

def append(self, right: "SignalSchema") -> "SignalSchema":
missing_schema = {
key: right.values[key]
for key in [k for k in right.values if k not in self.values]
}
return SignalSchema(self.values | missing_schema)

def get_signals(self, target_type: type[DataModel]) -> Iterator[str]:
for path, type_, has_subtree, _ in self.get_flat_tree():
if has_subtree and issubclass(type_, target_type):
Expand Down
15 changes: 5 additions & 10 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,6 @@ def step_result(
)


class StartingStep(ABC):
"""An initial query processing step, referencing a data source."""

@abstractmethod
def apply(self) -> "StepResult": ...


@frozen
class Step(ABC):
"""A query processing step (filtering, mutation, etc.)"""
Expand All @@ -172,12 +165,14 @@ def apply(


@frozen
class QueryStep(StartingStep):
class QueryStep:
"""A query that returns all rows from specific dataset version"""

catalog: "Catalog"
dataset_name: str
dataset_version: int

def apply(self):
def apply(self) -> "StepResult":
def q(*columns):
return sqlalchemy.select(*columns)

Expand Down Expand Up @@ -1095,7 +1090,7 @@ def __init__(
self.temp_table_names: list[str] = []
self.dependencies: set[DatasetDependencyType] = set()
self.table = self.get_table()
self.starting_step: StartingStep
self.starting_step: QueryStep
self.name: Optional[str] = None
self.version: Optional[int] = None
self.feature_schema: Optional[dict] = None
Expand Down
Loading
Loading