Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 51 additions & 27 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use pyo3::prelude::*;
use external_sort::{ExternalSorter, ExternallySortable};
use itertools::Itertools;
use std::io::{BufRead, BufReader, Read, Write};
use pyo3::prelude::*;
use std::io::{BufRead, BufReader, BufWriter, Read, SeekFrom, Write};

// Define a string structure that can be sorted externally
#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord)]
Expand Down Expand Up @@ -38,6 +37,34 @@ fn sort_lines(lines: Vec<String>) -> Vec<String> {
lines
}

fn streaming_sort_until<'a, IN, OUT, STR>(
input: IN,
mut output: OUT,
end: &str,
) -> std::io::Result<()>
where
// can take &str or &String
STR: AsRef<str>,
IN: Iterator<Item = STR>,
OUT: Write,
{
let input = input
.take_while(|l| l.as_ref() != end)
.map(|l| TsvLine::new(l.as_ref()));
// Do the external sort
let sorted = ExternalSorter::new(1000000, None)
.sort_by(input, |a, b| {
tsv_cmp(a.the_line.as_str(), b.the_line.as_str())
})
.unwrap();
// Write the sorted lines to the output file
for line in sorted {
writeln!(&mut output, "{}", line.unwrap().the_line)?;
}
writeln!(&mut output, "{end}")?;
Ok(())
}

/// Merge sort a range of lines from an input file and write the result to another file.
///
/// The function `sort_file_lines` seeks to the given start position in the input file, reads
Expand Down Expand Up @@ -76,31 +103,22 @@ fn sort_lines(lines: Vec<String>) -> Vec<String> {
///
#[pyfunction]
fn sort_file_lines(input: &str, output: &str, start: u64, end: &str) -> PyResult<u64> {
// Open the input file and seek to the start position
let mut input_file = std::fs::File::open(input)?;
input_file.seek(std::io::SeekFrom::Start(start))?;
// Wrap the input file in a buffered reader
let mut input = BufReader::new(&mut input_file);
// Create an iterator which reads lines until the end marker and doesn't consume the end marker
let mut binding = input.by_ref().lines().peekable();
let lines = binding
.peeking_take_while(|line| line.as_ref().map(|l| l != end).unwrap_or(false))
.map(|line| TsvLine::new(&line.unwrap()));
// Do the external sort
let iter = ExternalSorter::new(1000000, None).sort_by(
lines,
|a, b| tsv_cmp(a.the_line.as_str(), b.the_line.as_str()),
).unwrap();
// Write the sorted lines to the output file
let output_file = std::fs::File::create(output)?;
let mut output = std::io::BufWriter::new(output_file);
for line in iter {
writeln!(output, "{}", line.unwrap().the_line)?;
}
// Write the end marker (which was not consumed by peeking_take_while)
writeln!(output, "{}", binding.next().unwrap().unwrap())?;
// Open the input file with a buffer
let mut input = BufReader::new(File::open(input)?);

// Seek to the start position
input.seek(SeekFrom::Start(start))?;

// Open the output file
let mut output = BufWriter::new(File::create(output)?);

// Create the lines iterator
let lines = input.by_ref().lines().map(|l| l.unwrap());
streaming_sort_until(lines, output, end)?;
// sort_until(input.as_ref().lines(), &mut output)?;

// return the stream position from the counting reader object
Ok(input.stream_position().unwrap())
Ok(input.stream_position().unwrap() - (end.bytes().len() as u64))
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a 50% chance that there is an off-by-one error here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also it is weird that we write this to the result but still seek backwards as if it is not read from the source. Is this right?

}

/// A Python module implemented in Rust.
Expand Down Expand Up @@ -413,4 +431,10 @@ mod tests {
expected as i8,
);
}
#[test]
fn streaming_sort_smoke() {
let mut res = Vec::new();
streaming_sort_until("1\n3\n2\nEND".lines(), &mut res, "END").unwrap();
assert_eq!(std::str::from_utf8(&res).unwrap(), "1\n2\n3\nEND\n");
}
}