diff --git a/src/ast.ts b/src/ast.ts index c747ad72..e27e9579 100644 --- a/src/ast.ts +++ b/src/ast.ts @@ -461,6 +461,7 @@ export type TableExpression = | TableExpression.SubQuery | TableExpression.CrossJoin | TableExpression.QualifiedJoin + | TableExpression.FunctionCall export namespace TableExpression { export type Table = { @@ -473,6 +474,19 @@ export namespace TableExpression { return { kind: 'Table', table, as } } + export type FunctionCall = { + kind: 'FunctionCall' + func: Expression.FunctionCall + as: string | null + } + + export function createFunctionCall( + func: Expression.FunctionCall, + as: string | null + ): FunctionCall { + return { kind: 'FunctionCall', func, as } + } + export type SubQuery = { kind: 'SubQuery' query: AST @@ -537,6 +551,7 @@ export namespace TableExpression { tableExpr: TableExpression, handlers: { table: (node: Table) => T + functionCall: (node: FunctionCall) => T subQuery: (node: SubQuery) => T crossJoin: (node: CrossJoin) => T qualifiedJoin: (node: QualifiedJoin) => T @@ -545,6 +560,8 @@ export namespace TableExpression { switch (tableExpr.kind) { case 'Table': return handlers.table(tableExpr) + case 'FunctionCall': + return handlers.functionCall(tableExpr) case 'SubQuery': return handlers.subQuery(tableExpr) case 'CrossJoin': diff --git a/src/infer.ts b/src/infer.ts index c1009565..7d590e73 100644 --- a/src/infer.ts +++ b/src/infer.ts @@ -317,7 +317,7 @@ function inferSelectBodyOutput( function inferSelectListOutput( client: SchemaClient, outsideCTEs: VirtualTable[], - sourceColumns: SourceColumn[], + sourceColumns: (SourceColumn | VirtualField)[], paramNullability: ParamNullability[], conditions: Array, selectList: ast.SelectListItem[] @@ -350,7 +350,7 @@ function inferSelectListOutput( function inferSelectListItemOutput( client: SchemaClient, outsideCTEs: VirtualTable[], - sourceColumns: SourceColumn[], + sourceColumns: (SourceColumn | VirtualField)[], paramNullability: ParamNullability[], nonNullExpressions: ast.Expression[], selectListItem: ast.SelectListItem @@ -368,7 +368,7 @@ function inferSelectListItemOutput( ), Either.map((columns) => columns.map((column) => ({ - name: column.columnName, + name: isSourceColumn(column) ? column.columnName : column.name, nullability: column.nullability, })) ) @@ -384,7 +384,7 @@ function inferSelectListItemOutput( ), Either.map((columns) => columns.map((column) => ({ - name: column.columnName, + name: isSourceColumn(column) ? column.columnName : column.name, nullability: column.nullability, })) ) @@ -416,19 +416,21 @@ type NonNullableColumn = { tableName: string | null; columnName: string } function isColumnNonNullable( nonNullableColumns: NonNullableColumn[], - sourceColumn: SourceColumn + sourceColumn: SourceColumn | VirtualField ): boolean { - return nonNullableColumns.some((nonNull) => - nonNull.tableName - ? sourceColumn.tableAlias === nonNull.tableName - : true && sourceColumn.columnName === nonNull.columnName - ) + return isSourceColumn(sourceColumn) + ? nonNullableColumns.some((nonNull) => + nonNull.tableName + ? sourceColumn.tableAlias === nonNull.tableName + : true && sourceColumn.columnName === nonNull.columnName + ) + : !sourceColumn.nullability.nullable // TODO correct? } function applyExpressionNonNullability( nonNullableExpressions: ast.Expression[], - sourceColumns: SourceColumn[] -): SourceColumn[] { + sourceColumns: (SourceColumn | VirtualField)[] +): (SourceColumn | VirtualField)[] { const nonNullableColumns = pipe( nonNullableExpressions, R.map((expr) => @@ -464,7 +466,7 @@ function inferExpressionName(expression: ast.Expression): string { function inferExpressionNullability( client: SchemaClient, outsideCTEs: VirtualTable[], - sourceColumns: SourceColumn[], + sourceColumns: (SourceColumn | VirtualField)[], paramNullability: ParamNullability[], nonNullExprs: ast.Expression[], expression: ast.Expression @@ -485,7 +487,13 @@ function inferExpressionNullability( // have a NOT NULL constraint tableColumnRef: ({ table, column }) => pipe( - InferM.fromEither(findSourceTableColumn(table, column, sourceColumns)), + InferM.fromEither( + findSourceTableColumn( + table, + column, + sourceColumns.filter(isSourceColumn) + ) + ), InferM.map((column) => column.nullability) ), @@ -1173,7 +1181,7 @@ function getSourceColumnsForTableExpr( paramNullability: ParamNullability[], tableExpr: ast.TableExpression | null, setNullable = false -): InferM.InferM { +): InferM.InferM<(SourceColumn | VirtualField)[]> { if (!tableExpr) { return InferM.right([]) } @@ -1182,6 +1190,22 @@ function getSourceColumnsForTableExpr( ast.TableExpression.walk(tableExpr, { table: ({ table, as }) => getSourceColumnsForTable(client, ctes, table, as), + functionCall: ({ func, as }): InferM.InferM => + InferM.map((nullability: FieldNullability) => [ + { + name: as || inferExpressionName(func), + nullability, + }, + ])( + inferExpressionNullability( + client, + ctes, + [], // TODO? + paramNullability, + [], // TODO? + func + ) + ), subQuery: ({ query, as }) => getSourceColumnsForSubQuery(client, ctes, paramNullability, query, as), crossJoin: ({ left, right }) => @@ -1248,21 +1272,31 @@ function getSourceColumnsForSubQuery( } function setSourceColumnsAsNullable( - sourceColumns: SourceColumn[] -): SourceColumn[] { - return sourceColumns.map((col) => ({ - ...col, - nullability: { ...col.nullability, nullable: true }, - })) + sourceColumns: (SourceColumn | VirtualField)[] +): (SourceColumn | VirtualField)[] { + return sourceColumns.map((col) => + isSourceColumn(col) + ? { + ...col, + nullability: { ...col.nullability, nullable: true }, + } + : { + ...col, + nullability: { ...col.nullability, nullable: true }, + } + ) } function combineSourceColumns( - ...sourceColumns: Array> -): InferM.InferM { + ...sourceColumns: Array> +): InferM.InferM<(SourceColumn | VirtualField)[]> { return pipe( sourceColumns, sequenceAIM, - InferM.map(R.flatten) + InferM.map< + (SourceColumn | VirtualField)[][], + (SourceColumn | VirtualField)[] + >(R.flatten) ) } @@ -1272,11 +1306,19 @@ function isConstantExprOf(expectedValue: string, expr: ast.Expression) { }) } +function isSourceColumn(c: SourceColumn | VirtualField): c is SourceColumn { + return !!(c as SourceColumn).tableAlias +} + function findNonHiddenSourceColumns( - sourceColumns: SourceColumn[] -): Either.Either { + sourceColumns: (SourceColumn | VirtualField)[] +): Either.Either { return pipe( - sourceColumns.filter((col) => !col.hidden), + sourceColumns.filter( + (col) => + (isSourceColumn(col) && !col.hidden) || + !isSourceColumn(col) /* VirtualField's are never hidden */ + ), Either.fromPredicate( (result) => result.length > 0, () => `No columns` @@ -1286,12 +1328,14 @@ function findNonHiddenSourceColumns( function findNonHiddenSourceTableColumns( tableName: string, - sourceColumns: SourceColumn[] + sourceColumns: (SourceColumn | VirtualField)[] ): Either.Either { return pipe( findNonHiddenSourceColumns(sourceColumns), Either.map((sourceColumns) => - sourceColumns.filter((col) => col.tableAlias === tableName) + sourceColumns + .filter(isSourceColumn) + .filter((col) => col.tableAlias === tableName) ), Either.chain((result) => result.length > 0 @@ -1317,10 +1361,14 @@ function findSourceTableColumn( function findSourceColumn( columnName: string, - sourceColumns: SourceColumn[] -): Either.Either { + sourceColumns: (SourceColumn | VirtualField)[] +): Either.Either { return pipe( - sourceColumns.find((col) => col.columnName === columnName), + sourceColumns.find( + (col) => + (isSourceColumn(col) && col.columnName === columnName) || + (!isSourceColumn(col) && col.name === columnName) + ), Either.fromNullable(`Unknown column ${columnName}`) ) } diff --git a/src/parser/index.ts b/src/parser/index.ts index d6234431..d3e505e0 100644 --- a/src/parser/index.ts +++ b/src/parser/index.ts @@ -564,6 +564,24 @@ const tableExpression: Parser = seq( as )((stmt, as) => TableExpression.createSubQuery(stmt, as)) ), + attempt( + seq( + seq( + identifier, + optional(seq2(symbol('.'), identifier)), + functionArguments + )((ident, ident2, argList) => + Expression.createFunctionCall( + ident2 ? ident : null, + ident2 ? ident2 : ident, + argList, + null, + null + ) + ), + optional(as) + )((fnCall, as) => TableExpression.createFunctionCall(fnCall, as)) + ), table ), many(oneOf(crossJoin, qualifiedJoin, naturalJoin)) diff --git a/src/schema.ts b/src/schema.ts index 71abb65e..36acb06f 100644 --- a/src/schema.ts +++ b/src/schema.ts @@ -38,6 +38,10 @@ export interface SchemaClient { ): TaskEither.TaskEither getEnums(): Promise getArrayTypes(): Promise + getFunction( + schemaName: string | null, + functionName: string + ): TaskEither.TaskEither functionNullSafety( schemaName: string | null, functionName: string @@ -54,9 +58,7 @@ export function schemaClient(postgresClient: postgres.Sql<{}>): SchemaClient { tableName, }) if (result.length === 0) { - return Either.left( - `No such table: ${fullTableName(schemaName, tableName)}` - ) + return Either.left(`No such table: ${fullName(schemaName, tableName)}`) } return Either.right({ name: tableName, @@ -86,6 +88,27 @@ export function schemaClient(postgresClient: postgres.Sql<{}>): SchemaClient { })) ) + // TODO: handle overloaded functions + const getFunction = ( + schemaName: string | null, + functionName: string + ): TaskEither.TaskEither => async () => { + const allFunctions = await getFunctions() + const res = allFunctions.find( + (f) => + schemaName !== null && + f.schema === schemaName && + f.name === functionName + ) + if (res) { + return Either.right(res) + } else { + return Either.left( + `No such function: ${fullName(schemaName, functionName)}` + ) + } + } + const getFunctions = asyncCached( async (): Promise => (await sql.functions(postgresClient)).map((row) => ({ @@ -117,9 +140,9 @@ export function schemaClient(postgresClient: postgres.Sql<{}>): SchemaClient { ) } - return { getTable, getEnums, getArrayTypes, functionNullSafety } + return { getTable, getEnums, getArrayTypes, getFunction, functionNullSafety } } -function fullTableName(schemaName: string | null, tableName: string): string { +function fullName(schemaName: string | null, tableName: string): string { return (schemaName ? schemaName + '.' : '') + tableName }