diff --git a/src/databricks/sql/backend/kernel/client.py b/src/databricks/sql/backend/kernel/client.py index 9b2f8e8db..422c207b9 100644 --- a/src/databricks/sql/backend/kernel/client.py +++ b/src/databricks/sql/backend/kernel/client.py @@ -129,6 +129,15 @@ def __init__( # Guarded by ``_async_handles_lock`` so concurrent cursors on the # same connection don't race on submit / close / close-session. self._async_handles: Dict[str, Any] = {} + # Parent ``Statement`` objects kept alive alongside async handles. + # On the kernel, ``Statement.close()`` flips the validity flag on + # the produced executed handle (see kernel + # ``statement::mutable::close``), so we cannot close the + # Statement immediately after ``submit()`` as we do for sync + # ``execute()``. Instead retain it here and close it in + # ``close_command`` / ``close_session`` after the async handle + # has finished its work. + self._async_statements: Dict[str, Any] = {} # CommandId.guids of async commands that have already been # closed (via ``close_command`` or ``close_session``). Lets # ``get_query_state`` report ``CLOSED`` for them rather than @@ -167,6 +176,16 @@ def open_session( schema=schema or self._schema, session_conf=session_conf, complex_types_as_json=not self._use_arrow_native_complex_types, + # Pyarrow's Python bindings cannot decode Arrow's + # ``month_interval`` type at all (id 21 — raises + # ``KeyError`` from ``.as_py``, ``to_pylist``, + # ``cast(string)``, and ``to_pandas``). Ask the kernel + # to stringify INTERVAL / DURATION columns server-side + # so result sets containing interval columns are + # decodable on the Python side. Matches the Thrift + # backend's surface (interval columns arrive as + # strings). + intervals_as_string=True, **auth_kwargs, ) except Exception as exc: @@ -197,7 +216,9 @@ def close_session(self, session_id: SessionId) -> None: # server-side CloseStatement before the session goes away. with self._async_handles_lock: tracked = list(self._async_handles.items()) + tracked_stmts = list(self._async_statements.items()) self._async_handles.clear() + self._async_statements.clear() for guid, _ in tracked: self._closed_commands.add(guid) for _, handle in tracked: @@ -211,6 +232,16 @@ def close_session(self, session_id: SessionId) -> None: logger.warning( "Error closing async handle during session close: %s", exc ) + # Now drop the parent Statements that were keeping those handles + # alive. Same non-fatal close semantics — close errors are not + # actionable at session-close time. + for _, stmt in tracked_stmts: + try: + stmt.close() + except Exception as exc: + logger.warning( + "Error closing async statement during session close: %s", exc + ) try: self._kernel_session.close() except Exception as exc: @@ -249,6 +280,11 @@ def execute_command( stmt = self._kernel_session.statement() except Exception as exc: raise _wrap_kernel_exception("execute_command", exc) from exc + # ``async_op`` keeps ``stmt`` alive (tracked in + # ``_async_statements`` and closed by ``close_command``); the sync + # path drops it in finally. ``close_stmt`` is the post-success + # decision flag — it stays True on sync, flips to False on async. + close_stmt = True try: try: stmt.set_sql(operation) @@ -262,21 +298,26 @@ def execute_command( cursor.active_command_id = command_id with self._async_handles_lock: self._async_handles[command_id.guid] = async_exec + # Closing the kernel ``Statement`` invalidates the + # async handle (see kernel validity flag). Retain + # the Statement here and close it on + # ``close_command`` / ``close_session``. + self._async_statements[command_id.guid] = stmt + close_stmt = False return None executed = stmt.execute() except Exception as exc: raise _wrap_kernel_exception("execute_command", exc) from exc finally: - # ``Statement`` is a lifecycle owner separate from the - # executed handle it produces. Drop it here so the - # parent doesn't keep the handle alive longer than the - # caller expects. Swallow all close errors (including - # PyO3 native exceptions) — a failed stmt.close() is - # not actionable for the caller. - try: - stmt.close() - except Exception: - pass + if close_stmt: + # Sync path: ``Statement`` is a lifecycle owner separate + # from the executed handle. Drop it here so the parent + # doesn't outlive its caller. Swallow close errors — + # they're not actionable. + try: + stmt.close() + except Exception: + pass command_id = CommandId.from_sea_statement_id(executed.statement_id) cursor.active_command_id = command_id @@ -307,17 +348,34 @@ def cancel_command(self, command_id: CommandId) -> None: def close_command(self, command_id: CommandId) -> None: with self._async_handles_lock: handle = self._async_handles.pop(command_id.guid, None) + stmt = self._async_statements.pop(command_id.guid, None) if handle is not None: # Record the close so ``get_query_state`` can report # ``CLOSED`` (not ``SUCCEEDED``) for this command. self._closed_commands.add(command_id.guid) if handle is None: logger.debug("close_command: no tracked handle for %s", command_id) + # Still drop the parent Statement if somehow tracked without + # the handle — keeps the invariant clean even on bookkeeping + # races. + if stmt is not None: + try: + stmt.close() + except Exception: + pass return try: handle.close() except Exception as exc: raise _wrap_kernel_exception("close_command", exc) from exc + finally: + # Now safe to close the parent Statement — the executed + # handle has finished its lifecycle. + if stmt is not None: + try: + stmt.close() + except Exception: + pass def get_query_state(self, command_id: CommandId) -> CommandState: with self._async_handles_lock: @@ -378,6 +436,7 @@ def get_execution_result( # it wraps. Drop tracking and fire-and-forget the close. with self._async_handles_lock: self._async_handles.pop(command_id.guid, None) + stmt = self._async_statements.pop(command_id.guid, None) self._closed_commands.add(command_id.guid) try: async_exec.close() @@ -387,6 +446,18 @@ def get_execution_result( command_id, exc, ) + # The parent Statement is no longer needed once the async handle + # has produced its ResultStream. Close to release server-side + # tracking; matches the sync path's eager Statement close. + if stmt is not None: + try: + stmt.close() + except Exception as exc: + logger.warning( + "Error closing async statement after await_result for %s: %s", + command_id, + exc, + ) # ``KernelResultSet.__init__`` calls ``arrow_schema()`` which # can raise — map that to PEP 249 too. try: diff --git a/src/databricks/sql/backend/kernel/type_mapping.py b/src/databricks/sql/backend/kernel/type_mapping.py index fc1a338cd..e5753035b 100644 --- a/src/databricks/sql/backend/kernel/type_mapping.py +++ b/src/databricks/sql/backend/kernel/type_mapping.py @@ -21,7 +21,7 @@ from __future__ import annotations -from typing import Any, List, Tuple +from typing import Any, List, Optional, Tuple import pyarrow @@ -102,6 +102,14 @@ def description_from_arrow_schema(schema: pyarrow.Schema) -> List[Tuple]: backend's behaviour; other precise types (``INTERVAL_*``, ``GEOMETRY``, ``GEOGRAPHY``) collapse to their Arrow shape on both backends and don't need a remap. + + ``precision`` / ``scale`` are extracted from ``Decimal128Type`` / + ``Decimal256Type`` so DECIMAL columns expose the same + ``(precision, scale)`` pair the Thrift backend reports. The Arrow + schema carries these on the type itself; without this extraction + the kernel-backend description would silently drop them, breaking + parity for any consumer (SQLAlchemy, pandas-read-sql, etc.) that + reads slots 4/5 to know how to display or round decimal values. """ return [ ( @@ -109,14 +117,32 @@ def description_from_arrow_schema(schema: pyarrow.Schema) -> List[Tuple]: _databricks_type_for_field(field), None, None, - None, - None, + *_precision_scale_for_arrow_type(field.type), None, ) for field in schema ] +def _precision_scale_for_arrow_type( + arrow_type: pyarrow.DataType, +) -> Tuple[Optional[int], Optional[int]]: + """Extract PEP 249 ``(precision, scale)`` from an Arrow type. + + Only Arrow's decimal types carry both; every other type collapses + to ``(None, None)`` to match the Thrift backend's behaviour. Future + extensions (e.g. fractional-second precision from + ``Time64Type`` / ``Timestamp``) can land here without touching the + description builder above. + """ + if pyarrow.types.is_decimal(arrow_type): + # Decimal128Type / Decimal256Type both expose `.precision` and + # `.scale`. The cast is for the type checker — pyarrow's + # `DataType` base type doesn't declare them. + return arrow_type.precision, arrow_type.scale # type: ignore[attr-defined] + return None, None + + def _databricks_type_for_field(field: pyarrow.Field) -> str: """Pick the PEP 249 type code for a single field. @@ -173,12 +199,10 @@ def _tspark_param_value_str(param: ttypes.TSparkParameter) -> Any: def bind_tspark_params(kernel_stmt, parameters: List[ttypes.TSparkParameter]) -> None: """Bind a list of ``TSparkParameter`` onto a kernel ``Statement``. - The kernel expects positional bindings only (SEA v0 doesn't - accept named bindings on the wire). The connector's + Both positional and named bindings are supported. The connector's ``TSparkParameter`` has an ``ordinal: bool`` flag; ``True`` means - "treat as positional in source-list order". Named-binding - parameters surface as ``NotSupportedError`` so the user gets a - clear message instead of a server-side rejection. + "treat as positional in source-list order", otherwise the + parameter is bound by name via ``Statement.bind_named_param``. Compound types (``ARRAY`` / ``MAP`` / ``STRUCT``) build a ``TSparkParameter`` with the payload on ``arguments`` and @@ -186,19 +210,8 @@ def bind_tspark_params(kernel_stmt, parameters: List[ttypes.TSparkParameter]) -> NULL. Reject up front with ``NotSupportedError`` so callers get a clear message instead of silent data loss. """ - for i, param in enumerate(parameters, start=1): - # ``ordinal`` on connector-native params is a bool (True for - # positional, False for named). Thrift defaults to ``None``; - # treat any non-True value with a name as a named binding so - # a future caller that forgets to set ordinal=True still gets - # rejected instead of silently dropping the name. - name = getattr(param, "name", None) - if name and getattr(param, "ordinal", None) is not True: - raise NotSupportedError( - f"Named parameter binding (got name={name!r}) is not yet " - "supported on the kernel backend; pass parameters positionally." - ) - + positional_index = 0 + for param in parameters: sql_type = param.type or "STRING" # Compound types put their payload on ``arguments``, not # ``value``. The kernel parser doesn't accept them yet, and @@ -214,7 +227,12 @@ def bind_tspark_params(kernel_stmt, parameters: List[ttypes.TSparkParameter]) -> ) value_str = _tspark_param_value_str(param) - # The kernel takes 1-based ordinals; `i` is already that. - # Errors from the kernel side (bad literal, unsupported type, - # etc.) come up as KernelError and bubble through normally. - kernel_stmt.bind_param(i, value_str, sql_type) + # ``ordinal`` on connector-native params is a bool. ``True`` + # → positional (assign the next 1-based ordinal). Anything + # else with a name → named binding. + name = getattr(param, "name", None) + if name and getattr(param, "ordinal", None) is not True: + kernel_stmt.bind_named_param(name, value_str, sql_type) + else: + positional_index += 1 + kernel_stmt.bind_param(positional_index, value_str, sql_type) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 7ea0d7f5c..e7b5337f6 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -19,6 +19,7 @@ pyarrow = None from databricks.sql import OperationalError +from databricks.sql.exc import ProgrammingError from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager from databricks.sql.thrift_api.TCLIService.ttypes import ( TRowSet, @@ -548,9 +549,7 @@ def escape_args(self, parameters): elif isinstance(parameters, (list, tuple)): return tuple(self.escape_item(x) for x in parameters) else: - raise exc.ProgrammingError( - "Unsupported param format: {}".format(parameters) - ) + raise ProgrammingError("Unsupported param format: {}".format(parameters)) def escape_number(self, item): return item @@ -606,7 +605,7 @@ def escape_item(self, item): elif isinstance(item, Mapping): return self.escape_mapping(item) else: - raise exc.ProgrammingError("Unsupported object {}".format(item)) + raise ProgrammingError("Unsupported object {}".format(item)) def inject_parameters(operation: str, parameters: Dict[str, str]): diff --git a/tests/e2e/test_kernel_backend.py b/tests/e2e/test_kernel_backend.py index d2c0c9b9c..a6ea851b8 100644 --- a/tests/e2e/test_kernel_backend.py +++ b/tests/e2e/test_kernel_backend.py @@ -241,6 +241,34 @@ def test_parameterized_query_with_null(conn): assert rows[0][0] is True +def test_parameterized_query_named_params(conn): + """Named parameter binding via the kernel backend. The + connector passes ``parameters={name: value}`` dicts (DB-API + style); the kernel forwards them through ``bind_named_param`` + so the SEA wire payload sets ``StatementParameter.name`` (the + spec-required public form per canonical proto). + """ + with conn.cursor() as cur: + cur.execute( + "SELECT :n AS n, :s AS s, :b AS b", + {"n": 42, "s": "alice", "b": True}, + ) + rows = cur.fetchall() + assert len(rows) == 1 + assert rows[0][0] == 42 + assert rows[0][1] == "alice" + assert rows[0][2] is True + + +def test_parameterized_query_named_param_with_null(conn): + """``None`` value in a named binding flows through as + VoidParameter → kernel ``TypedValue::Null``.""" + with conn.cursor() as cur: + cur.execute("SELECT :x IS NULL AS is_null", {"x": None}) + rows = cur.fetchall() + assert rows[0][0] is True + + def test_parameterized_query_decimal(conn): """DECIMAL parameters carry precision/scale in the SQL type string ('DECIMAL(p,s)') — the kernel parser extracts them so diff --git a/tests/unit/test_kernel_type_mapping.py b/tests/unit/test_kernel_type_mapping.py index c859ffca1..bfc3087d7 100644 --- a/tests/unit/test_kernel_type_mapping.py +++ b/tests/unit/test_kernel_type_mapping.py @@ -145,6 +145,63 @@ def test_description_recovers_complex_type_name_from_metadata(metadata_value, ex assert desc[0][1] == expected +@pytest.mark.parametrize( + "arrow_type, expected_precision, expected_scale", + [ + (pa.decimal128(10, 2), 10, 2), + (pa.decimal128(38, 0), 38, 0), + (pa.decimal128(38, 18), 38, 18), + # Decimal256 — kernel doesn't emit it today (server uses + # `Decimal128` exclusively), but the extraction helper handles + # any pyarrow decimal type via `is_decimal`. Locking in the + # contract. + (pa.decimal256(76, 38), 76, 38), + ], +) +def test_description_extracts_decimal_precision_scale( + arrow_type, expected_precision, expected_scale +): + """PEP 249 description slots 4/5 (precision, scale) must be + populated for DECIMAL columns. The Thrift backend reports them; + kernel must match. Without extraction, SQLAlchemy / pandas-read-sql + can't tell ``DECIMAL(10,2)`` from ``DECIMAL(38,18)``.""" + schema = pa.schema([("amount", arrow_type)]) + desc = description_from_arrow_schema(schema) + assert len(desc) == 1 + d = desc[0] + assert d[0] == "amount" + assert d[1] == "decimal" + # Slots 2/3 (display_size, internal_size) stay None; the Thrift + # backend doesn't populate them either, and matching is more + # valuable than introducing new info. + assert d[2] is None + assert d[3] is None + # Slots 4/5 are the precision/scale this test exists to lock in. + assert d[4] == expected_precision + assert d[5] == expected_scale + # Slot 6 (null_ok) stays None — see the parity rationale in + # `test_description_null_ok_always_none_regardless_of_field_nullable`. + assert d[6] is None + + +def test_description_non_decimal_columns_have_none_precision_scale(): + """Companion to the decimal test: non-decimal columns must report + ``(None, None)`` in slots 4/5. Catches a regression where the + helper accidentally extracts precision from non-decimal Arrow + types (e.g. ``Time64`` fractional-second precision).""" + schema = pa.schema( + [ + ("i", pa.int64()), + ("s", pa.string()), + ("ts", pa.timestamp("us")), + ] + ) + desc = description_from_arrow_schema(schema) + for d in desc: + assert d[4] is None, f"precision must be None for {d[1]}, got {d[4]}" + assert d[5] is None, f"scale must be None for {d[1]}, got {d[5]}" + + def test_description_passes_through_unknown_databricks_type_name(): """Server-reported names other than the handful we explicitly recognise (VARIANT / ARRAY / MAP / STRUCT) defer to the Arrow @@ -182,15 +239,22 @@ def _mk_param(*, type, value, ordinal=True, name=None): class _RecordingStmt: """Stand-in for the kernel `Statement` pyclass — records every - `bind_param` call so tests can assert the (ordinal, value, type) - triples the mapper forwarded.""" + `bind_param` / `bind_named_param` call so tests can assert the + triples the mapper forwarded. + + Positional calls land in `calls` as `(ordinal, value, type)`; + named calls land in `named_calls` as `(name, value, type)`.""" def __init__(self): self.calls = [] + self.named_calls = [] def bind_param(self, ordinal, value_str, sql_type): self.calls.append((ordinal, value_str, sql_type)) + def bind_named_param(self, name, value_str, sql_type): + self.named_calls.append((name, value_str, sql_type)) + def test_bind_tspark_params_forwards_each_param_positionally(): from databricks.sql.backend.kernel.type_mapping import bind_tspark_params @@ -233,21 +297,39 @@ def test_bind_tspark_params_void_passes_through(): assert stmt.calls == [(1, None, "VOID")] -def test_bind_tspark_params_named_param_rejected(): - """The kernel doesn't accept named bindings on the SEA wire; - surface that at the connector layer so the user gets a pointed - error instead of a server-side rejection.""" +def test_bind_tspark_params_named_param_forwarded(): + """Named bindings route through `bind_named_param` so the SEA + wire payload sets `StatementParameter.name` (the spec-required + public form per canonical proto).""" from databricks.sql.backend.kernel.type_mapping import bind_tspark_params - from databricks.sql.exc import NotSupportedError p = _mk_param(type="INT", value="42", ordinal=False, name="my_param") stmt = _RecordingStmt() - with pytest.raises(NotSupportedError, match="(?i)named"): - bind_tspark_params(stmt, [p]) - # Nothing should have been forwarded before the rejection. + bind_tspark_params(stmt, [p]) + assert stmt.named_calls == [("my_param", "42", "INT")] + # Positional path untouched — no ordinal consumed. assert stmt.calls == [] +def test_bind_tspark_params_named_does_not_consume_positional_ordinal(): + """When the list mixes positional and named params, the positional + ordinal counter must skip past named bindings — the named entry + doesn't take ordinal slot 2.""" + from databricks.sql.backend.kernel.type_mapping import bind_tspark_params + + params = [ + _mk_param(type="INT", value="1", ordinal=True), + _mk_param(type="INT", value="2", ordinal=False, name="n"), + _mk_param(type="INT", value="3", ordinal=True), + ] + stmt = _RecordingStmt() + bind_tspark_params(stmt, params) + # Positional indices are 1 and 2 (not 1 and 3) — named binding + # doesn't claim an ordinal slot. + assert stmt.calls == [(1, "1", "INT"), (2, "3", "INT")] + assert stmt.named_calls == [("n", "2", "INT")] + + def test_bind_tspark_params_missing_type_defaults_to_string(): """Defensive: a TSparkParameter with no `type` shouldn't crash the mapper — fall back to STRING and let the kernel parse.""" @@ -304,17 +386,16 @@ def test_bind_tspark_params_arguments_field_rejected(): assert stmt.calls == [] -def test_bind_tspark_params_named_with_ordinal_none_rejected(): +def test_bind_tspark_params_named_with_ordinal_none_routes_named(): """Defensive: a TSparkParameter with a name and ordinal=None - (Thrift default) should also be rejected as a named binding — - not silently routed positionally with the name dropped.""" + (Thrift default) still routes via `bind_named_param` — `ordinal` + being `not True` is enough to flag the binding as named.""" from databricks.sql.backend.kernel.type_mapping import bind_tspark_params - from databricks.sql.exc import NotSupportedError from databricks.sql.thrift_api.TCLIService import ttypes p = ttypes.TSparkParameter(ordinal=None, name="my_param", type="INT") p.value = ttypes.TSparkParameterValue(stringValue="42") stmt = _RecordingStmt() - with pytest.raises(NotSupportedError, match="(?i)named"): - bind_tspark_params(stmt, [p]) + bind_tspark_params(stmt, [p]) + assert stmt.named_calls == [("my_param", "42", "INT")] assert stmt.calls == []