diff --git a/src/datachain/delta.py b/src/datachain/delta.py new file mode 100644 index 000000000..b525252ca --- /dev/null +++ b/src/datachain/delta.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING, Optional + +from datachain.error import DatasetNotFoundError + +if TYPE_CHECKING: + from datachain.lib.dc import DataChain + + +def delta_update(dc: "DataChain", name: str) -> Optional["DataChain"]: + """ + Creates new chain that consists of the last version of current delta dataset + plus diff from the source with all needed modifications. + This way we don't need to re-calculate the whole chain from the source again( + apply all the DataChain methods like filters, mappers, generators etc.) + but just the diff part which is very important for performance. + """ + from datachain.lib.dc import DataChain + + file_signal = dc.signals_schema.get_file_signal() + if not file_signal: + raise ValueError("Datasets without file signals cannot have delta updates") + try: + latest_version = dc.session.catalog.get_dataset(name).latest_version + except DatasetNotFoundError: + # first creation of delta update dataset + return None + + source_ds_name = dc._query.starting_step.dataset_name + source_ds_version = dc._query.starting_step.dataset_version + + diff = ( + DataChain.from_dataset(source_ds_name, version=source_ds_version) + .diff( + DataChain.from_dataset(name, version=latest_version), + on=file_signal, + sys=True, + ) + # We append all the steps from the original chain to diff, e.g filters, mappers. + .append_steps(dc) + ) + + # merging diff and the latest version of dataset + return ( + DataChain.from_dataset(name, latest_version) + .diff(diff, added=True, modified=False, sys=True) + .union(diff) + ) diff --git a/src/datachain/diff/__init__.py b/src/datachain/diff/__init__.py index b325a2d29..d09931851 100644 --- a/src/datachain/diff/__init__.py +++ b/src/datachain/diff/__init__.py @@ -42,6 +42,7 @@ def _compare( # noqa: C901 modified: bool = True, same: bool = True, status_col: Optional[str] = None, + sys: Optional[bool] = False, ) -> "DataChain": """Comparing two chains by identifying rows that are added, deleted, modified or same""" @@ -140,6 +141,10 @@ def _to_list(obj: Optional[Union[str, Sequence[str]]]) -> Optional[list[str]]: .select_except(ldiff_col, rdiff_col) ) + if sys: + # making sure we have sys signals in final diff chain + dc_diff = dc_diff.settings(sys=True) + if not added: dc_diff = dc_diff.filter(C(diff_col) != CompareStatus.ADDED) if not modified: diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index e6471af79..8ef4d69a3 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -26,6 +26,7 @@ from tqdm import tqdm from datachain.dataset import DatasetRecord +from datachain.delta import delta_update from datachain.func import literal from datachain.func.base import Function from datachain.func.func import Func @@ -333,6 +334,15 @@ def clone(self) -> "Self": """Make a copy of the chain in a new table.""" return self._evolve(query=self._query.clone(new_table=True)) + def append_steps(self, chain: "DataChain") -> "Self": + """Returns cloned chain with appended steps from other chain. + Steps are all those modification methods applied like filters, mappers etc. + """ + dc = self.clone() + dc._query.steps += chain._query.steps.copy() + dc.signals_schema = dc.signals_schema.append(chain.signals_schema) + return dc + def _evolve( self, *, @@ -760,7 +770,11 @@ def listings( ) def save( # type: ignore[override] - self, name: Optional[str] = None, version: Optional[int] = None, **kwargs + self, + name: Optional[str] = None, + version: Optional[int] = None, + delta: Optional[bool] = False, + **kwargs, ) -> "Self": """Save to a Dataset. It returns the chain itself. @@ -768,8 +782,21 @@ def save( # type: ignore[override] name : dataset name. Empty name saves to a temporary dataset that will be removed after process ends. Temp dataset are useful for optimization. version : version of a dataset. Default - the last version that exist. + delta : If True, we optimize on creation of the new dataset versions + by calculating diff between source and the last version and applying + all needed modifications (mappers, filters etc.) only on that diff. + At the end, we merge modified diff with last version of dataset to + create new version. """ schema = self.signals_schema.clone_without_sys_signals().serialize() + if delta and name: + delta_ds = delta_update(self, name) + if delta_ds: + return self._evolve( + query=delta_ds._query.save( + name=name, version=version, feature_schema=schema, **kwargs + ) + ) return self._evolve( query=self._query.save( name=name, version=version, feature_schema=schema, **kwargs @@ -1621,6 +1648,7 @@ def compare( modified: bool = True, same: bool = False, status_col: Optional[str] = None, + sys: Optional[bool] = False, ) -> "DataChain": """Comparing two chains by identifying rows that are added, deleted, modified or same. Result is the new chain that has additional column with possible @@ -1653,6 +1681,7 @@ def compare( same (bool): Whether to return unchanged rows in resulting chain. status_col (str): Name of the new column that is created in resulting chain representing diff status. + sys (bool): Whether to have sys columns in returned diff chain or not. Example: ```py @@ -1683,6 +1712,7 @@ def compare( modified=modified, same=same, status_col=status_col, + sys=sys, ) def diff( @@ -1695,6 +1725,7 @@ def diff( deleted: bool = False, same: bool = False, status_col: Optional[str] = None, + sys: Optional[bool] = False, ) -> "DataChain": """Similar to `.compare()`, which is more generic method to calculate difference between two chains. Unlike `.compare()`, this method works only on those chains @@ -1717,6 +1748,7 @@ def diff( same (bool): Whether to return unchanged rows in resulting chain. status_col (str): Optional name of the new column that is created in resulting chain representing diff status. + sys (bool): Whether to have sys columns in returned diff chain or not. Example: ```py @@ -1756,6 +1788,7 @@ def get_file_signals(file: str, signals): modified=modified, same=same, status_col=status_col, + sys=sys, ) @classmethod diff --git a/src/datachain/lib/signal_schema.py b/src/datachain/lib/signal_schema.py index f3b0875ad..c32bf10d7 100644 --- a/src/datachain/lib/signal_schema.py +++ b/src/datachain/lib/signal_schema.py @@ -454,14 +454,13 @@ def row_to_objs(self, row: Sequence[Any]) -> list[DataValue]: pos += 1 return objs - def contains_file(self) -> bool: - for type_ in self.values.values(): - if (fr := ModelStore.to_pydantic(type_)) is not None and issubclass( + def get_file_signal(self) -> Optional[str]: + for signal_name, signal_type in self.values.items(): + if (fr := ModelStore.to_pydantic(signal_type)) is not None and issubclass( fr, File ): - return True - - return False + return signal_name + return None def slice( self, keys: Sequence[str], setup: Optional[dict[str, Callable]] = None @@ -647,6 +646,13 @@ def merge( return SignalSchema(self.values | schema_right) + def append(self, right: "SignalSchema") -> "SignalSchema": + missing_schema = { + key: right.values[key] + for key in [k for k in right.values if k not in self.values] + } + return SignalSchema(self.values | missing_schema) + def get_signals(self, target_type: type[DataModel]) -> Iterator[str]: for path, type_, has_subtree, _ in self.get_flat_tree(): if has_subtree and issubclass(type_, target_type): diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index 3b0eb420e..0093bcd0d 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -153,13 +153,6 @@ def step_result( ) -class StartingStep(ABC): - """An initial query processing step, referencing a data source.""" - - @abstractmethod - def apply(self) -> "StepResult": ... - - @frozen class Step(ABC): """A query processing step (filtering, mutation, etc.)""" @@ -172,12 +165,14 @@ def apply( @frozen -class QueryStep(StartingStep): +class QueryStep: + """A query that returns all rows from specific dataset version""" + catalog: "Catalog" dataset_name: str dataset_version: int - def apply(self): + def apply(self) -> "StepResult": def q(*columns): return sqlalchemy.select(*columns) @@ -1095,7 +1090,7 @@ def __init__( self.temp_table_names: list[str] = [] self.dependencies: set[DatasetDependencyType] = set() self.table = self.get_table() - self.starting_step: StartingStep + self.starting_step: QueryStep self.name: Optional[str] = None self.version: Optional[int] = None self.feature_schema: Optional[dict] = None diff --git a/tests/func/test_delta.py b/tests/func/test_delta.py new file mode 100644 index 000000000..96cd9792e --- /dev/null +++ b/tests/func/test_delta.py @@ -0,0 +1,226 @@ +import os + +import pytest +import regex as re +from PIL import Image + +from datachain import func +from datachain.lib.dc import C, DataChain +from datachain.lib.file import File, ImageFile + + +def test_delta_update_from_dataset(test_session, tmp_dir, tmp_path): + starting_ds_name = "starting_ds" + ds_name = "delta_ds" + + images = [ + {"name": "img1.jpg", "data": Image.new(mode="RGB", size=(64, 64))}, + {"name": "img2.jpg", "data": Image.new(mode="RGB", size=(128, 128))}, + {"name": "img3.jpg", "data": Image.new(mode="RGB", size=(64, 64))}, + {"name": "img4.jpg", "data": Image.new(mode="RGB", size=(128, 128))}, + ] + + def create_image_dataset(ds_name, images): + DataChain.from_values( + file=[ + ImageFile(path=img["name"], source=f"file://{tmp_path}") + for img in images + ], + session=test_session, + ).save(ds_name) + + def create_delta_dataset(ds_name): + DataChain.from_dataset( + starting_ds_name, + session=test_session, + ).save(ds_name, delta=True) + + # first version of starting dataset + create_image_dataset(starting_ds_name, images[:2]) + # first version of delta dataset + create_delta_dataset(ds_name) + # second version of starting dataset + create_image_dataset(starting_ds_name, images[2:]) + # second version of delta dataset + create_delta_dataset(ds_name) + + assert list( + DataChain.from_dataset(ds_name, version=1) + .order_by("file.path") + .collect("file.path") + ) == [ + "img1.jpg", + "img2.jpg", + ] + + assert list( + DataChain.from_dataset(ds_name, version=2) + .order_by("file.path") + .collect("file.path") + ) == [ + "img1.jpg", + "img2.jpg", + "img3.jpg", + "img4.jpg", + ] + + +def test_delta_update_from_storage(test_session, tmp_dir, tmp_path): + ds_name = "delta_ds" + path = tmp_dir.as_uri() + tmp_dir = tmp_dir / "images" + os.mkdir(tmp_dir) + + images = [ + { + "name": f"img{i}.{'jpg' if i % 2 == 0 else 'png'}", + "data": Image.new(mode="RGB", size=((i + 1) * 10, (i + 1) * 10)), + } + for i in range(20) + ] + + # save only half of the images for now + for img in images[:10]: + img["data"].save(tmp_dir / img["name"]) + + def create_delta_dataset(): + def my_embedding(file: File) -> list[float]: + return [0.5, 0.5] + + def get_index(file: File) -> int: + r = r".+\/img(\d+)\.jpg" + return int(re.search(r, file.path).group(1)) # type: ignore[union-attr] + + ( + DataChain.from_storage(path, update=True, session=test_session) + .filter(C("file.path").glob("*.jpg")) + .map(emb=my_embedding) + .mutate(dist=func.cosine_distance("emb", (0.1, 0.2))) + .map(index=get_index) + .filter(C("index") > 3) + .save(ds_name, delta=True) + ) + + # first version of delta dataset + create_delta_dataset() + + # remember old etags for later comparison to prove modified images are also taken + # into consideration on delta update + etags = { + r[0]: r[1].etag + for r in DataChain.from_dataset(ds_name, version=1).collect("index", "file") + } + + # remove last couple of images to simulate modification since we will re-create it + for img in images[5:10]: + os.remove(tmp_dir / img["name"]) + + # save other half of images and the ones that are removed above + for img in images[5:]: + img["data"].save(tmp_dir / img["name"]) + + # second version of delta dataset + create_delta_dataset() + + assert list( + DataChain.from_dataset(ds_name, version=1) + .order_by("file.path") + .collect("file.path") + ) == [ + "images/img4.jpg", + "images/img6.jpg", + "images/img8.jpg", + ] + + assert list( + DataChain.from_dataset(ds_name, version=2) + .order_by("file.path") + .collect("file.path") + ) == [ + "images/img10.jpg", + "images/img12.jpg", + "images/img14.jpg", + "images/img16.jpg", + "images/img18.jpg", + "images/img4.jpg", + "images/img6.jpg", + "images/img8.jpg", + ] + + # check that we have newest versions for modified rows since etags are mtime + # and modified rows etags should be bigger than the old ones + assert ( + next( + DataChain.from_dataset(ds_name, version=2) + .filter(C("index") == 6) + .order_by("file.path", "file.etag") + .collect("file.etag") + ) + > etags[6] + ) + + +def test_delta_update_no_diff(test_session, tmp_dir, tmp_path): + ds_name = "delta_ds" + path = tmp_dir.as_uri() + tmp_dir = tmp_dir / "images" + os.mkdir(tmp_dir) + + images = [ + {"name": f"img{i}.jpg", "data": Image.new(mode="RGB", size=(64, 128))} + for i in range(10) + ] + + for img in images: + img["data"].save(tmp_dir / img["name"]) + + def create_delta_dataset(): + def get_index(file: File) -> int: + r = r".+\/img(\d+)\.jpg" + return int(re.search(r, file.path).group(1)) # type: ignore[union-attr] + + ( + DataChain.from_storage(path, update=True, session=test_session) + .filter(C("file.path").glob("*.jpg")) + .map(index=get_index) + .filter(C("index") > 5) + .save(ds_name, delta=True) + ) + + create_delta_dataset() + create_delta_dataset() + + assert ( + list( + DataChain.from_dataset(ds_name, version=1) + .order_by("file.path") + .collect("file.path") + ) + == list( + DataChain.from_dataset(ds_name, version=2) + .order_by("file.path") + .collect("file.path") + ) + == [ + "images/img6.jpg", + "images/img7.jpg", + "images/img8.jpg", + "images/img9.jpg", + ] + ) + + +def test_delta_update_no_file_signals(test_session): + starting_ds_name = "starting_ds" + + DataChain.from_values(num=[10, 20], session=test_session).save(starting_ds_name) + + with pytest.raises(ValueError) as excinfo: + DataChain.from_dataset( + starting_ds_name, + session=test_session, + ).save("delta_ds", delta=True) + + assert ( + str(excinfo.value) == "Datasets without file signals cannot have delta updates" + ) diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index d5b442edd..3a10e8616 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -2998,3 +2998,21 @@ def test_window_error(test_session): ), ): dc.mutate(first=func.sum("col2").over(window)) + + +def test_append_steps(test_session): + keys = ["a", "b", "c", "d"] + values = [1, 2, 3, 4] + + DataChain.from_values(key=keys, val=values, session=test_session).save("ds") + + ds1 = ( + DataChain.from_dataset("ds", session=test_session) + .filter(C("val") > 2) + .mutate(double=C("val") * 2) + ) + + ds2 = DataChain.from_dataset("ds", session=test_session).append_steps(ds1) + + assert list(ds2.order_by("val").collect("val")) == [3, 4] + assert list(ds2.order_by("val").collect("double")) == [6, 8] diff --git a/tests/unit/lib/test_signal_schema.py b/tests/unit/lib/test_signal_schema.py index 2890653b2..d03f52807 100644 --- a/tests/unit/lib/test_signal_schema.py +++ b/tests/unit/lib/test_signal_schema.py @@ -1132,3 +1132,14 @@ class Custom(DataModel): "f": "FilePartial1@v1", "custom": "CustomPartial1@v1", } + + +def test_get_file_signal(): + assert SignalSchema({"name": str, "f": File}).get_file_signal() == "f" + assert SignalSchema({"name": str}).get_file_signal() is None + + +def test_append(): + s1 = SignalSchema({"name": str, "f": File}) + s2 = SignalSchema({"name": str, "f": File, "age": int}) + assert s1.append(s2).values == {"name": str, "f": File, "age": int}