#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
User-defined function related classes and functions
"""
import functools
import inspect
import sys
import warnings
from typing import Callable, Any, TYPE_CHECKING, Optional, cast, Union
from pyspark.util import PythonEvalType
from pyspark.sql.column import Column
from pyspark.sql.types import (
    DataType,
    StringType,
    StructType,
    _parse_datatype_string,
)
from pyspark.sql.utils import get_active_spark_context
from pyspark.sql.pandas.types import to_arrow_type
from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version
from pyspark.errors import PySparkTypeError, PySparkNotImplementedError, PySparkRuntimeError
if TYPE_CHECKING:
    from py4j.java_gateway import JavaObject
    from pyspark.core.context import SparkContext
    from pyspark.sql._typing import DataTypeOrString, ColumnOrName, UserDefinedFunctionLike
    from pyspark.sql.session import SparkSession
__all__ = ["UDFRegistration"]
def _wrap_function(
    sc: "SparkContext", func: Callable[..., Any], returnType: Optional[DataType] = None
) -> "JavaObject":
    from pyspark.core.rdd import _prepare_for_python_RDD
    command: Any
    if returnType is None:
        command = func
    else:
        command = (func, returnType)
    pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
    assert sc._jvm is not None
    return sc._jvm.SimplePythonFunction(
        bytearray(pickled_command),
        env,
        includes,
        sc.pythonExec,
        sc.pythonVer,
        broadcast_vars,
        sc._javaAccumulator,
    )
def _create_udf(
    f: Callable[..., Any],
    returnType: "DataTypeOrString",
    evalType: int,
    name: Optional[str] = None,
    deterministic: bool = True,
) -> "UserDefinedFunctionLike":
    """Create a regular(non-Arrow-optimized) Python UDF."""
    # Set the name of the UserDefinedFunction object to be the name of function f
    udf_obj = UserDefinedFunction(
        f, returnType=returnType, name=name, evalType=evalType, deterministic=deterministic
    )
    return udf_obj._wrapped()
def _create_py_udf(
    f: Callable[..., Any],
    returnType: "DataTypeOrString",
    useArrow: Optional[bool] = None,
) -> "UserDefinedFunctionLike":
    """Create a regular/Arrow-optimized Python UDF."""
    # The following table shows the results when the type coercion in Arrow is needed, that is,
    # when the user-specified return type(SQL Type) of the UDF and the actual instance(Python
    # Value(Type)) that the UDF returns are different.
    # Arrow and Pickle have different type coercion rules, so a UDF might have a different result
    # with/without Arrow optimization. That's the main reason the Arrow optimization for Python
    # UDFs is disabled by default.
    # +-----------------------------+--------------+----------+------+---------------+--------------------+-----------------------------+----------+----------------------+---------+--------------------+----------------------------+------------+--------------+  # noqa
    # |SQL Type \ Python Value(Type)|None(NoneType)|True(bool)|1(int)|         a(str)|    1970-01-01(date)|1970-01-01 00:00:00(datetime)|1.0(float)|array('i', [1])(array)|[1](list)|         (1,)(tuple)|bytearray(b'ABC')(bytearray)|  1(Decimal)|{'a': 1}(dict)|  # noqa
    # +-----------------------------+--------------+----------+------+---------------+--------------------+-----------------------------+----------+----------------------+---------+--------------------+----------------------------+------------+--------------+  # noqa
    # |                      boolean|          None|      True|  None|           None|                None|                         None|      None|                  None|     None|                None|                        None|        None|          None|  # noqa
    # |                      tinyint|          None|      None|     1|           None|                None|                         None|      None|                  None|     None|                None|                        None|        None|          None|  # noqa
    # |                     smallint|          None|      None|     1|           None|                None|                         None|      None|                  None|     None|                None|                        None|        None|          None|  # noqa
    # |                          int|          None|      None|     1|           None|                None|                         None|      None|                  None|     None|                None|                        None|        None|          None|  # noqa
    # |                       bigint|          None|      None|     1|           None|                None|                         None|      None|                  None|     None|                None|                        None|        None|          None|  # noqa
    # |                       string|          None|    'true'|   '1'|            'a'|'java.util.Gregor...|         'java.util.Gregor...|     '1.0'|         '[I@120d813a'|    '[1]'|'[Ljava.lang.Obje...|               '[B@48571878'|         '1'|       '{a=1}'|  # noqa
    # |                         date|          None|         X|     X|              X|datetime.date(197...|         datetime.date(197...|         X|                     X|        X|                   X|                           X|           X|             X|  # noqa
    # |                    timestamp|          None|         X|     X|              X|                   X|         datetime.datetime...|         X|                     X|        X|                   X|                           X|           X|             X|  # noqa
    # |                        float|          None|      None|  None|           None|                None|                         None|       1.0|                  None|     None|                None|                        None|        None|          None|  # noqa
    # |                       double|          None|      None|  None|           None|                None|                         None|       1.0|                  None|     None|                None|                        None|        None|          None|  # noqa
    # |                       binary|          None|      None|  None|bytearray(b'a')|                None|                         None|      None|                  None|     None|                None|           bytearray(b'ABC')|        None|          None|  # noqa
    # |                decimal(10,0)|          None|      None|  None|           None|                None|                         None|      None|                  None|     None|                None|                        None|Decimal('1')|          None|  # noqa
    # +-----------------------------+--------------+----------+------+---------------+--------------------+-----------------------------+----------+----------------------+---------+--------------------+----------------------------+------------+--------------+  # noqa
    # Note: Python 3.9.15, Pandas 1.5.2 and PyArrow 10.0.1 are used.
    # Note: The values of 'SQL Type' are DDL formatted strings, which can be used as `returnType`s.
    # Note: The values inside the table are generated by `repr`. X' means it throws an exception
    # during the conversion.
    is_arrow_enabled = False
    if useArrow is None:
        from pyspark.sql import SparkSession
        session = SparkSession._instantiatedSession
        is_arrow_enabled = (
            False
            if session is None
            else session.conf.get("spark.sql.execution.pythonUDF.arrow.enabled") == "true"
        )
    else:
        is_arrow_enabled = useArrow
    if is_arrow_enabled:
        try:
            require_minimum_pandas_version()
            require_minimum_pyarrow_version()
        except ImportError:
            is_arrow_enabled = False
            warnings.warn(
                "Arrow optimization failed to enable because PyArrow or Pandas is not installed. "
                "Falling back to a non-Arrow-optimized UDF.",
                RuntimeWarning,
            )
    eval_type: int = PythonEvalType.SQL_BATCHED_UDF
    if is_arrow_enabled:
        eval_type = PythonEvalType.SQL_ARROW_BATCHED_UDF
    return _create_udf(f, returnType, eval_type)
[docs]class UserDefinedFunction:
    """
    User defined function in Python
    .. versionadded:: 1.3
    Notes
    -----
    The constructor of this class is not supposed to be directly called.
    Use :meth:`pyspark.sql.functions.udf` or :meth:`pyspark.sql.functions.pandas_udf`
    to create this instance.
    """
    def __init__(
        self,
        func: Callable[..., Any],
        returnType: "DataTypeOrString" = StringType(),
        name: Optional[str] = None,
        evalType: int = PythonEvalType.SQL_BATCHED_UDF,
        deterministic: bool = True,
    ):
        if not callable(func):
            raise PySparkTypeError(
                errorClass="NOT_CALLABLE",
                messageParameters={"arg_name": "func", "arg_type": type(func).__name__},
            )
        if not isinstance(returnType, (DataType, str)):
            raise PySparkTypeError(
                errorClass="NOT_DATATYPE_OR_STR",
                messageParameters={
                    "arg_name": "returnType",
                    "arg_type": type(returnType).__name__,
                },
            )
        if not isinstance(evalType, int):
            raise PySparkTypeError(
                errorClass="NOT_INT",
                messageParameters={"arg_name": "evalType", "arg_type": type(evalType).__name__},
            )
        self.func = func
        self._returnType = returnType
        # Stores UserDefinedPythonFunctions jobj, once initialized
        self._returnType_placeholder: Optional[DataType] = None
        self._judf_placeholder = None
        self._name = name or (
            func.__name__ if hasattr(func, "__name__") else func.__class__.__name__
        )
        self.evalType = evalType
        self.deterministic = deterministic
    @staticmethod
    def _check_return_type(returnType: DataType, evalType: int) -> None:
        if evalType == PythonEvalType.SQL_ARROW_BATCHED_UDF:
            try:
                to_arrow_type(returnType)
            except TypeError:
                raise PySparkNotImplementedError(
                    errorClass="NOT_IMPLEMENTED",
                    messageParameters={
                        "feature": f"Invalid return type with Arrow-optimized Python UDF: "
                        f"{returnType}"
                    },
                )
        elif (
            evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF
            or evalType == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
        ):
            try:
                to_arrow_type(returnType)
            except TypeError:
                raise PySparkNotImplementedError(
                    errorClass="NOT_IMPLEMENTED",
                    messageParameters={
                        "feature": f"Invalid return type with scalar Pandas UDFs: " f"{returnType}"
                    },
                )
        elif (
            evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
            or evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE
        ):
            if isinstance(returnType, StructType):
                try:
                    to_arrow_type(returnType)
                except TypeError:
                    raise PySparkNotImplementedError(
                        errorClass="NOT_IMPLEMENTED",
                        messageParameters={
                            "feature": f"Invalid return type with grouped map Pandas UDFs or "
                            f"at groupby.applyInPandas(WithState): {returnType}"
                        },
                    )
            else:
                raise PySparkTypeError(
                    errorClass="INVALID_RETURN_TYPE_FOR_PANDAS_UDF",
                    messageParameters={
                        "eval_type": "SQL_GROUPED_MAP_PANDAS_UDF or "
                        "SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE",
                        "return_type": str(returnType),
                    },
                )
        elif (
            evalType == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
            or evalType == PythonEvalType.SQL_MAP_ARROW_ITER_UDF
        ):
            if isinstance(returnType, StructType):
                try:
                    to_arrow_type(returnType)
                except TypeError:
                    raise PySparkNotImplementedError(
                        errorClass="NOT_IMPLEMENTED",
                        messageParameters={
                            "feature": f"Invalid return type in mapInPandas: " f"{returnType}"
                        },
                    )
            else:
                raise PySparkTypeError(
                    errorClass="INVALID_RETURN_TYPE_FOR_PANDAS_UDF",
                    messageParameters={
                        "eval_type": "SQL_MAP_PANDAS_ITER_UDF or SQL_MAP_ARROW_ITER_UDF",
                        "return_type": str(returnType),
                    },
                )
        elif evalType == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF:
            if isinstance(returnType, StructType):
                try:
                    to_arrow_type(returnType)
                except TypeError:
                    raise PySparkNotImplementedError(
                        errorClass="NOT_IMPLEMENTED",
                        messageParameters={
                            "feature": "Invalid return type with grouped map Arrow UDFs or "
                            f"at groupby.applyInArrow: {returnType}"
                        },
                    )
            else:
                raise PySparkTypeError(
                    errorClass="INVALID_RETURN_TYPE_FOR_ARROW_UDF",
                    messageParameters={
                        "eval_type": "SQL_GROUPED_MAP_ARROW_UDF",
                        "return_type": str(returnType),
                    },
                )
        elif evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
            if isinstance(returnType, StructType):
                try:
                    to_arrow_type(returnType)
                except TypeError:
                    raise PySparkNotImplementedError(
                        errorClass="NOT_IMPLEMENTED",
                        messageParameters={
                            "feature": f"Invalid return type in cogroup.applyInPandas: "
                            f"{returnType}"
                        },
                    )
            else:
                raise PySparkTypeError(
                    errorClass="INVALID_RETURN_TYPE_FOR_PANDAS_UDF",
                    messageParameters={
                        "eval_type": "SQL_COGROUPED_MAP_PANDAS_UDF",
                        "return_type": str(returnType),
                    },
                )
        elif evalType == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF:
            if isinstance(returnType, StructType):
                try:
                    to_arrow_type(returnType)
                except TypeError:
                    raise PySparkNotImplementedError(
                        errorClass="NOT_IMPLEMENTED",
                        messageParameters={
                            "feature": "Invalid return type in cogroup.applyInArrow: "
                            f"{returnType}"
                        },
                    )
            else:
                raise PySparkTypeError(
                    errorClass="INVALID_RETURN_TYPE_FOR_ARROW_UDF",
                    messageParameters={
                        "eval_type": "SQL_COGROUPED_MAP_ARROW_UDF",
                        "return_type": str(returnType),
                    },
                )
        elif evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
            try:
                # StructType is not yet allowed as a return type, explicitly check here to fail fast
                if isinstance(returnType, StructType):
                    raise PySparkNotImplementedError(
                        errorClass="NOT_IMPLEMENTED",
                        messageParameters={
                            "feature": f"Invalid return type with grouped aggregate Pandas UDFs: "
                            f"{returnType}"
                        },
                    )
                to_arrow_type(returnType)
            except TypeError:
                raise PySparkNotImplementedError(
                    errorClass="NOT_IMPLEMENTED",
                    messageParameters={
                        "feature": f"Invalid return type with grouped aggregate Pandas UDFs: "
                        f"{returnType}"
                    },
                )
    @property
    def returnType(self) -> DataType:
        # Make sure this is called after SparkContext is initialized.
        # ``_parse_datatype_string`` accesses to JVM for parsing a DDL formatted string.
        # TODO: PythonEvalType.SQL_BATCHED_UDF
        if self._returnType_placeholder is None:
            if isinstance(self._returnType, DataType):
                self._returnType_placeholder = self._returnType
            else:
                self._returnType_placeholder = _parse_datatype_string(self._returnType)
        UserDefinedFunction._check_return_type(self._returnType_placeholder, self.evalType)
        return self._returnType_placeholder
    @property
    def _judf(self) -> "JavaObject":
        # It is possible that concurrent access, to newly created UDF,
        # will initialize multiple UserDefinedPythonFunctions.
        # This is unlikely, doesn't affect correctness,
        # and should have a minimal performance impact.
        if self._judf_placeholder is None:
            self._judf_placeholder = self._create_judf(self.func)
        return self._judf_placeholder
    def _create_judf(self, func: Callable[..., Any]) -> "JavaObject":
        from pyspark.sql import SparkSession
        spark = SparkSession._getActiveSessionOrCreate()
        sc = spark.sparkContext
        wrapped_func = _wrap_function(sc, func, self.returnType)
        jdt = spark._jsparkSession.parseDataType(self.returnType.json())
        assert sc._jvm is not None
        judf = getattr(sc._jvm, "org.apache.spark.sql.execution.python.UserDefinedPythonFunction")(
            self._name, wrapped_func, jdt, self.evalType, self.deterministic
        )
        return judf
    def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") -> Column:
        from pyspark.sql.classic.column import _to_java_column, _to_seq
        sc = get_active_spark_context()
        assert sc._jvm is not None
        jcols = [_to_java_column(arg) for arg in args] + [
            sc._jvm.PythonSQLUtils.namedArgumentExpression(key, _to_java_column(value))
            for key, value in kwargs.items()
        ]
        profiler_enabled = sc._conf.get("spark.python.profile", "false") == "true"
        memory_profiler_enabled = sc._conf.get("spark.python.profile.memory", "false") == "true"
        if profiler_enabled or memory_profiler_enabled:
            # Disable profiling Pandas UDFs with iterators as input/output.
            if self.evalType in [
                PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
                PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
                PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
            ]:
                warnings.warn(
                    "Profiling UDFs with iterators input/output is not supported.",
                    UserWarning,
                )
                judf = self._judf
                return Column(judf.apply(_to_seq(sc, jcols)))
            # Disallow enabling two profilers at the same time.
            if profiler_enabled and memory_profiler_enabled:
                # When both profilers are enabled, they interfere with each other,
                # that makes the result profile misleading.
                raise PySparkRuntimeError(
                    errorClass="CANNOT_SET_TOGETHER",
                    messageParameters={
                        "arg_list": "'spark.python.profile' and "
                        "'spark.python.profile.memory' configuration"
                    },
                )
            elif profiler_enabled:
                f = self.func
                profiler = sc.profiler_collector.new_udf_profiler(sc)
                @functools.wraps(f)
                def func(*args: Any, **kwargs: Any) -> Any:
                    assert profiler is not None
                    return profiler.profile(f, *args, **kwargs)
                func.__signature__ = inspect.signature(f)  # type: ignore[attr-defined]
                judf = self._create_judf(func)
                jUDFExpr = judf.builderWithColumns(_to_seq(sc, jcols))
                jPythonUDF = judf.fromUDFExpr(jUDFExpr)
                id = jUDFExpr.resultId().id()
                sc.profiler_collector.add_profiler(id, profiler)
            else:  # memory_profiler_enabled
                f = self.func
                memory_profiler = sc.profiler_collector.new_memory_profiler(sc)
                (sub_lines, start_line) = inspect.getsourcelines(f.__code__)
                @functools.wraps(f)
                def func(*args: Any, **kwargs: Any) -> Any:
                    assert memory_profiler is not None
                    return memory_profiler.profile(
                        sub_lines, start_line, f, *args, **kwargs  # type: ignore[arg-type]
                    )
                func.__signature__ = inspect.signature(f)  # type: ignore[attr-defined]
                judf = self._create_judf(func)
                jUDFExpr = judf.builderWithColumns(_to_seq(sc, jcols))
                jPythonUDF = judf.fromUDFExpr(jUDFExpr)
                id = jUDFExpr.resultId().id()
                sc.profiler_collector.add_profiler(id, memory_profiler)
        else:
            judf = self._judf
            jPythonUDF = judf.apply(_to_seq(sc, jcols))
        return Column(jPythonUDF)
    # This function is for improving the online help system in the interactive interpreter.
    # For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and
    # argument annotation. (See: SPARK-19161)
    def _wrapped(self) -> "UserDefinedFunctionLike":
        """
        Wrap this udf with a function and attach docstring from func
        """
        # It is possible for a callable instance without __name__ attribute or/and
        # __module__ attribute to be wrapped here. For example, functools.partial. In this case,
        # we should avoid wrapping the attributes from the wrapped function to the wrapper
        # function. So, we take out these attribute names from the default names to set and
        # then manually assign it after being wrapped.
        assignments = tuple(
            a for a in functools.WRAPPER_ASSIGNMENTS if a != "__name__" and a != "__module__"
        )
        @functools.wraps(self.func, assigned=assignments)
        def wrapper(*args: "ColumnOrName", **kwargs: "ColumnOrName") -> Column:
            return self(*args, **kwargs)
        wrapper.__name__ = self._name
        wrapper.__module__ = (
            self.func.__module__
            if hasattr(self.func, "__module__")
            else self.func.__class__.__module__
        )
        wrapper.func = self.func  # type: ignore[attr-defined]
        wrapper.returnType = self.returnType  # type: ignore[attr-defined]
        wrapper.evalType = self.evalType  # type: ignore[attr-defined]
        wrapper.deterministic = self.deterministic  # type: ignore[attr-defined]
        wrapper.asNondeterministic = functools.wraps(  # type: ignore[attr-defined]
            self.asNondeterministic
        )(lambda: self.asNondeterministic()._wrapped())
        wrapper._unwrapped = self  # type: ignore[attr-defined]
        return wrapper  # type: ignore[return-value]
[docs]    def asNondeterministic(self) -> "UserDefinedFunction":
        """
        Updates UserDefinedFunction to nondeterministic.
        .. versionadded:: 2.3
        """
        # Here, we explicitly clean the cache to create a JVM UDF instance
        # with 'deterministic' updated. See SPARK-23233.
        self._judf_placeholder = None
        self.deterministic = False
        return self  
[docs]class UDFRegistration:
    """
    Wrapper for user-defined function registration. This instance can be accessed by
    :attr:`spark.udf` or :attr:`sqlContext.udf`.
    .. versionadded:: 1.3.1
    """
    def __init__(self, sparkSession: "SparkSession"):
        self.sparkSession = sparkSession
[docs]    def register(
        self,
        name: str,
        f: Union[Callable[..., Any], "UserDefinedFunctionLike"],
        returnType: Optional["DataTypeOrString"] = None,
    ) -> "UserDefinedFunctionLike":
        """Register a Python function (including lambda function) or a user-defined function
        as a SQL function.
        .. versionadded:: 1.3.1
        .. versionchanged:: 3.4.0
            Supports Spark Connect.
        Parameters
        ----------
        name : str,
            name of the user-defined function in SQL statements.
        f : function, :meth:`pyspark.sql.functions.udf` or :meth:`pyspark.sql.functions.pandas_udf`
            a Python function, or a user-defined function. The user-defined function can
            be either row-at-a-time or vectorized. See :meth:`pyspark.sql.functions.udf` and
            :meth:`pyspark.sql.functions.pandas_udf`.
        returnType : :class:`pyspark.sql.types.DataType` or str, optional
            the return type of the registered user-defined function. The value can
            be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
            `returnType` can be optionally specified when `f` is a Python function but not
            when `f` is a user-defined function. Please see the examples below.
        Returns
        -------
        function
            a user-defined function
        Notes
        -----
        To register a nondeterministic Python function, users need to first build
        a nondeterministic user-defined function for the Python function and then register it
        as a SQL function.
        Examples
        --------
        1. When `f` is a Python function:
            `returnType` defaults to string type and can be optionally specified. The produced
            object must match the specified type. In this case, this API works as if
            `register(name, f, returnType=StringType())`.
            >>> strlen = spark.udf.register("stringLengthString", lambda x: len(x))
            >>> spark.sql("SELECT stringLengthString('test')").collect()
            [Row(stringLengthString(test)='4')]
            >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect()
            [Row(stringLengthString(text)='3')]
            >>> from pyspark.sql.types import IntegerType
            >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
            >>> spark.sql("SELECT stringLengthInt('test')").collect()
            [Row(stringLengthInt(test)=4)]
            >>> from pyspark.sql.types import IntegerType
            >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
            >>> spark.sql("SELECT stringLengthInt('test')").collect()
            [Row(stringLengthInt(test)=4)]
        2. When `f` is a user-defined function (from Spark 2.3.0):
            Spark uses the return type of the given user-defined function as the return type of
            the registered user-defined function. `returnType` should not be specified.
            In this case, this API works as if `register(name, f)`.
            >>> from pyspark.sql.types import IntegerType
            >>> from pyspark.sql.functions import udf
            >>> slen = udf(lambda s: len(s), IntegerType())
            >>> _ = spark.udf.register("slen", slen)
            >>> spark.sql("SELECT slen('test')").collect()
            [Row(slen(test)=4)]
            >>> import random
            >>> from pyspark.sql.functions import udf
            >>> from pyspark.sql.types import IntegerType
            >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
            >>> new_random_udf = spark.udf.register("random_udf", random_udf)
            >>> spark.sql("SELECT random_udf()").collect()  # doctest: +SKIP
            [Row(random_udf()=82)]
            >>> import pandas as pd  # doctest: +SKIP
            >>> from pyspark.sql.functions import pandas_udf
            >>> @pandas_udf("integer")  # doctest: +SKIP
            ... def add_one(s: pd.Series) -> pd.Series:
            ...     return s + 1
            ...
            >>> _ = spark.udf.register("add_one", add_one)  # doctest: +SKIP
            >>> spark.sql("SELECT add_one(id) FROM range(3)").collect()  # doctest: +SKIP
            [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
            >>> @pandas_udf("integer")  # doctest: +SKIP
            ... def sum_udf(v: pd.Series) -> int:
            ...     return v.sum()
            ...
            >>> _ = spark.udf.register("sum_udf", sum_udf)  # doctest: +SKIP
            >>> q = "SELECT sum_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2"
            >>> spark.sql(q).collect()  # doctest: +SKIP
            [Row(sum_udf(v1)=1), Row(sum_udf(v1)=5)]
        """
        # This is to check whether the input function is from a user-defined function or
        # Python function.
        if hasattr(f, "asNondeterministic"):
            if returnType is not None:
                raise PySparkTypeError(
                    errorClass="CANNOT_SPECIFY_RETURN_TYPE_FOR_UDF",
                    messageParameters={"arg_name": "f", "return_type": str(returnType)},
                )
            f = cast("UserDefinedFunctionLike", f)
            if f.evalType not in [
                PythonEvalType.SQL_BATCHED_UDF,
                PythonEvalType.SQL_ARROW_BATCHED_UDF,
                PythonEvalType.SQL_SCALAR_PANDAS_UDF,
                PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
                PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
            ]:
                raise PySparkTypeError(
                    errorClass="INVALID_UDF_EVAL_TYPE",
                    messageParameters={
                        "eval_type": "SQL_BATCHED_UDF, SQL_ARROW_BATCHED_UDF, "
                        "SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_PANDAS_ITER_UDF or "
                        "SQL_GROUPED_AGG_PANDAS_UDF"
                    },
                )
            source_udf = _create_udf(
                f.func,
                returnType=f.returnType,
                name=name,
                evalType=f.evalType,
                deterministic=f.deterministic,
            )
            register_udf = source_udf._unwrapped  # type: ignore[attr-defined]
            return_udf = register_udf
        else:
            if returnType is None:
                returnType = StringType()
            return_udf = _create_udf(
                f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF, name=name
            )
            register_udf = return_udf._unwrapped
        self.sparkSession._jsparkSession.udf().registerPython(name, register_udf._judf)
        return return_udf 
[docs]    def registerJavaFunction(
        self,
        name: str,
        javaClassName: str,
        returnType: Optional["DataTypeOrString"] = None,
    ) -> None:
        """Register a Java user-defined function as a SQL function.
        In addition to a name and the function itself, the return type can be optionally specified.
        When the return type is not specified we would infer it via reflection.
        .. versionadded:: 2.3.0
        .. versionchanged:: 3.4.0
            Supports Spark Connect.
        Parameters
        ----------
        name : str
            name of the user-defined function
        javaClassName : str
            fully qualified name of java class
        returnType : :class:`pyspark.sql.types.DataType` or str, optional
            the return type of the registered Java function. The value can be either
            a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
        Examples
        --------
        >>> from pyspark.sql.types import IntegerType
        >>> spark.udf.registerJavaFunction(
        ...     "javaStringLength", "test.org.apache.spark.sql.JavaStringLength", IntegerType())
        ... # doctest: +SKIP
        >>> spark.sql("SELECT javaStringLength('test')").collect()  # doctest: +SKIP
        [Row(javaStringLength(test)=4)]
        >>> spark.udf.registerJavaFunction(
        ...     "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength")
        ... # doctest: +SKIP
        >>> spark.sql("SELECT javaStringLength2('test')").collect()  # doctest: +SKIP
        [Row(javaStringLength2(test)=4)]
        >>> spark.udf.registerJavaFunction(
        ...     "javaStringLength3", "test.org.apache.spark.sql.JavaStringLength", "integer")
        ... # doctest: +SKIP
        >>> spark.sql("SELECT javaStringLength3('test')").collect()  # doctest: +SKIP
        [Row(javaStringLength3(test)=4)]
        """
        jdt = None
        if returnType is not None:
            if not isinstance(returnType, DataType):
                returnType = _parse_datatype_string(returnType)
            jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json())
        self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) 
[docs]    def registerJavaUDAF(self, name: str, javaClassName: str) -> None:
        """Register a Java user-defined aggregate function as a SQL function.
        .. versionadded:: 2.3.0
        .. versionchanged:: 3.4.0
            Supports Spark Connect.
        name : str
            name of the user-defined aggregate function
        javaClassName : str
            fully qualified name of java class
        Examples
        --------
        >>> spark.udf.registerJavaUDAF("javaUDAF", "test.org.apache.spark.sql.MyDoubleAvg")
        ... # doctest: +SKIP
        >>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"])
        >>> df.createOrReplaceTempView("df")
        >>> q = "SELECT name, javaUDAF(id) as avg from df group by name order by name desc"
        >>> spark.sql(q).collect()  # doctest: +SKIP
        [Row(name='b', avg=102.0), Row(name='a', avg=102.0)]
        """
        self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName)  
def _test() -> None:
    import doctest
    from pyspark.sql import SparkSession
    import pyspark.sql.udf
    globs = pyspark.sql.udf.__dict__.copy()
    spark = SparkSession.builder.master("local[4]").appName("sql.udf tests").getOrCreate()
    globs["spark"] = spark
    (failure_count, test_count) = doctest.testmod(
        pyspark.sql.udf, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE
    )
    spark.stop()
    if failure_count:
        sys.exit(-1)
if __name__ == "__main__":
    _test()