Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incremental (delta) update #928

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions src/datachain/delta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import TYPE_CHECKING, Optional

from datachain.error import DatasetNotFoundError

if TYPE_CHECKING:
from datachain.lib.dc import DataChain


def delta_update(dc: "DataChain", name: str) -> Optional["DataChain"]:
"""
Creates new chain that consists of the last version of current delta dataset
plus diff from the source with all needed modifications.
This way we don't need to re-calculate the whole chain from the source again(
apply all the DataChain methods like filters, mappers, generators etc.)
but just the diff part which is very important for performance.
"""
from datachain.lib.dc import DataChain

file_signal = dc.signals_schema.get_file_signal()
if not file_signal:
raise ValueError("Datasets without file signals cannot have delta updates")
try:
latest_version = dc.session.catalog.get_dataset(name).latest_version
except DatasetNotFoundError:
# first creation of delta update dataset
return None

source_ds_name = dc._query.starting_step.dataset_name
source_ds_version = dc._query.starting_step.dataset_version
diff = DataChain.from_dataset(source_ds_name, version=source_ds_version).diff(
DataChain.from_dataset(name, version=latest_version), on=file_signal
)
# we append all the steps from the original chain to diff,
# e.g filters, mappers, generators etc. With this we make sure we add all
# needed modifications to diff part as well
diff._query.steps += dc._query.steps

# merging diff and the latest version of our dataset
return diff.union(DataChain.from_dataset(name, latest_version))
20 changes: 19 additions & 1 deletion src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from sqlalchemy.sql.sqltypes import NullType

from datachain.dataset import DatasetRecord
from datachain.delta import delta_update
from datachain.func import literal
from datachain.func.base import Function
from datachain.func.func import Func
Expand Down Expand Up @@ -753,16 +754,33 @@ def listings(
)

def save( # type: ignore[override]
self, name: Optional[str] = None, version: Optional[int] = None, **kwargs
self,
name: Optional[str] = None,
version: Optional[int] = None,
delta: Optional[bool] = False,
**kwargs,
) -> "Self":
"""Save to a Dataset. It returns the chain itself.

Parameters:
name : dataset name. Empty name saves to a temporary dataset that will be
removed after process ends. Temp dataset are useful for optimization.
version : version of a dataset. Default - the last version that exist.
delta : If True, we optimize on creation of the new dataset versions
by calculating diff between source and the last version and applying
all needed modifications (mappers, filters etc.) only on that diff.
At the end, we merge modified diff with last version of dataset to
create new version.
"""
schema = self.signals_schema.clone_without_sys_signals().serialize()
if delta and name:
delta_ds = delta_update(self, name)
if delta_ds:
return self._evolve(
query=delta_ds._query.save(
name=name, version=version, feature_schema=schema, **kwargs
)
)
return self._evolve(
query=self._query.save(
name=name, version=version, feature_schema=schema, **kwargs
Expand Down
11 changes: 5 additions & 6 deletions src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,14 +410,13 @@ def row_to_objs(self, row: Sequence[Any]) -> list[DataValue]:
pos += 1
return objs

def contains_file(self) -> bool:
for type_ in self.values.values():
if (fr := ModelStore.to_pydantic(type_)) is not None and issubclass(
def get_file_signal(self) -> Optional[str]:
for signal_name, signal_type in self.values.items():
if (fr := ModelStore.to_pydantic(signal_type)) is not None and issubclass(
fr, File
):
return True

return False
return signal_name
return None

def slice(
self, keys: Sequence[str], setup: Optional[dict[str, Callable]] = None
Expand Down
15 changes: 5 additions & 10 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,6 @@ def step_result(
)


class StartingStep(ABC):
"""An initial query processing step, referencing a data source."""

@abstractmethod
def apply(self) -> "StepResult": ...


@frozen
class Step(ABC):
"""A query processing step (filtering, mutation, etc.)"""
Expand All @@ -172,12 +165,14 @@ def apply(


@frozen
class QueryStep(StartingStep):
class QueryStep:
"""A query that returns all rows from specific dataset version"""

catalog: "Catalog"
dataset_name: str
dataset_version: int

def apply(self):
def apply(self) -> "StepResult":
def q(*columns):
return sqlalchemy.select(*columns)

Expand Down Expand Up @@ -1095,7 +1090,7 @@ def __init__(
self.temp_table_names: list[str] = []
self.dependencies: set[DatasetDependencyType] = set()
self.table = self.get_table()
self.starting_step: StartingStep
self.starting_step: QueryStep
self.name: Optional[str] = None
self.version: Optional[int] = None
self.feature_schema: Optional[dict] = None
Expand Down
226 changes: 226 additions & 0 deletions tests/func/test_delta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
import os

import pytest
import regex as re
from PIL import Image

from datachain import func
from datachain.lib.dc import C, DataChain
from datachain.lib.file import File, ImageFile


def test_delta_update_from_dataset(test_session, tmp_dir, tmp_path):
starting_ds_name = "starting_ds"
ds_name = "delta_ds"

images = [
{"name": "img1.jpg", "data": Image.new(mode="RGB", size=(64, 64))},
{"name": "img2.jpg", "data": Image.new(mode="RGB", size=(128, 128))},
{"name": "img3.jpg", "data": Image.new(mode="RGB", size=(64, 64))},
{"name": "img4.jpg", "data": Image.new(mode="RGB", size=(128, 128))},
]

def create_image_dataset(ds_name, images):
DataChain.from_values(
file=[
ImageFile(path=img["name"], source=f"file://{tmp_path}")
for img in images
],
session=test_session,
).save(ds_name)

def create_delta_dataset(ds_name):
DataChain.from_dataset(
starting_ds_name,
session=test_session,
).save(ds_name, delta=True)

# first version of starting dataset
create_image_dataset(starting_ds_name, images[:2])
# first version of delta dataset
create_delta_dataset(ds_name)
# second version of starting dataset
create_image_dataset(starting_ds_name, images[2:])
# second version of delta dataset
create_delta_dataset(ds_name)

assert list(
DataChain.from_dataset(ds_name, version=1)
.order_by("file.path")
.collect("file.path")
) == [
"img1.jpg",
"img2.jpg",
]

assert list(
DataChain.from_dataset(ds_name, version=2)
.order_by("file.path")
.collect("file.path")
) == [
"img1.jpg",
"img2.jpg",
"img3.jpg",
"img4.jpg",
]


def test_delta_update_from_storage(test_session, tmp_dir, tmp_path):
ds_name = "delta_ds"
path = tmp_dir.as_uri()
tmp_dir = tmp_dir / "images"
os.mkdir(tmp_dir)

images = [
{
"name": f"img{i}.{'jpg' if i % 2 == 0 else 'png'}",
"data": Image.new(mode="RGB", size=((i + 1) * 10, (i + 1) * 10)),
}
for i in range(20)
]

# save only half of the images for now
for img in images[:10]:
img["data"].save(tmp_dir / img["name"])

def create_delta_dataset():
def my_embedding(file: File) -> list[float]:
return [0.5, 0.5]

def get_index(file: File) -> int:
r = r".+\/img(\d+)\.jpg"
return int(re.search(r, file.path).group(1)) # type: ignore[union-attr]

(
DataChain.from_storage(path, update=True, session=test_session)
.filter(C("file.path").glob("*.jpg"))
.map(emb=my_embedding)
.mutate(dist=func.cosine_distance("emb", (0.1, 0.2)))
.map(index=get_index)
.filter(C("index") > 3)
.save(ds_name, delta=True)
)

# first version of delta dataset
create_delta_dataset()

# remember old etags for later comparison to prove modified images are also taken
# into consideration on delta update
etags = {
r[0]: r[1].etag
for r in DataChain.from_dataset(ds_name, version=1).collect("index", "file")
}

# remove last couple of images to simulate modification since we will re-create it
for img in images[5:10]:
os.remove(tmp_dir / img["name"])

# save other half of images and the ones that are removed above
for img in images[5:]:
img["data"].save(tmp_dir / img["name"])

# second version of delta dataset
create_delta_dataset()

assert list(
DataChain.from_dataset(ds_name, version=1)
.order_by("file.path")
.collect("file.path")
) == [
"images/img4.jpg",
"images/img6.jpg",
"images/img8.jpg",
]

assert list(
DataChain.from_dataset(ds_name, version=2)
.order_by("file.path")
.collect("file.path")
) == [
"images/img10.jpg",
"images/img12.jpg",
"images/img14.jpg",
"images/img16.jpg",
"images/img18.jpg",
"images/img4.jpg",
"images/img6.jpg",
"images/img6.jpg",
"images/img8.jpg",
"images/img8.jpg",
]

# check that we have both old and new version of those that are modified
rows = list(
DataChain.from_dataset(ds_name, version=2)
.filter(C("index") == 6)
.order_by("file.path", "file.etag")
.collect("file")
)
assert rows[0].etag == etags[6]
assert rows[1].etag > etags[6] # new etag is bigger as it's the value of mtime


def test_delta_update_no_diff(test_session, tmp_dir, tmp_path):
ds_name = "delta_ds"
path = tmp_dir.as_uri()
tmp_dir = tmp_dir / "images"
os.mkdir(tmp_dir)

images = [
{"name": f"img{i}.jpg", "data": Image.new(mode="RGB", size=(64, 128))}
for i in range(10)
]

for img in images:
img["data"].save(tmp_dir / img["name"])

def create_delta_dataset():
def get_index(file: File) -> int:
r = r".+\/img(\d+)\.jpg"
return int(re.search(r, file.path).group(1)) # type: ignore[union-attr]

(
DataChain.from_storage(path, update=True, session=test_session)
.filter(C("file.path").glob("*.jpg"))
.map(index=get_index)
.filter(C("index") > 5)
.save(ds_name, delta=True)
)

create_delta_dataset()
create_delta_dataset()

assert (
list(
DataChain.from_dataset(ds_name, version=1)
.order_by("file.path")
.collect("file.path")
)
== list(
DataChain.from_dataset(ds_name, version=2)
.order_by("file.path")
.collect("file.path")
)
== [
"images/img6.jpg",
"images/img7.jpg",
"images/img8.jpg",
"images/img9.jpg",
]
)


def test_delta_update_no_file_signals(test_session):
starting_ds_name = "starting_ds"

DataChain.from_values(num=[10, 20], session=test_session).save(starting_ds_name)

with pytest.raises(ValueError) as excinfo:
DataChain.from_dataset(
starting_ds_name,
session=test_session,
).save("delta_ds", delta=True)

assert (
str(excinfo.value) == "Datasets without file signals cannot have delta updates"
)
5 changes: 5 additions & 0 deletions tests/unit/lib/test_signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,3 +992,8 @@ def test_column_types(column_type, signal_type):

assert len(signals) == 1
assert signals["val"] is signal_type


def test_get_file_signal():
assert SignalSchema({"name": str, "f": File}).get_file_signal() == "f"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also test for nested file signal or is it out of the scope of this method?

E.g.:

class CustomModel(DataModel):
    file: File
    foo: str
    bar: float

assert SignalSchema({"name": str, "custom": CustomModel}).get_file_signal() == "custom.file"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now it only works for top level file objects. In future we can add nested as well if needed

assert SignalSchema({"name": str}).get_file_signal() is None
Loading