Source code for reforge.utils

"""Utility wrappers and functions

Description:
    This module provides utility functions and decorators for the reForge workflow.
    It includes decorators for timing and memory profiling functions, a context manager
    for changing the working directory, and helper functions for cleaning directories and
    detecting CUDA availability.

Requirements:
    - Python 3.x

Author: DY
Date: YYYY-MM-DD
"""

import logging
import os
import time
import tracemalloc
import warnings
from contextlib import contextmanager
from functools import wraps
from pathlib import Path

[docs] def get_logger(name="reforge"): """Get the configured logger instance. Since logging is now configured in reforge.__init__.py, this function simply returns the already-configured logger instance. Parameters ---------- name : str, optional Logger name (default: "reforge") Returns ------- logging.Logger The configured logger instance """ return logging.getLogger(name)
# Backward compatibility - provide logger at module level # For new code, prefer: from reforge.utils import get_logger; logger = get_logger() logger = get_logger()
[docs] def timeit(*args, **kwargs): """Backwards-compatible timeit decorator""" # If called with no args, it's being used as @timeit if len(args) == 0: return _timeit(**kwargs) # If first arg is a function, it's being used as @timeit if len(args) == 1 and callable(args[0]): return _timeit()(args[0]) # If first arg is a level, it's being used as @timeit(level=...) if len(args) == 1 and isinstance(args[0], int): return _timeit(level=args[0]) # New style with explicit parameters return _timeit(*args, **kwargs)
def _timeit(level=logging.DEBUG, unit='s'): """Decorator to measure and log execution time of a function, with adjustable log level and time unit. Args: level (int): Logging level (default: logging.DEBUG) unit (str): Time unit to display. Options: - 'ms': milliseconds - 's': seconds (default) - 'm': minutes - 'auto': automatically choose best unit """ def decorator(func): @wraps(func) def wrapper(*args, **kwargs): start_time = time.perf_counter() result = func(*args, **kwargs) end_time = time.perf_counter() execution_time = end_time - start_time # Convert to requested unit if unit == 'ms' or (unit == 'auto' and execution_time < 1): display_time = execution_time * 1000 unit_str = 'milliseconds' elif unit == 'm' or (unit == 'auto' and execution_time > 60): display_time = execution_time / 60 unit_str = 'minutes' else: # seconds is default display_time = execution_time unit_str = 'seconds' logger.log( level, "Function '%s.%s' executed in %.6f %s", func.__module__, func.__name__, display_time, unit_str, ) return result return wrapper return decorator
[docs] def memprofit(*args, **kwargs): """Backwards-compatible memory profiling decorator""" # If called with no args, it's being used as @memprofit if len(args) == 0: return _memprofit(**kwargs) # If first arg is a function, it's being used as @memprofit if len(args) == 1 and callable(args[0]): return _memprofit()(args[0]) # If first arg is a level, it's being used as @memprofit(level=...) if len(args) == 1 and isinstance(args[0], int): return _memprofit(level=args[0]) # New style with explicit parameters return _memprofit(*args, **kwargs)
def _memprofit(level=logging.DEBUG): """Decorator to profile and log the memory usage of a function.""" def decorator(func): @wraps(func) def wrapper(*args, **kwargs): tracemalloc.start() # Start memory tracking result = func(*args, **kwargs) # Execute the function current, peak = tracemalloc.get_traced_memory() # Get memory usage logger.log( level, "Memory usage after executing '%s.%s': %.2f MB, Peak: %.2f MB", func.__module__, func.__name__, current / 1024**2, peak / 1024**2, ) tracemalloc.stop() # Stop memory tracking return result return wrapper return decorator
[docs] @contextmanager def cd(newdir): """ Context manager to temporarily change the current working directory. Parameters: newdir (str or Path): The target directory to change into. Yields: None. After the context, reverts to the original directory. """ prevdir = Path.cwd() os.chdir(newdir) logger.info("Changed working directory to: %s", newdir) try: yield finally: os.chdir(prevdir)
[docs] def clean_dir(directory=".", pattern="#*"): """ Remove files matching a specific pattern from a directory. Parameters: directory (str or Path, optional): Directory to search (default: current directory). pattern (str, optional): Glob pattern for files to remove (default: "#*"). """ directory = Path(directory) for file_path in directory.glob(pattern): if file_path.is_file(): file_path.unlink()