Source code for kgdata.splitter

"""Functions to split a big file into smaller files.
"""

from bz2 import BZ2File
from gzip import GzipFile
from io import TextIOWrapper
import shutil
from pathlib import Path
from typing import (
    BinaryIO,
    Callable,
    ContextManager,
    Iterable,
    List,
    Tuple,
    Union,
)
from serde.helper import get_open_fn

from sm.misc.prelude import (
    datasize,
    identity_func,
    import_func,
)
from tqdm import tqdm
from multiprocessing import Process, Queue
import serde.byteline


[docs]def default_currentbyte_constructor( file_object: Union[BZ2File, GzipFile, BinaryIO, TextIOWrapper] ) -> Callable[[], int]: """Get a function that returns the current byte position that the file reader is currently at.""" if isinstance(file_object, BZ2File): return file_object._buffer.raw._fp.tell # type: ignore if isinstance(file_object, GzipFile): return file_object.fileobj.tell # type: ignore return file_object.tell
[docs]def split_a_file( infile: Union[str, Path, Callable[[], Tuple[int, ContextManager[BinaryIO]]]], outfile: Union[str, Path], record_iter: Callable[ [Union[BZ2File, GzipFile, BinaryIO]], Iterable[bytes] ] = identity_func, record_postprocess: str = "kgdata.splitter.strip_newline", currentbyte_constructor: Callable[ [Union[BZ2File, GzipFile, BinaryIO]], Callable[[], int] ] = default_currentbyte_constructor, override: bool = False, n_writers: int = 8, n_records_per_file: int = 64000, ): r"""Split a file containing a list of records into smaller files stored in a directory. The list of records are written in a round-robin fashion by multiple writers (processes) in parallel but read process is run in sequence. Args: infile: path of input file (e.g., '/data/input/bigfile.json.gz') or a function that returns a file object (opened in binary mode) and its size in bytes. outfile: template of path of output file (e.g., '/data/outputs/smallfile.json.gz') from the template, this function will write to files in the parent folder (e.g., '/data/outputs') with files named 'smallfile-<number>.json.gz' and an extra file named '_SUCCESS' to indicate that the job is done. record_iter: a function that returns an iterator of records given a file object, by default it returns the file object itself. record_postprocess: name/path to import the function that post-process an record. by default we strip the newline from the end of the string. when the function returns None, skip the record. currentbyte_constructor: a function that returns a function that returns the current byte position of a file object. override: whether to override existing files. n_writers: number of parallel writers. n_records_per_file: number of records per file. """ outfile = Path(outfile) outdir = outfile.parent if outdir.exists(): if not override and (outdir / "_SUCCESS").exists(): return shutil.rmtree(outdir) outdir.mkdir(parents=True) queues = [] writers = [] for i in range(n_writers): name_parts = outfile.name.split(".", 1) name_parts[0] = name_parts[0] + "-%02d{auto:05d}" % i writer_file = str(outdir / ".".join(name_parts)) queues.append(Queue()) writers.append( Process( target=write_to_file, args=(writer_file, n_records_per_file, record_postprocess, queues[i]), ) ) writers[i].start() if isinstance(infile, (str, Path)): file_object = get_open_fn(infile)(infile, "rb") file_size = Path(infile).stat().st_size else: assert isinstance(infile, Callable) file_size, file_object = infile() if file_size == 0: file_size = 1 data_size_file_size = datasize(file_size) success = True try: with file_object as f, tqdm( total=file_size, desc="splitting", unit="B", unit_scale=True, ) as pbar: last_bytes = 0 try: tell = currentbyte_constructor(f) except: success = False raise for i, line in enumerate(record_iter(f)): queues[i % n_writers].put(line) current_bytes = tell() pbar.update(current_bytes - last_bytes) last_bytes = current_bytes finally: print(">>> Finish! Waiting to exit...") for q in queues: q.put(None) for p in writers: p.join() success = success and (p.exitcode == 0) if success: (outdir / "_SUCCESS").touch()
[docs]def write_to_file( outfile_template: str, n_records_per_file: int, record_postprocessing: str, queue: Queue, ): """Write records from a queue to a file. Args: outfile_template: template of path of output file n_records_per_file: number of records per file record_postprocessing: name/path to import the function that post-process an record. the function can return None to skip the record. queue: a queue that yields records to be written to a file, when it yields None, the writer stops. """ file_counter = 0 outfile = outfile_template.format(auto=file_counter) writer = get_open_fn(outfile)(outfile, "wb") n_records = 0 postprocess_fn = import_func(record_postprocessing) while True: record = queue.get() if record is None: break n_records += 1 if n_records % n_records_per_file == 0: writer.close() file_counter += 1 outfile = outfile_template.format(auto=file_counter) writer = get_open_fn(outfile)(outfile, "wb") record = postprocess_fn(record) if record is None: continue writer.write(record) writer.write(b"\n") writer.close()
[docs]def strip_newline(line: bytes) -> bytes: """Strip newline from a line.""" return line.rstrip(b"\n")
[docs]def split_a_list( lst: List[bytes], outfile: Union[str, Path], n_records_per_file: int = 64000 ): outfile = Path(outfile) outfile.parent.mkdir(exist_ok=True, parents=True) name_parts = outfile.name.split(".", 1) name_parts[0] = name_parts[0] + "-{auto:05d}" name_template = str(outfile.parent / ".".join(name_parts)) counter = 0 for i in tqdm(range(0, len(lst), n_records_per_file), desc="splitting"): serde.byteline.ser( lst[i : i + n_records_per_file], name_template.format(auto=counter) ) counter += 1