Skip to content

Commit bdeedd7

Browse files
ntjohnson1claude
andcommitted
test: move pickle/multiprocessing helpers into a non-test module
Spawn workers re-import the module that defined a pickled function by its ``__module__`` attribute. Pytest's ``--import-mode=importlib`` (used in CI) can load test modules under synthetic names that the worker's normal import machinery cannot resolve, which manifests as ``Pool.map`` hanging forever waiting on a worker that died during unpickling. Move the worker-side helpers (``_apply_builtin_expr``, ``_apply_udf_expr``, ``_register_udf_on_global_ctx``, ``_build_add_ten_udf``) to a regular ``tests._pickle_multiprocessing_helpers`` module. The leading underscore keeps pytest from collecting it; spawn workers import it under its real dotted name in both parent and child. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 106ea3c commit bdeedd7

2 files changed

Lines changed: 97 additions & 55 deletions

File tree

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""Helpers for :mod:`test_pickle_multiprocessing`.
19+
20+
Spawn workers re-import the module that defines a pickled function by the
21+
function's ``__module__`` attribute. Pytest's ``--import-mode=importlib``
22+
loads test modules under synthetic names that the worker cannot resolve via
23+
the normal import machinery, which can cause ``Pool.map`` to hang waiting
24+
for a worker that died during unpickling.
25+
26+
Keeping the helpers in this regular (non-test) module side-steps that: it
27+
is importable under its real dotted name (``tests._pickle_multiprocessing_helpers``)
28+
in both parent and worker, and the leading underscore keeps pytest from
29+
collecting it as a test module.
30+
"""
31+
32+
from __future__ import annotations
33+
34+
import pyarrow as pa
35+
import pyarrow.compute as pc
36+
from datafusion import SessionContext, udf
37+
38+
UDF_NAME = "mp_pickle_add_ten"
39+
40+
41+
def add_ten_impl(array: pa.Array) -> pa.Array:
42+
return pc.add(array, 10)
43+
44+
45+
def build_add_ten_udf():
46+
return udf(
47+
add_ten_impl,
48+
[pa.int64()],
49+
pa.int64(),
50+
volatility="immutable",
51+
name=UDF_NAME,
52+
)
53+
54+
55+
def register_udf_on_global_ctx() -> None:
56+
"""Pool initializer: install a global ctx in the worker that knows the UDF.
57+
58+
``Expr.__setstate__`` resolves UDF references by name against the
59+
*global* context, so the registration must happen before any task arg is
60+
unpickled — i.e. in the Pool's ``initializer``, not in the task body.
61+
"""
62+
ctx = SessionContext()
63+
ctx.register_udf(build_add_ten_udf())
64+
ctx.set_as_global()
65+
66+
67+
def apply_builtin_expr(args: tuple) -> list:
68+
expr, values = args
69+
ctx = SessionContext()
70+
batch = pa.RecordBatch.from_arrays([pa.array(values, type=pa.int64())], names=["a"])
71+
df = ctx.create_dataframe([[batch]], name="t")
72+
return df.select(expr.alias("out")).collect()[0].column(0).to_pylist()
73+
74+
75+
def apply_udf_expr(args: tuple) -> list:
76+
expr, values = args
77+
# Reuse the worker's global ctx so the UDF registered by the initializer
78+
# is visible during execution as well as during arg unpickling.
79+
ctx = SessionContext.global_ctx()
80+
batch = pa.RecordBatch.from_arrays([pa.array(values, type=pa.int64())], names=["a"])
81+
df = ctx.create_dataframe([[batch]], name="t_udf")
82+
return df.select(expr.alias("out")).collect()[0].column(0).to_pylist()

python/tests/test_pickle_multiprocessing.py

Lines changed: 15 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -32,64 +32,24 @@
3232
2. A custom Python UDF — the worker must register the UDF on its global
3333
context *before* unpickling, since ``Expr.__setstate__`` resolves
3434
function references by name against the global context.
35+
36+
Worker-side helpers live in :mod:`tests._pickle_multiprocessing_helpers`
37+
rather than in this test module, so spawn workers can resolve them by their
38+
real dotted name regardless of how pytest imported the test module.
3539
"""
3640

3741
from __future__ import annotations
3842

3943
import multiprocessing as mp
4044

41-
import pyarrow as pa
42-
import pyarrow.compute as pc
43-
from datafusion import SessionContext, col, lit, udf
44-
45-
# Module-scope helpers — must be importable by name so the `spawn` workers
46-
# can resolve them after re-importing this module.
47-
48-
_UDF_NAME = "mp_pickle_add_ten"
49-
50-
51-
def _add_ten_impl(array: pa.Array) -> pa.Array:
52-
return pc.add(array, 10)
53-
54-
55-
def _build_add_ten_udf():
56-
return udf(
57-
_add_ten_impl,
58-
[pa.int64()],
59-
pa.int64(),
60-
volatility="immutable",
61-
name=_UDF_NAME,
62-
)
63-
64-
65-
def _register_udf_on_global_ctx() -> None:
66-
"""Pool initializer: install a global ctx in the worker that knows the UDF.
67-
68-
``Expr.__setstate__`` resolves UDF references by name against the
69-
*global* context, so the registration must happen before any task arg is
70-
unpickled — i.e. in the Pool's ``initializer``, not in the task body.
71-
"""
72-
ctx = SessionContext()
73-
ctx.register_udf(_build_add_ten_udf())
74-
ctx.set_as_global()
75-
76-
77-
def _apply_builtin_expr(args: tuple) -> list:
78-
expr, values = args
79-
ctx = SessionContext()
80-
batch = pa.RecordBatch.from_arrays([pa.array(values, type=pa.int64())], names=["a"])
81-
df = ctx.create_dataframe([[batch]], name="t")
82-
return df.select(expr.alias("out")).collect()[0].column(0).to_pylist()
83-
45+
from datafusion import col, lit
8446

85-
def _apply_udf_expr(args: tuple) -> list:
86-
expr, values = args
87-
# Reuse the worker's global ctx so the UDF registered by the initializer
88-
# is visible during execution as well as during arg unpickling.
89-
ctx = SessionContext.global_ctx()
90-
batch = pa.RecordBatch.from_arrays([pa.array(values, type=pa.int64())], names=["a"])
91-
df = ctx.create_dataframe([[batch]], name="t_udf")
92-
return df.select(expr.alias("out")).collect()[0].column(0).to_pylist()
47+
from tests._pickle_multiprocessing_helpers import (
48+
apply_builtin_expr,
49+
apply_udf_expr,
50+
build_add_ten_udf,
51+
register_udf_on_global_ctx,
52+
)
9353

9454

9555
def test_builtin_expr_through_multiprocessing_pool() -> None:
@@ -99,7 +59,7 @@ def test_builtin_expr_through_multiprocessing_pool() -> None:
9959
chunks = [[1, 2, 3], [10, 20, 30]]
10060

10161
with spawn_ctx.Pool(processes=2) as pool:
102-
results = pool.map(_apply_builtin_expr, [(expr, c) for c in chunks])
62+
results = pool.map(apply_builtin_expr, [(expr, c) for c in chunks])
10363

10464
assert results == [[3, 5, 7], [21, 41, 61]]
10565

@@ -108,11 +68,11 @@ def test_udf_expr_through_multiprocessing_pool() -> None:
10868
"""A UDF-backed ``Expr`` survives ``Pool.map`` when the worker registers
10969
the UDF on its global context via the Pool initializer."""
11070
spawn_ctx = mp.get_context("spawn")
111-
add_ten = _build_add_ten_udf()
71+
add_ten = build_add_ten_udf()
11272
expr = add_ten(col("a"))
11373
chunks = [[1, 2, 3], [10, 20, 30]]
11474

115-
with spawn_ctx.Pool(processes=2, initializer=_register_udf_on_global_ctx) as pool:
116-
results = pool.map(_apply_udf_expr, [(expr, c) for c in chunks])
75+
with spawn_ctx.Pool(processes=2, initializer=register_udf_on_global_ctx) as pool:
76+
results = pool.map(apply_udf_expr, [(expr, c) for c in chunks])
11777

11878
assert results == [[11, 12, 13], [20, 30, 40]]

0 commit comments

Comments
 (0)