diff --git a/src/QueryStream.ts b/src/QueryStream.ts index 6c701449..01fa37ce 100644 --- a/src/QueryStream.ts +++ b/src/QueryStream.ts @@ -38,7 +38,7 @@ export class QueryStream extends Readable { public handleError: Function; - public constructor (text: unknown, values: unknown, options?: ReadableOptions & {batchSize: number, }) { + public constructor (text: unknown, values: unknown, options?: ReadableOptions & {batchSize?: number, }) { super({ objectMode: true, ...options, diff --git a/src/binders/bindPool.ts b/src/binders/bindPool.ts index f400b882..50089725 100644 --- a/src/binders/bindPool.ts +++ b/src/binders/bindPool.ts @@ -302,7 +302,7 @@ export const bindPool = ( query, ); }, - stream: (streamQuery, streamHandler) => { + stream: (streamQuery, streamHandler, config) => { assertSqlSqlToken(streamQuery); return createConnection( @@ -315,10 +315,10 @@ export const bindPool = ( connection, boundConnection, ) => { - return boundConnection.stream(streamQuery, streamHandler); + return boundConnection.stream(streamQuery, streamHandler, config); }, (newPool) => { - return newPool.stream(streamQuery, streamHandler); + return newPool.stream(streamQuery, streamHandler, config); }, streamQuery, ); diff --git a/src/binders/bindPoolConnection.ts b/src/binders/bindPoolConnection.ts index 64891312..6ef103de 100644 --- a/src/binders/bindPoolConnection.ts +++ b/src/binders/bindPoolConnection.ts @@ -154,7 +154,7 @@ export const bindPoolConnection = ( query.values, ); }, - stream: (query, streamHandler) => { + stream: (query, streamHandler, config) => { assertSqlSqlToken(query); return stream( @@ -164,6 +164,8 @@ export const bindPoolConnection = ( query.sql, query.values, streamHandler, + undefined, + config, ); }, transaction: (handler, transactionRetryLimit) => { diff --git a/src/connectionMethods/stream.ts b/src/connectionMethods/stream.ts index a20a5b75..df376908 100644 --- a/src/connectionMethods/stream.ts +++ b/src/connectionMethods/stream.ts @@ -11,7 +11,7 @@ import type { InternalStreamFunction, } from '../types'; -export const stream: InternalStreamFunction = async (connectionLogger, connection, clientConfiguration, rawSql, values, streamHandler) => { +export const stream: InternalStreamFunction = async (connectionLogger, connection, clientConfiguration, rawSql, values, streamHandler, uid, options) => { return await executeQuery( connectionLogger, connection, @@ -20,7 +20,7 @@ export const stream: InternalStreamFunction = async (connectionLogger, connectio values, undefined, (finalConnection, finalSql, finalValues, executionContext, actualQuery) => { - const query = new QueryStream(finalSql, finalValues); + const query = new QueryStream(finalSql, finalValues, options); const queryStream: Stream = finalConnection.query(query); diff --git a/src/types.ts b/src/types.ts index 81c30c9e..91a1439f 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1,5 +1,6 @@ import type { Readable, + ReadableOptions, } from 'stream'; import type { ConnectionOptions as TlsConnectionOptions, @@ -129,9 +130,12 @@ export type ClientConfiguration = { export type ClientConfigurationInput = Partial; +export type QueryStreamConfig = ReadableOptions & {batchSize?: number, }; + export type StreamFunction = ( sql: TaggedTemplateLiteralInvocation, streamHandler: StreamHandler, + config?: QueryStreamConfig ) => Promise | null>; export type QueryCopyFromBinaryFunction = ( @@ -380,6 +384,7 @@ export type InternalStreamFunction = ( values: readonly PrimitiveValueExpression[], streamHandler: StreamHandler, uid?: QueryId, + config?: QueryStreamConfig, ) => Promise>; export type InternalTransactionFunction = ( diff --git a/test/slonik/integration/pg.ts b/test/slonik/integration/pg.ts index 095c0290..9d69a0ba 100644 --- a/test/slonik/integration/pg.ts +++ b/test/slonik/integration/pg.ts @@ -109,6 +109,65 @@ test('streams rows', async (t) => { await pool.end(); }); +test('streams rows with different batchSize', async (t) => { + const pool = createPool(t.context.dsn); + + await pool.query(sql` + INSERT INTO person (name) VALUES ('foo'), ('bar'), ('baz') + `); + + const messages: Array> = []; + + await pool.stream(sql` + SELECT name + FROM person + `, (stream) => { + stream.on('data', (datum) => { + messages.push(datum); + }); + }, { + batchSize: 1, + }); + + t.deepEqual(messages, [ + { + fields: [ + { + dataTypeId: 25, + name: 'name', + }, + ], + row: { + name: 'foo', + }, + }, + { + fields: [ + { + dataTypeId: 25, + name: 'name', + }, + ], + row: { + name: 'bar', + }, + }, + { + fields: [ + { + dataTypeId: 25, + name: 'name', + }, + ], + row: { + name: 'baz', + }, + }, + ]); + + await pool.end(); +}); + test('applies type parsers to streamed rows', async (t) => { const pool = createPool(t.context.dsn, { typeParsers: [