Skip to content

Commit 5ac31a4

Browse files
committed
Add node to convert record batches into csp.Structs
Signed-off-by: Arham Chopra <arham.chopra@cubistsystematic.com>
1 parent f70df59 commit 5ac31a4

File tree

6 files changed

+229
-102
lines changed

6 files changed

+229
-102
lines changed

cpp/csp/adapters/parquet/ParquetReader.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,10 +382,13 @@ void SingleFileParquetReader::clear()
382382
}
383383

384384
InMemoryTableParquetReader::InMemoryTableParquetReader( GeneratorPtr generatorPtr, std::vector<std::string> columns,
385-
bool allowMissingColumns, std::optional<std::string> symbolColumnName )
385+
bool allowMissingColumns, std::optional<std::string> symbolColumnName, bool call_init )
386386
: SingleTableParquetReader( columns, true, allowMissingColumns, symbolColumnName ), m_generatorPtr( generatorPtr )
387387
{
388-
init();
388+
if( call_init )
389+
{
390+
init();
391+
}
389392
}
390393

391394
bool InMemoryTableParquetReader::openNextFile()

cpp/csp/adapters/parquet/ParquetReader.h

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -424,21 +424,28 @@ class SingleFileParquetReader final : public SingleTableParquetReader
424424
bool m_allowMissingFiles;
425425
};
426426

427-
class InMemoryTableParquetReader final : public SingleTableParquetReader
427+
class InMemoryTableParquetReader : public SingleTableParquetReader
428428
{
429429
public:
430430
using GeneratorPtr = csp::Generator<std::shared_ptr<arrow::Table>, csp::DateTime, csp::DateTime>::Ptr;
431431

432432
InMemoryTableParquetReader( GeneratorPtr generatorPtr, std::vector<std::string> columns,
433433
bool allowMissingColumns,
434-
std::optional<std::string> symbolColumnName = {} );
434+
std::optional<std::string> symbolColumnName = {},
435+
bool call_init = true);
435436
std::string getCurFileOrTableName() const override{ return "IN_MEMORY_TABLE"; }
436437

437438
protected:
438-
bool openNextFile() override;
439-
bool readNextRowGroup() override;
439+
virtual bool openNextFile() override;
440+
virtual bool readNextRowGroup() override;
441+
void setTable( std::shared_ptr<arrow::Table> table )
442+
{
443+
m_fullTable = table;
444+
m_nextChunkIndex = 0;
445+
m_curTable = nullptr;
446+
}
440447

441-
void clear() override;
448+
virtual void clear() override;
442449

443450
private:
444451
GeneratorPtr m_generatorPtr;

cpp/csp/python/CMakeLists.txt

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,35 @@ target_compile_definitions(cspimpl PUBLIC NPY_NO_DEPRECATED_API=NPY_1_7_API_VERS
9090
target_compile_definitions(cspimpl PRIVATE CSPIMPL_EXPORTS=1)
9191

9292

93+
find_package(Arrow REQUIRED)
94+
find_package(Parquet REQUIRED)
95+
96+
if(WIN32)
97+
if(CSP_USE_VCPKG)
98+
set(ARROW_PACKAGES_TO_LINK Arrow::arrow_static Parquet::parquet_static )
99+
target_compile_definitions(csp_parquet_adapter PUBLIC ARROW_STATIC)
100+
target_compile_definitions(csp_parquet_adapter PUBLIC PARQUET_STATIC)
101+
else()
102+
# use dynamic variants
103+
# Until we manage to get the fix for ws3_32.dll in arrow-16 into conda, manually fix the error here
104+
get_target_property(LINK_LIBS Arrow::arrow_shared INTERFACE_LINK_LIBRARIES)
105+
string(REPLACE "ws2_32.dll" "ws2_32" FIXED_LINK_LIBS "${LINK_LIBS}")
106+
set_target_properties(Arrow::arrow_shared PROPERTIES INTERFACE_LINK_LIBRARIES "${FIXED_LINK_LIBS}")
107+
set(ARROW_PACKAGES_TO_LINK parquet_shared arrow_shared)
108+
endif()
109+
else()
110+
if(CSP_USE_VCPKG)
111+
# use static variants
112+
set(ARROW_PACKAGES_TO_LINK parquet_static arrow_static)
113+
else()
114+
# use dynamic variants
115+
set(ARROW_PACKAGES_TO_LINK parquet arrow)
116+
endif()
117+
endif()
118+
93119
## Baselib c++ module
94120
add_library(cspbaselibimpl SHARED cspbaselibimpl.cpp)
95-
target_link_libraries(cspbaselibimpl cspimpl baselibimpl)
121+
target_link_libraries(cspbaselibimpl cspimpl baselibimpl csp_parquet_adapter ${ARROW_PACKAGES_TO_LINK})
96122

97123
# Include exprtk include directory for exprtk node
98124
target_include_directories(cspbaselibimpl PRIVATE ${EXPRTK_INCLUDE_DIRS})

cpp/csp/python/cspbaselibimpl.cpp

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,15 @@
55
#include <exprtk.hpp>
66
#include <numpy/ndarrayobject.h>
77

8+
#include <arrow/type.h>
9+
#include <arrow/table.h>
10+
#include <arrow/c/abi.h>
11+
#include <arrow/c/bridge.h>
12+
13+
#include <csp/adapters/parquet/ParquetReader.h>
14+
#include <csp/adapters/utils/StructAdapterInfo.h>
15+
#include <csp/adapters/utils/ValueDispatcher.h>
16+
817
static void * init_nparray()
918
{
1019
csp::python::AcquireGIL gil;
@@ -325,6 +334,137 @@ DECLARE_CPPNODE( exprtk_impl )
325334

326335
EXPORT_CPPNODE( exprtk_impl );
327336

337+
DECLARE_CPPNODE( record_batches_to_struct )
338+
{
339+
using InMemoryTableParquetReader = csp::adapters::parquet::InMemoryTableParquetReader;
340+
using SingleTableParquetReader = csp::adapters::parquet::SingleTableParquetReader;
341+
class MyTableReader : public InMemoryTableParquetReader
342+
{
343+
public:
344+
MyTableReader( std::vector<std::string> columns, std::shared_ptr<arrow::Schema> schema ):
345+
InMemoryTableParquetReader( nullptr, columns, false, {}, false )
346+
{
347+
m_schema = schema;
348+
}
349+
std::string getCurFileOrTableName() const override{ return "IN_RECORD_BATCH"; }
350+
void initialize() { setColumnAdaptersFromCurrentTable(); }
351+
void parseBatches( std::vector<std::shared_ptr<arrow::RecordBatch>> record_batches )
352+
{
353+
// TODO: Check if the schema has not changed
354+
auto table_result = arrow::Table::FromRecordBatches(record_batches);
355+
if( !table_result.ok() )
356+
CSP_THROW( NotImplemented, "Unable to make table from record batches" );
357+
358+
setTable( table_result.ValueUnsafe() );
359+
360+
if( !readNextRowGroup() )
361+
CSP_THROW( NotImplemented, "Unable to read row group from table" );
362+
363+
while( readNextRow() )
364+
{
365+
for( auto& adapter: getStructAdapters() )
366+
{
367+
adapter -> dispatchValue( nullptr );
368+
}
369+
}
370+
}
371+
void stop()
372+
{
373+
InMemoryTableParquetReader::clear();
374+
}
375+
protected:
376+
bool openNextFile() override { return false; }
377+
void clear() override { setTable( nullptr ); }
378+
};
379+
380+
SCALAR_INPUT( DialectGenericType, schema_ptr );
381+
SCALAR_INPUT( StructMetaPtr, cls );
382+
SCALAR_INPUT( DictionaryPtr, properties );
383+
TS_INPUT( Generic, data );
384+
385+
TS_OUTPUT( Generic );
386+
387+
std::shared_ptr<MyTableReader> reader;
388+
CspTypePtr outType;
389+
std::vector<StructPtr>* m_structsVecPtr;
390+
391+
using StructAdapterInfo = csp::adapters::utils::StructAdapterInfo;
392+
using ValueDispatcher = csp::adapters::utils::ValueDispatcher<StructPtr &>;
393+
394+
INIT_CPPNODE( record_batches_to_struct )
395+
{
396+
auto & input_def = tsinputDef( "data" );
397+
if( input_def.type -> type() != CspType::Type::ARRAY )
398+
CSP_THROW( TypeError, "record_batches_to_struct expected ts array type, got " << input_def.type -> type() );
399+
400+
auto * aType = static_cast<const CspArrayType *>( input_def.type.get() );
401+
CspTypePtr elemType = aType -> elemType();
402+
if( elemType -> type() != CspType::Type::DIALECT_GENERIC )
403+
CSP_THROW( TypeError, "record_batches_to_struct expected ts array of DIALECT_GENERIC type, got " << elemType -> type() );
404+
405+
auto & output_def = tsoutputDef( "" );
406+
if( output_def.type -> type() != CspType::Type::ARRAY )
407+
CSP_THROW( NotImplemented, "record_batches_to_struct expected ts array type, got " << output_def.type -> type() );
408+
}
409+
410+
START()
411+
{
412+
// Create Adapters for Schema
413+
PyObject* capsule = csp::python::toPythonBorrowed(schema_ptr);
414+
struct ArrowSchema* c_schema = reinterpret_cast<struct ArrowSchema*>( PyCapsule_GetPointer(capsule, "arrow_schema") );
415+
auto result = arrow::ImportSchema(c_schema);
416+
if( !result.ok() )
417+
CSP_THROW( NotImplemented, "Unable to import schema" );
418+
std::shared_ptr<arrow::Schema> schema = result.ValueUnsafe();
419+
std::vector<std::string> columns;
420+
for( int idx = 0; idx < schema -> num_fields(); idx++ )
421+
{
422+
auto& field = schema -> field( idx );
423+
columns.push_back(field -> name());
424+
}
425+
reader = std::make_shared<MyTableReader>( columns, schema );
426+
reader -> initialize();
427+
428+
outType = std::make_shared<csp::CspStructType>( cls.value() );
429+
auto field_map = properties.value() -> get<DictionaryPtr>( "field_map" );
430+
StructAdapterInfo key{ outType, field_map };
431+
auto& struct_adapter = reader -> getStructAdapter( key );
432+
struct_adapter.addSubscriber( [this]( StructPtr * s )
433+
{
434+
if( s ) this -> m_structsVecPtr -> push_back( *s );
435+
else CSP_THROW( NotImplemented, "StructPtr was null" );
436+
}, {} );
437+
}
438+
439+
INVOKE()
440+
{
441+
if( csp.ticked( data ) )
442+
{
443+
auto & py_batches = data.lastValue<std::vector<DialectGenericType>>();
444+
std::vector<std::shared_ptr<arrow::RecordBatch>> batches;
445+
for( auto& py_batch: py_batches )
446+
{
447+
PyObject* py_tuple = csp::python::toPythonBorrowed( py_batch );
448+
PyObject* py_schema = PyTuple_GET_ITEM( py_tuple, 0 );
449+
PyObject* py_array = PyTuple_GET_ITEM( py_tuple, 1 );
450+
struct ArrowSchema* c_schema = reinterpret_cast<struct ArrowSchema*>( PyCapsule_GetPointer( py_schema, "arrow_schema" ) );
451+
struct ArrowArray* c_array = reinterpret_cast<struct ArrowArray*>( PyCapsule_GetPointer( py_array, "arrow_array" ) );
452+
auto result = arrow::ImportRecordBatch(c_array, c_schema);
453+
if( !result.ok() )
454+
CSP_THROW( NotImplemented, "Unable to import record batch from c interface" );
455+
batches.emplace_back(result.ValueUnsafe());
456+
}
457+
std::vector<StructPtr> & out = unnamed_output().reserveSpace<std::vector<StructPtr>>();
458+
out.clear();
459+
m_structsVecPtr = &out;
460+
reader -> parseBatches( batches );
461+
m_structsVecPtr = nullptr;
462+
}
463+
}
464+
};
465+
466+
EXPORT_CPPNODE( record_batches_to_struct );
467+
328468
}
329469

330470
// Base nodes
@@ -350,6 +490,7 @@ REGISTER_CPPNODE( csp::cppnodes, struct_fromts );
350490
REGISTER_CPPNODE( csp::cppnodes, struct_collectts );
351491

352492
REGISTER_CPPNODE( csp::cppnodes, exprtk_impl );
493+
REGISTER_CPPNODE( csp::cppnodes, record_batches_to_struct );
353494

354495
static PyModuleDef _cspbaselibimpl_module = {
355496
PyModuleDef_HEAD_INIT,

0 commit comments

Comments
 (0)