Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 33 additions & 19 deletions src/datasets/packaged_modules/webdataset/webdataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io
import json
import re
from itertools import islice
from typing import Any, Callable, Dict, List

Expand Down Expand Up @@ -28,25 +29,26 @@ def _get_pipeline_from_tar(cls, tar_path, tar_iterator):
fs: fsspec.AbstractFileSystem = fsspec.filesystem("memory")
streaming_download_manager = datasets.StreamingDownloadManager()
for filename, f in tar_iterator:
if "." in filename:
example_key, field_name = filename.split(".", 1)
if current_example and current_example["__key__"] != example_key:
yield current_example
current_example = {}
current_example["__key__"] = example_key
current_example["__url__"] = tar_path
current_example[field_name.lower()] = f.read()
if field_name.split(".")[-1] in SINGLE_FILE_COMPRESSION_EXTENSION_TO_PROTOCOL:
fs.write_bytes(filename, current_example[field_name.lower()])
extracted_file_path = streaming_download_manager.extract(f"memory://{filename}")
with fsspec.open(extracted_file_path) as f:
current_example[field_name.lower()] = f.read()
fs.delete(filename)
data_extension = xbasename(extracted_file_path).split(".")[-1]
else:
data_extension = field_name.split(".")[-1]
if data_extension in cls.DECODERS:
current_example[field_name] = cls.DECODERS[data_extension](current_example[field_name])
example_key, field_name = base_plus_ext(filename)
if example_key is None:
continue
if current_example and current_example["__key__"] != example_key:
yield current_example
current_example = {}
current_example["__key__"] = example_key
current_example["__url__"] = tar_path
current_example[field_name.lower()] = f.read()
if field_name.split(".")[-1] in SINGLE_FILE_COMPRESSION_EXTENSION_TO_PROTOCOL:
fs.write_bytes(filename, current_example[field_name.lower()])
extracted_file_path = streaming_download_manager.extract(f"memory://{filename}")
with fsspec.open(extracted_file_path) as f:
current_example[field_name.lower()] = f.read()
fs.delete(filename)
data_extension = xbasename(extracted_file_path).split(".")[-1]
else:
data_extension = field_name.split(".")[-1]
if data_extension in cls.DECODERS:
current_example[field_name] = cls.DECODERS[data_extension](current_example[field_name])
if current_example:
yield current_example

Expand Down Expand Up @@ -121,6 +123,18 @@ def _generate_examples(self, tar_paths, tar_iterators):
yield f"{tar_idx}_{example_idx}", example


# Source: https://github.com/webdataset/webdataset/blob/87bd5aa41602d57f070f65a670893ee625702f2f/webdataset/tariterators.py#L25
def base_plus_ext(path):
"""Split off all file extensions.

Returns base, allext.
"""
match = re.match(r"^((?:.*/|)[^.]+)[.]([^/]*)$", path)
if not match:
return None, None
return match.group(1), match.group(2)


# Obtained with:
# ```
# import PIL.Image
Expand Down