From b3bcec6b101fbb4a1a8841f35647e03125585e18 Mon Sep 17 00:00:00 2001 From: Simon Van Casteren Date: Tue, 23 Mar 2021 13:51:41 +0100 Subject: [PATCH 1/3] First attempt at supporting functionCall in FROM --- src/ast.ts | 17 +++++++ src/infer.ts | 112 +++++++++++++++++++++++++++++++------------- src/parser/index.ts | 18 +++++++ src/schema.ts | 33 +++++++++++-- 4 files changed, 143 insertions(+), 37 deletions(-) diff --git a/src/ast.ts b/src/ast.ts index c747ad72..94efb086 100644 --- a/src/ast.ts +++ b/src/ast.ts @@ -461,6 +461,7 @@ export type TableExpression = | TableExpression.SubQuery | TableExpression.CrossJoin | TableExpression.QualifiedJoin + | TableExpression.FunctionCall export namespace TableExpression { export type Table = { @@ -473,6 +474,19 @@ export namespace TableExpression { return { kind: 'Table', table, as } } + export type FunctionCall = { + kind: 'FunctionCall' + func: Expression.FunctionCall + as: string + } + + export function createFunctionCall( + func: Expression.FunctionCall, + as: string + ): FunctionCall { + return { kind: 'FunctionCall', func, as } + } + export type SubQuery = { kind: 'SubQuery' query: AST @@ -537,6 +551,7 @@ export namespace TableExpression { tableExpr: TableExpression, handlers: { table: (node: Table) => T + functionCall: (node: FunctionCall) => T subQuery: (node: SubQuery) => T crossJoin: (node: CrossJoin) => T qualifiedJoin: (node: QualifiedJoin) => T @@ -545,6 +560,8 @@ export namespace TableExpression { switch (tableExpr.kind) { case 'Table': return handlers.table(tableExpr) + case 'FunctionCall': + return handlers.functionCall(tableExpr) case 'SubQuery': return handlers.subQuery(tableExpr) case 'CrossJoin': diff --git a/src/infer.ts b/src/infer.ts index c1009565..3976d3a9 100644 --- a/src/infer.ts +++ b/src/infer.ts @@ -317,7 +317,7 @@ function inferSelectBodyOutput( function inferSelectListOutput( client: SchemaClient, outsideCTEs: VirtualTable[], - sourceColumns: SourceColumn[], + sourceColumns: (SourceColumn | VirtualField)[], paramNullability: ParamNullability[], conditions: Array, selectList: ast.SelectListItem[] @@ -350,7 +350,7 @@ function inferSelectListOutput( function inferSelectListItemOutput( client: SchemaClient, outsideCTEs: VirtualTable[], - sourceColumns: SourceColumn[], + sourceColumns: (SourceColumn | VirtualField)[], paramNullability: ParamNullability[], nonNullExpressions: ast.Expression[], selectListItem: ast.SelectListItem @@ -368,7 +368,7 @@ function inferSelectListItemOutput( ), Either.map((columns) => columns.map((column) => ({ - name: column.columnName, + name: isSourceColumn(column) ? column.columnName : column.name, nullability: column.nullability, })) ) @@ -384,7 +384,7 @@ function inferSelectListItemOutput( ), Either.map((columns) => columns.map((column) => ({ - name: column.columnName, + name: isSourceColumn(column) ? column.columnName : column.name, nullability: column.nullability, })) ) @@ -416,19 +416,21 @@ type NonNullableColumn = { tableName: string | null; columnName: string } function isColumnNonNullable( nonNullableColumns: NonNullableColumn[], - sourceColumn: SourceColumn + sourceColumn: SourceColumn | VirtualField ): boolean { - return nonNullableColumns.some((nonNull) => - nonNull.tableName - ? sourceColumn.tableAlias === nonNull.tableName - : true && sourceColumn.columnName === nonNull.columnName - ) + return isSourceColumn(sourceColumn) + ? nonNullableColumns.some((nonNull) => + nonNull.tableName + ? sourceColumn.tableAlias === nonNull.tableName + : true && sourceColumn.columnName === nonNull.columnName + ) + : !sourceColumn.nullability.nullable // TODO correct? } function applyExpressionNonNullability( nonNullableExpressions: ast.Expression[], - sourceColumns: SourceColumn[] -): SourceColumn[] { + sourceColumns: (SourceColumn | VirtualField)[] +): (SourceColumn | VirtualField)[] { const nonNullableColumns = pipe( nonNullableExpressions, R.map((expr) => @@ -464,7 +466,7 @@ function inferExpressionName(expression: ast.Expression): string { function inferExpressionNullability( client: SchemaClient, outsideCTEs: VirtualTable[], - sourceColumns: SourceColumn[], + sourceColumns: (SourceColumn | VirtualField)[], paramNullability: ParamNullability[], nonNullExprs: ast.Expression[], expression: ast.Expression @@ -485,7 +487,13 @@ function inferExpressionNullability( // have a NOT NULL constraint tableColumnRef: ({ table, column }) => pipe( - InferM.fromEither(findSourceTableColumn(table, column, sourceColumns)), + InferM.fromEither( + findSourceTableColumn( + table, + column, + sourceColumns.filter(isSourceColumn) + ) + ), InferM.map((column) => column.nullability) ), @@ -1173,7 +1181,7 @@ function getSourceColumnsForTableExpr( paramNullability: ParamNullability[], tableExpr: ast.TableExpression | null, setNullable = false -): InferM.InferM { +): InferM.InferM<(SourceColumn | VirtualField)[]> { if (!tableExpr) { return InferM.right([]) } @@ -1182,6 +1190,22 @@ function getSourceColumnsForTableExpr( ast.TableExpression.walk(tableExpr, { table: ({ table, as }) => getSourceColumnsForTable(client, ctes, table, as), + functionCall: ({ func, as }): InferM.InferM => + InferM.map((nullability: FieldNullability) => [ + { + name: as, + nullability, + }, + ])( + inferExpressionNullability( + client, + ctes, + [], // TODO? + paramNullability, + [], // TODO? + func + ) + ), subQuery: ({ query, as }) => getSourceColumnsForSubQuery(client, ctes, paramNullability, query, as), crossJoin: ({ left, right }) => @@ -1248,21 +1272,31 @@ function getSourceColumnsForSubQuery( } function setSourceColumnsAsNullable( - sourceColumns: SourceColumn[] -): SourceColumn[] { - return sourceColumns.map((col) => ({ - ...col, - nullability: { ...col.nullability, nullable: true }, - })) + sourceColumns: (SourceColumn | VirtualField)[] +): (SourceColumn | VirtualField)[] { + return sourceColumns.map((col) => + isSourceColumn(col) + ? { + ...col, + nullability: { ...col.nullability, nullable: true }, + } + : { + ...col, + nullability: { ...col.nullability, nullable: true }, + } + ) } function combineSourceColumns( - ...sourceColumns: Array> -): InferM.InferM { + ...sourceColumns: Array> +): InferM.InferM<(SourceColumn | VirtualField)[]> { return pipe( sourceColumns, sequenceAIM, - InferM.map(R.flatten) + InferM.map< + (SourceColumn | VirtualField)[][], + (SourceColumn | VirtualField)[] + >(R.flatten) ) } @@ -1272,11 +1306,19 @@ function isConstantExprOf(expectedValue: string, expr: ast.Expression) { }) } +function isSourceColumn(c: SourceColumn | VirtualField): c is SourceColumn { + return !!(c as SourceColumn).tableAlias +} + function findNonHiddenSourceColumns( - sourceColumns: SourceColumn[] -): Either.Either { + sourceColumns: (SourceColumn | VirtualField)[] +): Either.Either { return pipe( - sourceColumns.filter((col) => !col.hidden), + sourceColumns.filter( + (col) => + (isSourceColumn(col) && !col.hidden) || + true /* VirtualField's are never hidden */ + ), Either.fromPredicate( (result) => result.length > 0, () => `No columns` @@ -1286,12 +1328,14 @@ function findNonHiddenSourceColumns( function findNonHiddenSourceTableColumns( tableName: string, - sourceColumns: SourceColumn[] + sourceColumns: (SourceColumn | VirtualField)[] ): Either.Either { return pipe( findNonHiddenSourceColumns(sourceColumns), Either.map((sourceColumns) => - sourceColumns.filter((col) => col.tableAlias === tableName) + sourceColumns + .filter(isSourceColumn) + .filter((col) => col.tableAlias === tableName) ), Either.chain((result) => result.length > 0 @@ -1317,10 +1361,14 @@ function findSourceTableColumn( function findSourceColumn( columnName: string, - sourceColumns: SourceColumn[] -): Either.Either { + sourceColumns: (SourceColumn | VirtualField)[] +): Either.Either { return pipe( - sourceColumns.find((col) => col.columnName === columnName), + sourceColumns.find( + (col) => + (isSourceColumn(col) && col.columnName === columnName) || + (!isSourceColumn(col) && col.name === columnName) + ), Either.fromNullable(`Unknown column ${columnName}`) ) } diff --git a/src/parser/index.ts b/src/parser/index.ts index d6234431..3abec677 100644 --- a/src/parser/index.ts +++ b/src/parser/index.ts @@ -564,6 +564,24 @@ const tableExpression: Parser = seq( as )((stmt, as) => TableExpression.createSubQuery(stmt, as)) ), + attempt( + seq( + seq( + identifier, + optional(seq2(symbol('.'), identifier)), + functionArguments + )((ident, ident2, argList) => + Expression.createFunctionCall( + ident2 ? ident : null, + ident2 ? ident2 : ident, + argList, + null, + null + ) + ), + as + )((fnCall, as) => TableExpression.createFunctionCall(fnCall, as)) + ), table ), many(oneOf(crossJoin, qualifiedJoin, naturalJoin)) diff --git a/src/schema.ts b/src/schema.ts index 71abb65e..36acb06f 100644 --- a/src/schema.ts +++ b/src/schema.ts @@ -38,6 +38,10 @@ export interface SchemaClient { ): TaskEither.TaskEither getEnums(): Promise getArrayTypes(): Promise + getFunction( + schemaName: string | null, + functionName: string + ): TaskEither.TaskEither functionNullSafety( schemaName: string | null, functionName: string @@ -54,9 +58,7 @@ export function schemaClient(postgresClient: postgres.Sql<{}>): SchemaClient { tableName, }) if (result.length === 0) { - return Either.left( - `No such table: ${fullTableName(schemaName, tableName)}` - ) + return Either.left(`No such table: ${fullName(schemaName, tableName)}`) } return Either.right({ name: tableName, @@ -86,6 +88,27 @@ export function schemaClient(postgresClient: postgres.Sql<{}>): SchemaClient { })) ) + // TODO: handle overloaded functions + const getFunction = ( + schemaName: string | null, + functionName: string + ): TaskEither.TaskEither => async () => { + const allFunctions = await getFunctions() + const res = allFunctions.find( + (f) => + schemaName !== null && + f.schema === schemaName && + f.name === functionName + ) + if (res) { + return Either.right(res) + } else { + return Either.left( + `No such function: ${fullName(schemaName, functionName)}` + ) + } + } + const getFunctions = asyncCached( async (): Promise => (await sql.functions(postgresClient)).map((row) => ({ @@ -117,9 +140,9 @@ export function schemaClient(postgresClient: postgres.Sql<{}>): SchemaClient { ) } - return { getTable, getEnums, getArrayTypes, functionNullSafety } + return { getTable, getEnums, getArrayTypes, getFunction, functionNullSafety } } -function fullTableName(schemaName: string | null, tableName: string): string { +function fullName(schemaName: string | null, tableName: string): string { return (schemaName ? schemaName + '.' : '') + tableName } From 3b6f517a5d7cec72cfcde81f51bb06d2f54a847b Mon Sep 17 00:00:00 2001 From: Simon Van Casteren Date: Tue, 23 Mar 2021 14:28:19 +0100 Subject: [PATCH 2/3] FROM : AS is optional --- src/ast.ts | 4 ++-- src/infer.ts | 2 +- src/parser/index.ts | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ast.ts b/src/ast.ts index 94efb086..e27e9579 100644 --- a/src/ast.ts +++ b/src/ast.ts @@ -477,12 +477,12 @@ export namespace TableExpression { export type FunctionCall = { kind: 'FunctionCall' func: Expression.FunctionCall - as: string + as: string | null } export function createFunctionCall( func: Expression.FunctionCall, - as: string + as: string | null ): FunctionCall { return { kind: 'FunctionCall', func, as } } diff --git a/src/infer.ts b/src/infer.ts index 3976d3a9..6ffb42c8 100644 --- a/src/infer.ts +++ b/src/infer.ts @@ -1193,7 +1193,7 @@ function getSourceColumnsForTableExpr( functionCall: ({ func, as }): InferM.InferM => InferM.map((nullability: FieldNullability) => [ { - name: as, + name: as || inferExpressionName(func), nullability, }, ])( diff --git a/src/parser/index.ts b/src/parser/index.ts index 3abec677..d3e505e0 100644 --- a/src/parser/index.ts +++ b/src/parser/index.ts @@ -579,7 +579,7 @@ const tableExpression: Parser = seq( null ) ), - as + optional(as) )((fnCall, as) => TableExpression.createFunctionCall(fnCall, as)) ), table From 7a4768fec8f714560fc90e38afa4a8072bef96cd Mon Sep 17 00:00:00 2001 From: Simon Van Casteren Date: Tue, 23 Mar 2021 15:02:52 +0100 Subject: [PATCH 3/3] Fix regression in findNonHiddenSourceColumns --- src/infer.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/infer.ts b/src/infer.ts index 6ffb42c8..7d590e73 100644 --- a/src/infer.ts +++ b/src/infer.ts @@ -1317,7 +1317,7 @@ function findNonHiddenSourceColumns( sourceColumns.filter( (col) => (isSourceColumn(col) && !col.hidden) || - true /* VirtualField's are never hidden */ + !isSourceColumn(col) /* VirtualField's are never hidden */ ), Either.fromPredicate( (result) => result.length > 0,