Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lavender_data/server/background_worker/process_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _worker_process(
result = _tasks[work_item.func](**work_item.kwargs)
result_item = ResultItem(work_id=work_item.work_id, result=result)
except Exception as e:
logger.exception(f"Error processing work {work_item.work_id}: {e}")
# logger.exception(f"Error processing work {work_item.work_id}: {e}")
result_item = ResultItem(
work_id=work_item.work_id,
exception="".join(
Expand Down
1 change: 1 addition & 0 deletions lavender_data/server/cache/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def lrem(self, name: str, count: int, value: str) -> int: ...


class CacheInterface(CacheOperations):
@contextmanager
@abstractmethod
def lock(self, key: str, timeout: Optional[int] = None) -> Iterator[None]: ...

Expand Down
22 changes: 14 additions & 8 deletions lavender_data/server/dataset/preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
)
from lavender_data.server.shardset import get_main_shardset, span
from lavender_data.storage import get_url
from lavender_data.serialize import serialize_list
from lavender_data.serialize import serialize_list, deserialize_item
from lavender_data.logging import get_logger

try:
Expand Down Expand Up @@ -150,14 +150,20 @@ def _set_file(content: bytes):

def refine_value_previewable(value: Any):
if type(value) == bytes:
if len(value) > 0:
try:
local_path = _set_file(value)
return f"file://{local_path}"
except ValueError:
return f"<bytes>"
else:
if len(value) == 0:
return ""

try:
return f"file://{_set_file(value)}"
except ValueError:
pass

try:
return refine_value_previewable(deserialize_item(value))
except Exception:
pass

return f"<bytes>"
elif type(value) == dict:
if value.get("bytes"):
try:
Expand Down
32 changes: 27 additions & 5 deletions lavender_data/server/iteration/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from lavender_data.server.reader import (
get_reader_instance,
GlobalSampleIndex,
JoinMethod,
InnerJoinSampleInsufficient,
)
from lavender_data.server.registries import (
PreprocessorRegistry,
Expand Down Expand Up @@ -107,7 +109,13 @@ def _decollate(batch: dict) -> dict:
return _batch


def _process_next_samples(params: ProcessNextSamplesParams) -> dict:
class NoSamplesFound(Exception):
pass


def _process_next_samples(
params: ProcessNextSamplesParams, join_method: JoinMethod = "left"
) -> dict:
reader = get_reader_instance()

current = params.current
Expand All @@ -118,7 +126,15 @@ def _process_next_samples(params: ProcessNextSamplesParams) -> dict:
batch_size = params.batch_size

if samples is None:
samples = [reader.get_sample(i, join="left") for i in global_sample_indices]
samples = []
for i in global_sample_indices:
try:
samples.append(reader.get_sample(i, join_method))
except InnerJoinSampleInsufficient:
pass

if len(samples) == 0:
raise NoSamplesFound()

batch = (
CollaterRegistry.get(collater["name"]).collate(samples)
Expand Down Expand Up @@ -146,18 +162,24 @@ def _process_next_samples(params: ProcessNextSamplesParams) -> dict:
def process_next_samples(
params: ProcessNextSamplesParams,
max_retry_count: int,
join_method: JoinMethod = "left",
) -> dict:
logger = get_logger(__name__)

for i in range(max_retry_count + 1):
try:
return _process_next_samples(params)
return _process_next_samples(params, join_method)
except NoSamplesFound as e:
raise ProcessNextSamplesException(
e=e,
current=params.current,
global_sample_indices=params.global_sample_indices,
)
except Exception as e:
error = ProcessNextSamplesException(
e=e,
current=params.current,
global_sample_indices=params.global_sample_indices,
)
logger = get_logger(__name__)
if i < max_retry_count:
logger.warning(f"{str(error)}, retrying... ({i+1}/{max_retry_count})")
else:
Expand Down
13 changes: 8 additions & 5 deletions lavender_data/server/routes/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pydantic import BaseModel

from lavender_data.logging import get_logger
from lavender_data.shard.readers import Reader
from lavender_data.server.db import DbSession
from lavender_data.server.db.models import (
Dataset,
Expand Down Expand Up @@ -42,7 +43,6 @@
preprocess_shardset,
)
from lavender_data.server.auth import AppAuth
from lavender_data.storage import list_files
from lavender_data.shard import inspect_shard
from lavender_data.serialize import deserialize_list

Expand Down Expand Up @@ -139,7 +139,11 @@ def get_dataset_preview(
raise HTTPException(status_code=404, detail="Dataset not found")

if cache.exists(f"preview:{dataset_id}:{preview_id}:error"):
error = cache.get(f"preview:{dataset_id}:{preview_id}:error").decode()
error = cache.get(f"preview:{dataset_id}:{preview_id}:error")
if error is None:
error = "Unknown error"
else:
error = error.decode()
cache.delete(f"preview:{dataset_id}:{preview_id}:error")
raise HTTPException(
status_code=500,
Expand Down Expand Up @@ -235,8 +239,7 @@ def create_dataset(
cluster=cluster,
)
except:
if cluster:
cluster.sync_changes([dataset], delete=True)
delete_dataset(dataset.id, session, cluster)
raise

return dataset
Expand Down Expand Up @@ -390,7 +393,7 @@ def create_shardset(
uid_column = None

try:
shard_basenames = sorted(list_files(params.location, limit=1))
shard_basenames = Reader.list_readables(params.location)
except Exception as e:
shard_basenames = []
logger.warning(f"Failed to list shardset location: {e}")
Expand Down
22 changes: 14 additions & 8 deletions lavender_data/server/shardset/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from sqlmodel import select
from sqlalchemy.orm import selectinload

from lavender_data.shard.readers import Reader
from lavender_data.logging import get_logger
from lavender_data.storage import list_files, upload_file
from lavender_data.storage import upload_file
from lavender_data.server.reader import GlobalSampleIndex, MainShardInfo, ShardInfo
from lavender_data.server.iteration import (
process_next_samples,
Expand Down Expand Up @@ -102,17 +103,16 @@ def preprocess_shardset(
existing_shard_basenames = []
if not overwrite:
try:
existing_shard_basenames = [
basename
for basename in sorted(list_files(shardset_location))
if basename.endswith(".parquet") or basename.endswith(".csv")
]
existing_shard_basenames = Reader.list_readables(shardset_location)
except Exception as e:
pass

for main_shard in main_shardset.shards:
shard_basename = f"shard.{main_shard.index:05d}.parquet"
for main_shard in sorted(main_shardset.shards, key=lambda x: x.index):
shard_basename = os.path.basename(main_shard.location)
filename, extension = os.path.splitext(shard_basename)
shard_basename = f"{filename}.parquet"
location = os.path.join(shardset_location, shard_basename)


if shard_basename in existing_shard_basenames:
logger.info(
Expand Down Expand Up @@ -187,6 +187,7 @@ def preprocess_shardset(
process_next_samples,
params=params,
max_retry_count=max_retry_count,
join_method="inner",
)
)

Expand All @@ -196,9 +197,14 @@ def preprocess_shardset(
try:
batch = process_pool.result(work_id)
except Exception as e:
if "NoSamplesFound" in str(e):
continue
logger.error(e)
continue

if batch is None:
continue

keys = list(batch.keys())
for key in keys:
if key not in _export_columns:
Expand Down
4 changes: 2 additions & 2 deletions lavender_data/server/shardset/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from sqlmodel import update, insert, select, delete

from lavender_data.logging import get_logger
from lavender_data.storage import list_files
from lavender_data.shard.inspect import OrphanShardInfo, inspect_shard
from lavender_data.shard.readers import Reader
from lavender_data.shard.readers.exceptions import ReaderException
from lavender_data.server.background_worker import (
TaskStatus,
Expand Down Expand Up @@ -43,7 +43,7 @@ def inspect_shardset_location(
try:
shard_index = 0

shard_basenames = sorted(list_files(shardset_location))
shard_basenames = Reader.list_readables(os.path.join(shardset_location))

shard_locations: list[str] = []
for shard_basename in shard_basenames:
Expand Down
18 changes: 17 additions & 1 deletion lavender_data/shard/readers/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Iterator, Optional, Union
from typing_extensions import Self

from lavender_data.storage import download_file
from lavender_data.storage import download_file, list_files
from lavender_data.logging import get_logger

from .exceptions import (
Expand All @@ -19,6 +19,22 @@
class Reader(ABC):
format: str = ""

@classmethod
def is_readable(cls, location: str) -> bool:
shard_format = os.path.splitext(location)[1].lstrip(".")
for subcls in cls._reader_classes():
if shard_format == subcls.format:
return True
return False

@classmethod
def list_readables(cls, location: str) -> list[str]:
return [
basename
for basename in sorted(list_files(location))
if cls.is_readable(os.path.join(location, basename))
]

@classmethod
def get(
cls,
Expand Down
14 changes: 7 additions & 7 deletions lavender_data/shard/readers/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,37 +15,37 @@ class CsvReader(UntypedReader):
format = "csv"
typed_columns = False

def resolve_type(self, value: Any, typestr: str) -> type:
def resolve_type(self, value: Any, typestr: str) -> Any:
if typestr in ["int", "int32", "int64"]:
if value == "":
if value == "" or value is None:
return np.nan
return int(value)
elif typestr in ["float", "double"]:
if value == "":
if value == "" or value is None:
return np.nan
return float(value)
elif typestr in ["string", "text", "str"]:
return str(value)
elif typestr in ["bool", "boolean"]:
return value.lower() in ["true", "t", "yes", "y", "1"]
elif typestr in ["list"]:
if value == "":
if value == "" or value is None:
return []
return ast.literal_eval(value)
elif typestr in ["map"]:
if value == "":
if value == "" or value is None:
return {}
return ast.literal_eval(value)
elif typestr in ["binary"]:
if value == "":
if value == "" or value is None:
return b""
return ast.literal_eval(value)
return value

def read_columns(self) -> dict[str, str]:
with open(self.filepath, "r") as f:
reader = csv.DictReader(f)
return {name: "string" for name in reader.fieldnames}
return {name: "string" for name in reader.fieldnames or []}

def read_samples(self) -> list[dict[str, Any]]:
samples = []
Expand Down
7 changes: 6 additions & 1 deletion lavender_data/storage/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@ def __init__(self):
def _parse_remote_path(self, remote_path: str) -> tuple[str, str]:
parsed = urllib.parse.urlparse(remote_path)
org = parsed.netloc
repo, path = parsed.path.lstrip("/").split("/", 1)
splitted = parsed.path.lstrip("/").split("/", 1)
repo = splitted[0]
if len(splitted) > 1:
path = splitted[1]
else:
path = ""
repo_id = f"{org}/{repo}"
return repo_id, path

Expand Down
2 changes: 1 addition & 1 deletion lavender_data/ui/.next/BUILD_ID
Original file line number Diff line number Diff line change
@@ -1 +1 @@
vKMYDews__Kl5CqTBPgsk
ASKelLLlcph1PGS2X65FV
Loading