Skip to content
Merged
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
torch
torchvision
lightning-utilities
filelock
filelock <3.24 # v3.24.0 removed lock file auto-delete on Windows, breaking our cleanup logic
numpy
boto3
requests
Expand Down
7 changes: 5 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path

from pkg_resources import parse_requirements
from setuptools import find_packages, setup

_PATH_ROOT = os.path.dirname(__file__)
Expand All @@ -19,8 +18,12 @@ def _load_py_module(fname, pkg="litdata"):
return py


about = _load_py_module("__about__.py")
requirements_module = _load_py_module("requirements.py")


def _load_requirements(path_dir: str = _PATH_ROOT, file_name: str = "requirements.txt") -> list:
reqs = parse_requirements(open(os.path.join(path_dir, file_name)).readlines())
reqs = requirements_module._parse_requirements(open(os.path.join(path_dir, file_name)).readlines())
return list(map(str, reqs))


Expand Down
292 changes: 267 additions & 25 deletions src/litdata/imports.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,32 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import importlib
import os
import warnings
from collections.abc import Callable
from functools import lru_cache
from importlib.metadata import PackageNotFoundError, distribution
from importlib.metadata import version as _version
from importlib.util import find_spec
from typing import TypeVar
from types import ModuleType
from typing import Any, TypeVar

import pkg_resources
from packaging.requirements import Requirement
from packaging.version import InvalidVersion, Version
from typing_extensions import ParamSpec

T = TypeVar("T")
P = ParamSpec("P")

try:
from importlib import metadata
except ImportError:
# Python < 3.8
import importlib_metadata as metadata # type: ignore


@lru_cache
def package_available(package_name: str) -> bool:
Expand Down Expand Up @@ -61,6 +66,30 @@ def module_available(module_path: str) -> bool:
return True


def compare_version(package: str, op: Callable, version: str, use_base_version: bool = False) -> bool:
"""Compare package version with some requirements.

>>> compare_version("torch", operator.ge, "0.1")
True
>>> compare_version("does_not_exist", operator.ge, "0.0")
False

"""
try:
pkg = importlib.import_module(package)
except (ImportError, RuntimeError):
return False
try:
# Use importlib.metadata to infer version
pkg_version = Version(pkg.__version__) if hasattr(pkg, "__version__") else Version(_version(package))
except (TypeError, PackageNotFoundError):
# this is mocked by Sphinx, so it should return True to generate all summaries
return True
if use_base_version:
pkg_version = Version(pkg_version.base_version)
return op(pkg_version, Version(version))


class RequirementCache:
"""Boolean-like class to check for requirement and module availability.

Expand All @@ -80,42 +109,255 @@ class RequirementCache:
True
>>> bool(RequirementCache("unknown_package"))
False
>>> bool(RequirementCache(module="torch.utils"))
True
>>> bool(RequirementCache(module="unknown_package"))
False
>>> bool(RequirementCache(module="unknown.module.path"))
False

"""

def __init__(self, requirement: str, module: str | None = None) -> None:
def __init__(self, requirement: str | None = None, module: str | None = None) -> None:
if not (requirement or module):
raise ValueError("At least one arguments need to be set.")
self.requirement = requirement
self.module = module

def _check_requirement(self) -> None:
if hasattr(self, "available"):
return
if not self.requirement:
raise ValueError("Requirement name is required.")
try:
# first try the pkg_resources requirement
pkg_resources.require(self.requirement)
self.available = True
self.message = f"Requirement {self.requirement!r} met"
except Exception as ex:
req = Requirement(self.requirement)
pkg_version = Version(_version(req.name))
self.available = req.specifier.contains(pkg_version, prereleases=True) and (
not req.extras or self._check_extras_available(req)
)
except (PackageNotFoundError, InvalidVersion) as ex:
self.available = False
self.message = f"{ex.__class__.__name__}: {ex}.\n HINT: Try running `pip install -U {self.requirement!r}`"
requirement_contains_version_specifier = any(c in self.requirement for c in "=<>")
if not requirement_contains_version_specifier or self.module is not None:
self.message = f"{ex.__class__.__name__}: {ex}. HINT: Try running `pip install -U {self.requirement!r}`"

if self.available:
self.message = f"Requirement {self.requirement!r} met"
else:
req_include_version = any(c in self.requirement for c in "=<>")
if not req_include_version or self.module is not None:
module = self.requirement if self.module is None else self.module
# sometimes `pkg_resources.require()` fails but the module is importable
# Sometimes `importlib.metadata.version` fails but the module is importable
self.available = module_available(module)
if self.available:
self.message = f"Module {module!r} available"
self.message = (
f"Requirement {self.requirement!r} not met. HINT: Try running `pip install -U {self.requirement!r}`"
)

def _check_module(self) -> None:
if not self.module:
raise ValueError("Module name is required.")
self.available = module_available(self.module)
if self.available:
self.message = f"Module {self.module!r} available"
else:
self.message = f"Module not found: {self.module!r}. HINT: Try running `pip install -U {self.module}`"

def _check_available(self) -> None:
if hasattr(self, "available"):
return
if self.requirement:
self._check_requirement()
if getattr(self, "available", True) and self.module:
self._check_module()

def _check_extras_available(self, requirement: Requirement) -> bool:
if not requirement.extras:
return True

extra_requirements = self._get_extra_requirements(requirement)

if not extra_requirements:
# The specified extra is not found in the package metadata
return False

# Verify each extra requirement is installed
for extra_req in extra_requirements:
try:
extra_dist = distribution(extra_req.name)
extra_installed_version = Version(extra_dist.version)
if extra_req.specifier and not extra_req.specifier.contains(extra_installed_version, prereleases=True):
return False
except importlib.metadata.PackageNotFoundError:
return False

return True

def _get_extra_requirements(self, requirement: Requirement) -> list[Requirement]:
dist = distribution(requirement.name)
# Get the required dependencies for the specified extras
extra_requirements = dist.metadata.get_all("Requires-Dist") or []
return [Requirement(r) for r in extra_requirements if any(extra in r for extra in requirement.extras)]

def __bool__(self) -> bool:
"""Format as bool."""
self._check_requirement()
self._check_available()
return self.available

def __str__(self) -> str:
"""Format as string."""
self._check_requirement()
self._check_available()
return self.message

def __repr__(self) -> str:
"""Format as string."""
return self.__str__()


class ModuleAvailableCache(RequirementCache):
"""Boolean-like class for check of module availability.

>>> ModuleAvailableCache("torch")
Module 'torch' available
>>> bool(ModuleAvailableCache("torch.utils"))
True
>>> bool(ModuleAvailableCache("unknown_package"))
False
>>> bool(ModuleAvailableCache("unknown.module.path"))
False

"""

def __init__(self, module: str) -> None:
warnings.warn(
"`ModuleAvailableCache` is a special case of `RequirementCache`."
" Please use `RequirementCache(module=...)` instead.",
DeprecationWarning,
stacklevel=4,
)
super().__init__(module=module)


def get_dependency_min_version_spec(package_name: str, dependency_name: str) -> str:
"""Return the minimum version specifier of a dependency of a package.

>>> get_dependency_min_version_spec("pytorch-lightning==1.8.0", "jsonargparse")
'>=4.12.0'

"""
dependencies = metadata.requires(package_name) or []
for dep in dependencies:
dependency = Requirement(dep)
if dependency.name == dependency_name:
spec = [str(s) for s in dependency.specifier if str(s)[0] == ">"]
return spec[0] if spec else ""
raise ValueError(
"This is an internal error. Please file a GitHub issue with the error message. Dependency "
f"{dependency_name!r} not found in package {package_name!r}."
)


class LazyModule(ModuleType):
"""Proxy module that lazily imports the underlying module the first time it is actually used.

Args:
module_name: the fully-qualified module name to import
callback: a callback function to call before importing the module

"""

def __init__(self, module_name: str, callback: Callable | None = None) -> None:
super().__init__(module_name)
self._module: Any = None
self._callback = callback

def __getattr__(self, item: str) -> Any:
"""Lazily import the underlying module and delegate attribute access to it."""
if self._module is None:
self._import_module()

return getattr(self._module, item)

def __dir__(self) -> list[str]:
"""Lazily import the underlying module and return its attributes for introspection (dir())."""
if self._module is None:
self._import_module()

return dir(self._module)

def _import_module(self) -> None:
# Execute callback, if any
if self._callback is not None:
self._callback()

# Actually import the module
self._module = importlib.import_module(self.__name__)

# Update this object's dict so that attribute references are efficient
# (__getattr__ is only called on lookups that fail)
self.__dict__.update(self._module.__dict__)


def lazy_import(module_name: str, callback: Callable | None = None) -> LazyModule:
"""Return a proxy module object that will lazily import the given module the first time it is used.

Example usage:

# Lazy version of `import tensorflow as tf`
tf = lazy_import("tensorflow")
# Other commands
# Now the module is loaded
tf.__version__

Args:
module_name: the fully-qualified module name to import
callback: a callback function to call before importing the module

Returns:
a proxy module object that will be lazily imported when first used

"""
return LazyModule(module_name, callback=callback)


def requires(*module_path_version: str, raise_exception: bool = True) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Decorator to check optional dependencies at call time with a clear error/warning message.

Args:
module_path_version: Python module paths (e.g., ``"torch.cuda"``) and/or pip-style requirements
(e.g., ``"torch>=2.0.0"``) to verify.
raise_exception: If ``True``, raise ``ModuleNotFoundError`` when requirements are not satisfied;
otherwise emit a warning and proceed to call the function.

Example:
>>> @requires("libpath", raise_exception=bool(int(os.getenv("LIGHTING_TESTING", "0"))))
... def my_cwd():
... from pathlib import Path
... return Path(__file__).parent

>>> class MyRndPower:
... @requires("math", "random")
... def __init__(self):
... from math import pow
... from random import randint
... self._rnd = pow(randint(1, 9), 2)

"""

def decorator(func: Callable[P, T]) -> Callable[P, T]:
reqs = [
ModuleAvailableCache(mod_ver) if "." in mod_ver else RequirementCache(mod_ver)
for mod_ver in module_path_version
]
available = all(map(bool, reqs))

@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
if not available:
missing = os.linesep.join([repr(r) for r in reqs if not bool(r)])
msg = f"Required dependencies not available: \n{missing}"
if raise_exception:
raise ModuleNotFoundError(msg)
warnings.warn(msg, stacklevel=2)
return func(*args, **kwargs)

return wrapper

return decorator
Loading
Loading