diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a2de9e5c..0c6c590e8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Storage minimum level parameter removed from cylindrical thermal storage [#1123](https://github.com/ie3-institute/PowerSystemDataModel/issues/1123) - Converted eval-rst to myst syntax in ReadTheDocs, fixed line wrapping and widths[#1137](https://github.com/ie3-institute/PowerSystemDataModel/issues/1137) +- Improving usage of streams on sql fetches [#827](https://github.com/ie3-institute/PowerSystemDataModel/issues/827) ## [5.1.0] - 2024-06-24 diff --git a/src/main/java/edu/ie3/datamodel/io/connectors/SqlConnector.java b/src/main/java/edu/ie3/datamodel/io/connectors/SqlConnector.java index b1e463180..eed7307ad 100644 --- a/src/main/java/edu/ie3/datamodel/io/connectors/SqlConnector.java +++ b/src/main/java/edu/ie3/datamodel/io/connectors/SqlConnector.java @@ -9,6 +9,8 @@ import edu.ie3.util.TimeUtil; import java.sql.*; import java.util.*; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -55,6 +57,9 @@ public ResultSet executeQuery(Statement stmt, String query) throws SQLException return stmt.executeQuery(query); } catch (SQLException e) { throw new SQLException(String.format("Error at execution of query \"%1.127s\": ", query), e); + } finally { + // commits any changes made and unlocks database + getConnection().commit(); } } @@ -70,6 +75,9 @@ public int executeUpdate(String query) throws SQLException { } catch (SQLException e) { throw new SQLException( String.format("Error at execution of query, SQLReason: '%s'", e.getMessage()), e); + } finally { + // commits any changes made and unlocks database + getConnection().commit(); } } @@ -85,7 +93,8 @@ public Connection getConnection() throws SQLException { } /** - * Establishes and returns a database connection + * Establishes and returns a database connection. The {@link Connection#getAutoCommit()} is set to + * {@code false}. * * @param reuseConnection should the connection be used again, if it is still valid? If not, a new * connection will be established @@ -98,6 +107,7 @@ public Connection getConnection(boolean reuseConnection) throws SQLException { if (connection != null) connection.close(); connection = DriverManager.getConnection(jdbcUrl, connectionProps); + connection.setAutoCommit(false); } catch (SQLException e) { throw new SQLException("Could not establish connection: ", e); } @@ -115,21 +125,82 @@ public void shutdown() { } /** - * Extracts all field to value maps from the ResultSet, one for each row + * Method to execute a {@link PreparedStatement} and return its result as a stream. * - * @param rs the ResultSet to use - * @return a list of field maps + * @param ps to execute + * @param fetchSize used for {@link PreparedStatement#setFetchSize(int)} + * @return a stream of maps + * @throws SQLException if an exception occurred while executing the query */ - public List> extractFieldMaps(ResultSet rs) { - List> fieldMaps = new ArrayList<>(); + public Stream> toStream(PreparedStatement ps, int fetchSize) + throws SQLException { try { - while (rs.next()) { - fieldMaps.add(extractFieldMap(rs)); + ps.setFetchSize(fetchSize); + ResultSet resultSet = ps.executeQuery(); + Iterator> sqlIterator = getSqlIterator(resultSet); + + return StreamSupport.stream( + Spliterators.spliteratorUnknownSize( + sqlIterator, Spliterator.NONNULL | Spliterator.IMMUTABLE), + true) + .onClose(() -> closeResultSet(ps, resultSet)); + } catch (SQLException e) { + // catches the exception, closes the statement and re-throws the exception + closeResultSet(ps, null); + throw e; + } + } + + /** + * Returns an {@link Iterator} for the given {@link ResultSet}. + * + * @param rs given result set + * @return an iterator + */ + public Iterator> getSqlIterator(ResultSet rs) { + return new Iterator<>() { + @Override + public boolean hasNext() { + try { + return rs.next(); + } catch (SQLException e) { + log.error("Exception at extracting next ResultSet: ", e); + closeResultSet(null, rs); + return false; + } } + + @Override + public Map next() { + try { + boolean isEmpty = !rs.isBeforeFirst() && rs.getRow() == 0; + + if (isEmpty || rs.isAfterLast()) + throw new NoSuchElementException( + "There is no more element to iterate to in the ResultSet."); + + return extractFieldMap(rs); + } catch (SQLException e) { + log.error("Exception at extracting ResultSet: ", e); + closeResultSet(null, rs); + return Collections.emptyMap(); + } + } + }; + } + + /** + * Method for closing a {@link ResultSet}. + * + * @param rs to close + */ + private void closeResultSet(PreparedStatement ps, ResultSet rs) { + try (ps; + rs) { + log.debug("Resources successfully closed."); } catch (SQLException e) { - log.error("Exception at extracting ResultSet: ", e); + log.warn("Failed to properly close sources.", e); } - return fieldMaps; } /** @@ -138,26 +209,24 @@ public List> extractFieldMaps(ResultSet rs) { * @param rs the ResultSet to use * @return the field map for the current row */ - public Map extractFieldMap(ResultSet rs) { + public Map extractFieldMap(ResultSet rs) throws SQLException { TreeMap insensitiveFieldsToAttributes = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); - try { - ResultSetMetaData metaData = rs.getMetaData(); - int columnCount = metaData.getColumnCount(); - for (int i = 1; i <= columnCount; i++) { - String columnName = StringUtils.snakeCaseToCamelCase(metaData.getColumnName(i)); - String value; - Object result = rs.getObject(i); - if (result instanceof Timestamp) { - value = TimeUtil.withDefaults.toString(rs.getTimestamp(i).toInstant()); - } else { - value = String.valueOf(rs.getObject(i)); - } - insensitiveFieldsToAttributes.put(columnName, value); + + ResultSetMetaData metaData = rs.getMetaData(); + int columnCount = metaData.getColumnCount(); + for (int i = 1; i <= columnCount; i++) { + String columnName = StringUtils.snakeCaseToCamelCase(metaData.getColumnName(i)); + String value; + Object result = rs.getObject(i); + if (result instanceof Timestamp) { + value = TimeUtil.withDefaults.toString(rs.getTimestamp(i).toInstant()); + } else { + value = String.valueOf(rs.getObject(i)); } - } catch (SQLException e) { - log.error("Exception at extracting ResultSet: ", e); + insensitiveFieldsToAttributes.put(columnName, value); } + return insensitiveFieldsToAttributes; } } diff --git a/src/main/java/edu/ie3/datamodel/io/source/sql/SqlDataSource.java b/src/main/java/edu/ie3/datamodel/io/source/sql/SqlDataSource.java index 7fafca42a..2d214a260 100644 --- a/src/main/java/edu/ie3/datamodel/io/source/sql/SqlDataSource.java +++ b/src/main/java/edu/ie3/datamodel/io/source/sql/SqlDataSource.java @@ -177,10 +177,15 @@ protected Stream> buildStreamByTableName(String tableName) { * table name. */ protected Stream> executeQuery(String query, AddParams addParams) { - try (PreparedStatement ps = connector.getConnection().prepareStatement(query)) { + try { + PreparedStatement ps = connector.getConnection().prepareStatement(query); addParams.addParams(ps); - ResultSet resultSet = ps.executeQuery(); - return connector.extractFieldMaps(resultSet).stream(); + + // don't work with `try with resource`, therefore manual closing is necessary + // closes automatically after all dependent resultSets are closed + ps.closeOnCompletion(); + + return connector.toStream(ps, 1000); } catch (SQLException e) { log.error("Error during execution of query {}", query, e); } diff --git a/src/main/java/edu/ie3/datamodel/io/source/sql/SqlIdCoordinateSource.java b/src/main/java/edu/ie3/datamodel/io/source/sql/SqlIdCoordinateSource.java index a891dd2ee..83661604c 100644 --- a/src/main/java/edu/ie3/datamodel/io/source/sql/SqlIdCoordinateSource.java +++ b/src/main/java/edu/ie3/datamodel/io/source/sql/SqlIdCoordinateSource.java @@ -18,8 +18,8 @@ import edu.ie3.util.geo.CoordinateDistance; import edu.ie3.util.geo.GeoUtils; import java.sql.Array; -import java.sql.PreparedStatement; import java.util.*; +import java.util.stream.Stream; import javax.measure.quantity.Length; import org.locationtech.jts.geom.Envelope; import org.locationtech.jts.geom.Point; @@ -98,7 +98,8 @@ public Optional> getSourceFields() { @Override public Optional getCoordinate(int id) { - List values = executeQueryToList(queryForPoint, ps -> ps.setInt(1, id)); + List values = + executeQueryToStream(queryForPoint, ps -> ps.setInt(1, id)).toList(); if (values.isEmpty()) { return Optional.empty(); @@ -111,15 +112,14 @@ public Optional getCoordinate(int id) { public Collection getCoordinates(int... ids) { Object[] idSet = Arrays.stream(ids).boxed().distinct().toArray(); - List values = - executeQueryToList( + return executeQueryToStream( queryForPoints, ps -> { Array sqlArray = ps.getConnection().createArrayOf("int", idSet); ps.setArray(1, sqlArray); - }); - - return values.stream().map(value -> value.coordinate).toList(); + }) + .map(value -> value.coordinate) + .toList(); } @Override @@ -128,12 +128,13 @@ public Optional getId(Point coordinate) { double longitude = coordinate.getX(); List values = - executeQueryToList( - queryForId, - ps -> { - ps.setDouble(1, longitude); - ps.setDouble(2, latitude); - }); + executeQueryToStream( + queryForId, + ps -> { + ps.setDouble(1, longitude); + ps.setDouble(2, latitude); + }) + .toList(); if (values.isEmpty()) { return Optional.empty(); @@ -144,23 +145,21 @@ public Optional getId(Point coordinate) { @Override public Collection getAllCoordinates() { - List values = executeQueryToList(basicQuery + ";", PreparedStatement::execute); - - return values.stream().map(value -> value.coordinate).toList(); + return executeQueryToStream(basicQuery + ";").map(value -> value.coordinate).toList(); } @Override public List getNearestCoordinates(Point coordinate, int n) { - List values = - executeQueryToList( - queryForNearestPoints, - ps -> { - ps.setDouble(1, coordinate.getX()); - ps.setDouble(2, coordinate.getY()); - ps.setInt(3, n); - }); - - List points = values.stream().map(value -> value.coordinate).toList(); + List points = + executeQueryToStream( + queryForNearestPoints, + ps -> { + ps.setDouble(1, coordinate.getX()); + ps.setDouble(2, coordinate.getY()); + ps.setInt(3, n); + }) + .map(value -> value.coordinate) + .toList(); return calculateCoordinateDistances(coordinate, n, points); } @@ -185,7 +184,7 @@ private List getCoordinatesInBoundingBox( Point coordinate, ComparableQuantity distance) { Envelope envelope = GeoUtils.calculateBoundingBox(coordinate, distance); - return executeQueryToList( + return executeQueryToStream( queryForBoundingBox, ps -> { ps.setDouble(1, envelope.getMinX()); @@ -193,7 +192,6 @@ private List getCoordinatesInBoundingBox( ps.setDouble(3, envelope.getMaxX()); ps.setDouble(4, envelope.getMaxY()); }) - .stream() .map(value -> value.coordinate) .toList(); } @@ -208,9 +206,13 @@ private CoordinateValue createCoordinateValue(Map fieldToValues) return new CoordinateValue(idCoordinate.id(), idCoordinate.point()); } - private List executeQueryToList( + private Stream executeQueryToStream(String query) { + return dataSource.executeQuery(query).map(this::createCoordinateValue); + } + + private Stream executeQueryToStream( String query, SqlDataSource.AddParams addParams) { - return dataSource.executeQuery(query, addParams).map(this::createCoordinateValue).toList(); + return dataSource.executeQuery(query, addParams).map(this::createCoordinateValue); } /** diff --git a/src/test/groovy/edu/ie3/datamodel/io/connectors/SqlConnectorIT.groovy b/src/test/groovy/edu/ie3/datamodel/io/connectors/SqlConnectorIT.groovy index 66db0baca..047d26987 100644 --- a/src/test/groovy/edu/ie3/datamodel/io/connectors/SqlConnectorIT.groovy +++ b/src/test/groovy/edu/ie3/datamodel/io/connectors/SqlConnectorIT.groovy @@ -121,16 +121,15 @@ class SqlConnectorIT extends Specification implements TestContainerHelper { def "A SQL connector is able to extract all field to value maps from result set"() { given: def preparedStatement = connector.getConnection(false).prepareStatement("SELECT * FROM public.test;") - def resultSet = preparedStatement.executeQuery() when: - def actual = connector.extractFieldMaps(resultSet) + def actual = connector.toStream(preparedStatement, 1).toList() then: actual.size() == 2 cleanup: - resultSet.close() + preparedStatement.close() } def "A SQL connector shuts down correctly, if no connection was opened"() { diff --git a/src/test/groovy/edu/ie3/datamodel/io/sink/SqlSinkTest.groovy b/src/test/groovy/edu/ie3/datamodel/io/sink/SqlSinkTest.groovy index 9e45cf059..67e5065e4 100644 --- a/src/test/groovy/edu/ie3/datamodel/io/sink/SqlSinkTest.groovy +++ b/src/test/groovy/edu/ie3/datamodel/io/sink/SqlSinkTest.groovy @@ -13,46 +13,23 @@ import edu.ie3.datamodel.io.DbGridMetadata import edu.ie3.datamodel.io.connectors.SqlConnector import edu.ie3.datamodel.io.naming.DatabaseNamingStrategy import edu.ie3.datamodel.io.processor.ProcessorProvider -import edu.ie3.datamodel.io.processor.input.InputEntityProcessor -import edu.ie3.datamodel.io.processor.result.ResultEntityProcessor import edu.ie3.datamodel.io.processor.timeseries.TimeSeriesProcessor import edu.ie3.datamodel.io.processor.timeseries.TimeSeriesProcessorKey import edu.ie3.datamodel.io.source.sql.SqlDataSource import edu.ie3.datamodel.models.OperationTime import edu.ie3.datamodel.models.StandardUnits -import edu.ie3.datamodel.models.input.EmInput import edu.ie3.datamodel.models.input.NodeInput import edu.ie3.datamodel.models.input.OperatorInput -import edu.ie3.datamodel.models.input.connector.LineInput -import edu.ie3.datamodel.models.input.connector.Transformer2WInput -import edu.ie3.datamodel.models.input.connector.type.LineTypeInput -import edu.ie3.datamodel.models.input.connector.type.Transformer2WTypeInput -import edu.ie3.datamodel.models.input.graphics.LineGraphicInput -import edu.ie3.datamodel.models.input.graphics.NodeGraphicInput -import edu.ie3.datamodel.models.input.system.EvcsInput -import edu.ie3.datamodel.models.input.system.LoadInput import edu.ie3.datamodel.models.input.system.PvInput import edu.ie3.datamodel.models.input.system.characteristic.CosPhiFixed -import edu.ie3.datamodel.models.input.thermal.CylindricalStorageInput -import edu.ie3.datamodel.models.input.thermal.ThermalBusInput -import edu.ie3.datamodel.models.input.thermal.ThermalHouseInput -import edu.ie3.datamodel.models.result.system.EmResult -import edu.ie3.datamodel.models.result.system.EvResult -import edu.ie3.datamodel.models.result.system.EvcsResult -import edu.ie3.datamodel.models.result.system.FlexOptionsResult -import edu.ie3.datamodel.models.result.system.PvResult -import edu.ie3.datamodel.models.result.system.WecResult +import edu.ie3.datamodel.models.result.system.* import edu.ie3.datamodel.models.timeseries.TimeSeries import edu.ie3.datamodel.models.timeseries.TimeSeriesEntry import edu.ie3.datamodel.models.timeseries.individual.IndividualTimeSeries import edu.ie3.datamodel.models.timeseries.individual.TimeBasedValue import edu.ie3.datamodel.models.value.EnergyPriceValue import edu.ie3.datamodel.models.value.Value -import edu.ie3.test.common.GridTestData -import edu.ie3.test.common.SampleJointGrid -import edu.ie3.test.common.SystemParticipantTestData -import edu.ie3.test.common.ThermalUnitInputTestData -import edu.ie3.test.common.TimeSeriesTestData +import edu.ie3.test.common.* import edu.ie3.test.helper.TestContainerHelper import edu.ie3.util.TimeUtil import org.testcontainers.containers.Container