Merge 'bindings/java: add batching support to JDBC4PreparedStatement' from

# Changes
Support batching multiple DML queries in a single PreparedStatement.
### Java
- the setters of JDBC4PreparedStatement no longer bind to the underlying
native statement directly, but only store the parameter values locally
- On execution the correct set of parameters is bound to the native
statement
### Rust
- Added a helper method to retrieve the parameter count of a statement
# Reference
#615

Reviewed-by: Kim Seon Woo (@seonWKim)

Closes #3971
This commit is contained in:
Pekka Enberg 2025-11-23 09:45:08 +02:00 committed by GitHub
commit 94cd61fb69
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 341 additions and 71 deletions

View file

@ -301,6 +301,41 @@ pub extern "system" fn Java_tech_turso_core_TursoStatement_changes<'local>(
stmt.connection.conn.changes()
}
#[no_mangle]
pub extern "system" fn Java_tech_turso_core_TursoStatement_parameterCount<'local>(
mut env: JNIEnv<'local>,
obj: JObject<'local>,
stmt_ptr: jlong,
) -> jint {
let stmt = match to_turso_statement(stmt_ptr) {
Ok(stmt) => stmt,
Err(e) => {
set_err_msg_and_throw_exception(&mut env, obj, SQLITE_ERROR, e.to_string());
return -1;
}
};
stmt.stmt.parameters_count() as jint
}
#[no_mangle]
pub extern "system" fn Java_tech_turso_core_TursoStatement_reset<'local>(
mut env: JNIEnv<'local>,
obj: JObject<'local>,
stmt_ptr: jlong,
) -> jint {
let stmt = match to_turso_statement(stmt_ptr) {
Ok(stmt) => stmt,
Err(e) => {
set_err_msg_and_throw_exception(&mut env, obj, SQLITE_ERROR, e.to_string());
return -1;
}
};
stmt.stmt.reset();
0
}
/// Converts an optional `JObject` into Java's `TursoStepResult`.
///
/// This function takes an optional `JObject` and converts it into a Java object

View file

@ -20,7 +20,7 @@ public final class TursoStatement {
private final String sql;
private final long statementPointer;
private final TursoResultSet resultSet;
private TursoResultSet resultSet;
private boolean closed;
@ -215,6 +215,32 @@ public final class TursoStatement {
private native int bindBlob(long statementPointer, int position, byte[] value)
throws SQLException;
public void bindObject(int parameterIndex, Object x) throws SQLException {
if (x == null) {
this.bindNull(parameterIndex);
return;
}
if (x instanceof Byte) {
this.bindInt(parameterIndex, (Byte) x);
} else if (x instanceof Short) {
this.bindInt(parameterIndex, (Short) x);
} else if (x instanceof Integer) {
this.bindInt(parameterIndex, (Integer) x);
} else if (x instanceof Long) {
this.bindLong(parameterIndex, (Long) x);
} else if (x instanceof String) {
bindText(parameterIndex, (String) x);
} else if (x instanceof Float) {
bindDouble(parameterIndex, (Float) x);
} else if (x instanceof Double) {
bindDouble(parameterIndex, (Double) x);
} else if (x instanceof byte[]) {
bindBlob(parameterIndex, (byte[]) x);
} else {
throw new SQLException("Unsupported object type in bindObject: " + x.getClass().getName());
}
}
/**
* Returns total number of changes.
*
@ -247,6 +273,34 @@ public final class TursoStatement {
private native long changes(long statementPointer) throws SQLException;
/**
* Returns the number of parameters in this statement. Parameters are the `?`'s that get replaced
* by the provided arguments.
*
* @throws SQLException If a database access error occurs
*/
public int parameterCount() throws SQLException {
final int result = parameterCount(statementPointer);
if (result == -1) {
throw new SQLException("Exception while retrieving parameter count");
}
return result;
}
private native int parameterCount(long statementPointer) throws SQLException;
/** Resets this statement so it's ready for re-execution */
public void reset() throws SQLException {
final int result = reset(statementPointer);
if (result == -1) {
throw new SQLException("Exception while resetting statement");
}
this.resultSet = TursoResultSet.of(this);
}
private native int reset(long statementPointer) throws SQLException;
/**
* Checks if the statement is closed.
*

View file

@ -10,23 +10,11 @@ import java.math.BigDecimal;
import java.net.URL;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.sql.Array;
import java.sql.Blob;
import java.sql.Clob;
import java.sql.Date;
import java.sql.NClob;
import java.sql.ParameterMetaData;
import java.sql.PreparedStatement;
import java.sql.Ref;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.RowId;
import java.sql.SQLException;
import java.sql.SQLFeatureNotSupportedException;
import java.sql.SQLXML;
import java.sql.Time;
import java.sql.Timestamp;
import java.sql.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Calendar;
import tech.turso.annotations.Nullable;
import tech.turso.annotations.SkipNullableCheck;
import tech.turso.core.TursoResultSet;
@ -36,6 +24,10 @@ public final class JDBC4PreparedStatement extends JDBC4Statement implements Prep
private final String sql;
private final JDBC4ResultSet resultSet;
private final int paramCount;
private Object[] currentBatchParams;
private final ArrayList<Object[]> batchQueryParams = new ArrayList<>();
/**
* Creates a new JDBC4PreparedStatement.
*
@ -48,97 +40,110 @@ public final class JDBC4PreparedStatement extends JDBC4Statement implements Prep
this.sql = sql;
this.statement = connection.prepare(sql);
this.resultSet = new JDBC4ResultSet(this.statement.getResultSet());
this.paramCount = statement.parameterCount();
this.currentBatchParams = new Object[paramCount];
}
@Override
public ResultSet executeQuery() throws SQLException {
// TODO: check bindings etc
bindParams(currentBatchParams);
return this.resultSet;
}
@Override
public int executeUpdate() throws SQLException {
requireNonNull(this.statement);
bindParams(currentBatchParams);
final TursoResultSet resultSet = statement.getResultSet();
resultSet.consumeAll();
return Math.toIntExact(statement.changes());
}
/**
* This helper method saves a parameter locally without binding it to the underlying native
* statement. We have to do this so we are able to switch between different sets of parameters
* when batching queries.
*/
private void setParam(int parameterIndex, @Nullable Object object) {
requireNonNull(this.statement);
currentBatchParams[parameterIndex - 1] = object;
}
@Override
public void setNull(int parameterIndex, int sqlType) throws SQLException {
requireNonNull(this.statement);
this.statement.bindNull(parameterIndex);
setParam(parameterIndex, null);
}
@Override
public void setBoolean(int parameterIndex, boolean x) throws SQLException {
requireNonNull(this.statement);
this.statement.bindInt(parameterIndex, x ? 1 : 0);
setParam(parameterIndex, x ? 1 : 0);
}
@Override
public void setByte(int parameterIndex, byte x) throws SQLException {
requireNonNull(this.statement);
this.statement.bindInt(parameterIndex, x);
setParam(parameterIndex, x);
}
@Override
public void setShort(int parameterIndex, short x) throws SQLException {
requireNonNull(this.statement);
this.statement.bindInt(parameterIndex, x);
setParam(parameterIndex, x);
}
@Override
public void setInt(int parameterIndex, int x) throws SQLException {
requireNonNull(this.statement);
this.statement.bindInt(parameterIndex, x);
setParam(parameterIndex, x);
}
@Override
public void setLong(int parameterIndex, long x) throws SQLException {
requireNonNull(this.statement);
this.statement.bindLong(parameterIndex, x);
setParam(parameterIndex, x);
}
@Override
public void setFloat(int parameterIndex, float x) throws SQLException {
requireNonNull(this.statement);
this.statement.bindDouble(parameterIndex, x);
setParam(parameterIndex, x);
}
@Override
public void setDouble(int parameterIndex, double x) throws SQLException {
requireNonNull(this.statement);
this.statement.bindDouble(parameterIndex, x);
setParam(parameterIndex, x);
}
@Override
public void setBigDecimal(int parameterIndex, BigDecimal x) throws SQLException {
requireNonNull(this.statement);
this.statement.bindText(parameterIndex, x.toString());
setParam(parameterIndex, x.toString());
}
@Override
public void setString(int parameterIndex, String x) throws SQLException {
requireNonNull(this.statement);
this.statement.bindText(parameterIndex, x);
setParam(parameterIndex, x);
}
@Override
public void setBytes(int parameterIndex, byte[] x) throws SQLException {
requireNonNull(this.statement);
this.statement.bindBlob(parameterIndex, x);
setParam(parameterIndex, x);
}
@Override
public void setDate(int parameterIndex, Date x) throws SQLException {
requireNonNull(this.statement);
if (x == null) {
this.statement.bindNull(parameterIndex);
setParam(parameterIndex, null);
} else {
long time = x.getTime();
this.statement.bindBlob(
parameterIndex, ByteBuffer.allocate(Long.BYTES).putLong(time).array());
setParam(parameterIndex, ByteBuffer.allocate(Long.BYTES).putLong(time).array());
}
}
@ -146,11 +151,10 @@ public final class JDBC4PreparedStatement extends JDBC4Statement implements Prep
public void setTime(int parameterIndex, Time x) throws SQLException {
requireNonNull(this.statement);
if (x == null) {
this.statement.bindNull(parameterIndex);
setParam(parameterIndex, null);
} else {
long time = x.getTime();
this.statement.bindBlob(
parameterIndex, ByteBuffer.allocate(Long.BYTES).putLong(time).array());
setParam(parameterIndex, ByteBuffer.allocate(Long.BYTES).putLong(time).array());
}
}
@ -158,11 +162,10 @@ public final class JDBC4PreparedStatement extends JDBC4Statement implements Prep
public void setTimestamp(int parameterIndex, Timestamp x) throws SQLException {
requireNonNull(this.statement);
if (x == null) {
this.statement.bindNull(parameterIndex);
setParam(parameterIndex, null);
} else {
long time = x.getTime();
this.statement.bindBlob(
parameterIndex, ByteBuffer.allocate(Long.BYTES).putLong(time).array());
setParam(parameterIndex, ByteBuffer.allocate(Long.BYTES).putLong(time).array());
}
}
@ -170,14 +173,14 @@ public final class JDBC4PreparedStatement extends JDBC4Statement implements Prep
public void setAsciiStream(int parameterIndex, InputStream x, int length) throws SQLException {
requireNonNull(this.statement);
if (x == null) {
this.statement.bindNull(parameterIndex);
setParam(parameterIndex, null);
return;
}
if (length < 0) {
throw new SQLException("setAsciiStream length must be non-negative");
}
if (length == 0) {
this.statement.bindText(parameterIndex, "");
setParam(parameterIndex, "");
return;
}
try {
@ -188,7 +191,7 @@ public final class JDBC4PreparedStatement extends JDBC4Statement implements Prep
offset += read;
}
String ascii = new String(buffer, 0, offset, StandardCharsets.US_ASCII);
this.statement.bindText(parameterIndex, ascii);
setParam(parameterIndex, ascii);
} catch (IOException e) {
throw new SQLException("Error reading ASCII stream", e);
}
@ -198,14 +201,14 @@ public final class JDBC4PreparedStatement extends JDBC4Statement implements Prep
public void setUnicodeStream(int parameterIndex, InputStream x, int length) throws SQLException {
requireNonNull(this.statement);
if (x == null) {
this.statement.bindNull(parameterIndex);
setParam(parameterIndex, null);
return;
}
if (length < 0) {
throw new SQLException("setUnicodeStream length must be non-negative");
}
if (length == 0) {
this.statement.bindText(parameterIndex, "");
setParam(parameterIndex, "");
return;
}
try {
@ -216,7 +219,7 @@ public final class JDBC4PreparedStatement extends JDBC4Statement implements Prep
offset += read;
}
String text = new String(buffer, 0, offset, StandardCharsets.UTF_8);
this.statement.bindText(parameterIndex, text);
setParam(parameterIndex, text);
} catch (IOException e) {
throw new SQLException("Error reading Unicode stream", e);
}
@ -226,14 +229,14 @@ public final class JDBC4PreparedStatement extends JDBC4Statement implements Prep
public void setBinaryStream(int parameterIndex, InputStream x, int length) throws SQLException {
requireNonNull(this.statement);
if (x == null) {
this.statement.bindNull(parameterIndex);
setParam(parameterIndex, null);
return;
}
if (length < 0) {
throw new SQLException("setBinaryStream length must be non-negative");
}
if (length == 0) {
this.statement.bindBlob(parameterIndex, new byte[0]);
setParam(parameterIndex, new byte[0]);
return;
}
try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
@ -246,15 +249,21 @@ public final class JDBC4PreparedStatement extends JDBC4Statement implements Prep
totalRead += bytesRead;
}
byte[] data = baos.toByteArray();
this.statement.bindBlob(parameterIndex, data);
setParam(parameterIndex, data);
} catch (IOException e) {
throw new SQLException("Error reading binary stream", e);
}
}
@Override
public void clearParameters() throws SQLException {
// TODO
public void clearParameters() {
this.currentBatchParams = new Object[paramCount];
}
@Override
public void clearBatch() throws SQLException {
this.batchQueryParams.clear();
this.currentBatchParams = new Object[paramCount];
}
@Override
@ -266,7 +275,7 @@ public final class JDBC4PreparedStatement extends JDBC4Statement implements Prep
public void setObject(int parameterIndex, Object x) throws SQLException {
requireNonNull(this.statement);
if (x == null) {
this.statement.bindNull(parameterIndex);
setParam(parameterIndex, null);
return;
}
if (x instanceof String) {
@ -309,14 +318,77 @@ public final class JDBC4PreparedStatement extends JDBC4Statement implements Prep
@Override
public boolean execute() throws SQLException {
return execute(currentBatchParams);
}
/** This helper method runs the statement using the provided parameter values. */
private boolean execute(Object[] params) throws SQLException {
// TODO: check whether this is sufficient
requireNonNull(this.statement);
return statement.execute();
requireNonNull(statement);
bindParams(params);
boolean result = statement.execute();
updateCount = statement.changes();
return result;
}
@Override
public void addBatch() throws SQLException {
// TODO
public int[] executeBatch() throws SQLException {
return Arrays.stream(executeLargeBatch()).mapToInt(l -> (int) l).toArray();
}
@Override
public long[] executeLargeBatch() throws SQLException {
requireNonNull(this.statement);
if (batchQueryParams.isEmpty()) {
return new long[0];
}
long[] updateCounts = new long[batchQueryParams.size()];
if (!isBatchCompatibleStatement(sql)) {
updateCounts[0] = EXECUTE_FAILED;
BatchUpdateException bue =
new BatchUpdateException(
"Batch commands cannot return result sets.",
"HY000", // General error SQL state
0,
Arrays.stream(updateCounts).mapToInt(l -> (int) l).toArray());
// Clear the batch after failure
clearBatch();
throw bue;
}
for (int i = 0; i < batchQueryParams.size(); i++) {
try {
statement.reset();
execute(batchQueryParams.get(i));
updateCounts[i] = getUpdateCount();
} catch (SQLException e) {
BatchUpdateException bue =
new BatchUpdateException(
"Batch entry " + i + " (" + sql + ") failed: " + e.getMessage(),
e.getSQLState(),
e.getErrorCode(),
updateCounts,
e.getCause());
// Clear the batch after failure
clearBatch();
throw bue;
}
}
clearBatch();
return updateCounts;
}
/** Takes the given set of parameters and binds it to the underlying statement. */
private void bindParams(Object[] params) throws SQLException {
requireNonNull(statement);
for (int paramIndex = 1; paramIndex <= params.length; paramIndex++) {
statement.bindObject(paramIndex, params[paramIndex - 1]);
}
}
@Override
public void addBatch() {
batchQueryParams.add(currentBatchParams);
currentBatchParams = new Object[paramCount];
}
@Override
@ -463,23 +535,23 @@ public final class JDBC4PreparedStatement extends JDBC4Statement implements Prep
public void setAsciiStream(int parameterIndex, InputStream x) throws SQLException {
requireNonNull(this.statement);
if (x == null) {
this.statement.bindNull(parameterIndex);
setParam(parameterIndex, null);
return;
}
byte[] data = readBytes(x);
String ascii = new String(data, StandardCharsets.US_ASCII);
this.statement.bindText(parameterIndex, ascii);
setParam(parameterIndex, ascii);
}
@Override
public void setBinaryStream(int parameterIndex, InputStream x) throws SQLException {
requireNonNull(this.statement);
if (x == null) {
this.statement.bindNull(parameterIndex);
setParam(parameterIndex, null);
return;
}
byte[] data = readBytes(x);
this.statement.bindBlob(parameterIndex, data);
setParam(parameterIndex, data);
}
/**

View file

@ -34,7 +34,7 @@ public class JDBC4Statement implements Statement {
+ ")\\b",
Pattern.CASE_INSENSITIVE | Pattern.DOTALL);
private final JDBC4Connection connection;
protected final JDBC4Connection connection;
/** The underlying Turso statement. */
@Nullable protected TursoStatement statement = null;
@ -330,7 +330,7 @@ public class JDBC4Statement implements Statement {
return updateCounts;
}
boolean isBatchCompatibleStatement(String sql) {
protected boolean isBatchCompatibleStatement(String sql) {
if (sql == null || sql.trim().isEmpty()) {
return false;
}

View file

@ -1,10 +1,6 @@
package tech.turso.core;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.*;
import java.util.Properties;
import org.junit.jupiter.api.BeforeEach;
@ -119,6 +115,14 @@ class TursoStatementTest {
selectStmt.close();
}
@Test
void test_parameterCount() throws Exception {
runSql("CREATE TABLE test (col1 INT);");
assertEquals(0, connection.prepare("INSERT INTO test VALUES (1)").parameterCount());
assertEquals(1, connection.prepare("INSERT INTO test VALUES (?)").parameterCount());
assertEquals(2, connection.prepare("INSERT INTO test VALUES (?), (?)").parameterCount());
}
private void runSql(String sql) throws Exception {
TursoStatement stmt = connection.prepare(sql);
while (stmt.execute()) {

View file

@ -1,12 +1,6 @@
package tech.turso.jdbc4;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.*;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
@ -871,4 +865,114 @@ class JDBC4PreparedStatementTest {
preparedStatement.setInt(1, 1);
assertEquals(preparedStatement.executeUpdate(), 1);
}
@Test
void testBatchInsert() throws Exception {
connection.prepareStatement("CREATE TABLE test (col1 INTEGER, col2 INTEGER)").execute();
PreparedStatement preparedStatement =
connection.prepareStatement("INSERT INTO test (col1, col2) VALUES (?, ?)");
preparedStatement.setInt(1, 1);
preparedStatement.setInt(2, 2);
preparedStatement.addBatch();
preparedStatement.setInt(1, 3);
preparedStatement.setInt(2, 4);
preparedStatement.addBatch();
assertArrayEquals(new int[] {1, 1}, preparedStatement.executeBatch());
ResultSet rs = connection.prepareStatement("SELECT * FROM test").executeQuery();
assertTrue(rs.next());
assertEquals(1, rs.getInt(1));
assertEquals(2, rs.getInt(2));
assertTrue(rs.next());
assertEquals(3, rs.getInt(1));
assertEquals(4, rs.getInt(2));
assertFalse(rs.next());
}
@Test
void testBatchUpdate() throws Exception {
connection.prepareStatement("CREATE TABLE test (col1 INTEGER, col2 INTEGER)").execute();
connection.prepareStatement("INSERT INTO test (col1, col2) VALUES (1, 1), (2, 2)").execute();
PreparedStatement preparedStatement =
connection.prepareStatement("UPDATE test SET col2=? WHERE col1=?");
preparedStatement.setInt(1, 5);
preparedStatement.setInt(2, 1);
preparedStatement.addBatch();
preparedStatement.setInt(1, 6);
preparedStatement.setInt(2, 2);
preparedStatement.addBatch();
preparedStatement.setInt(1, 7);
preparedStatement.setInt(2, 3);
preparedStatement.addBatch();
assertArrayEquals(new int[] {1, 1, 0}, preparedStatement.executeBatch());
ResultSet rs = connection.prepareStatement("SELECT * FROM test").executeQuery();
assertTrue(rs.next());
assertEquals(1, rs.getInt(1));
assertEquals(5, rs.getInt(2));
assertTrue(rs.next());
assertEquals(2, rs.getInt(1));
assertEquals(6, rs.getInt(2));
assertFalse(rs.next());
}
@Test
void testBatchDelete() throws Exception {
connection.prepareStatement("CREATE TABLE test (col1 INTEGER, col2 INTEGER)").execute();
connection.prepareStatement("INSERT INTO test (col1, col2) VALUES (1, 1), (2, 2)").execute();
PreparedStatement preparedStatement =
connection.prepareStatement("DELETE FROM test WHERE col1=?");
preparedStatement.setInt(1, 1);
preparedStatement.addBatch();
preparedStatement.setInt(1, 4);
preparedStatement.addBatch();
assertArrayEquals(new int[] {1, 0}, preparedStatement.executeBatch());
ResultSet rs = connection.prepareStatement("SELECT * FROM test").executeQuery();
assertTrue(rs.next());
assertEquals(2, rs.getInt(1));
assertEquals(2, rs.getInt(2));
assertFalse(rs.next());
}
@Test
void testBatch_implicitAddBatch_shouldIgnore() throws Exception {
connection.prepareStatement("CREATE TABLE test (col1 INTEGER, col2 INTEGER)").execute();
PreparedStatement preparedStatement =
connection.prepareStatement("INSERT INTO test (col1, col2) VALUES (?, ?)");
preparedStatement.setInt(1, 1);
preparedStatement.setInt(2, 2);
preparedStatement.addBatch();
// we set parameters but don't call addBatch afterward
// we should only get a result for the first insert statement to match sqlite-jdbc behavior
preparedStatement.setInt(1, 3);
preparedStatement.setInt(2, 4);
assertArrayEquals(new int[] {1}, preparedStatement.executeBatch());
ResultSet rs = connection.prepareStatement("SELECT * FROM test").executeQuery();
assertTrue(rs.next());
assertEquals(1, rs.getInt(1));
assertEquals(2, rs.getInt(2));
assertFalse(rs.next());
}
@Test
void testBatch_select_shouldFail() throws Exception {
connection.prepareStatement("CREATE TABLE test (col1 INTEGER, col2 INTEGER)").execute();
PreparedStatement preparedStatement =
connection.prepareStatement("SELECT * FROM test WHERE col1=?");
preparedStatement.setInt(1, 1);
preparedStatement.addBatch();
assertThrows(BatchUpdateException.class, preparedStatement::executeBatch);
}
}