Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
654 views
in Technique[技术] by (71.8m points)

python - How to pass additional parameters to numba cfunc passed as LowLevelCallable to scipy.integrate.quad

The documentation discusses using numba's cfuncs as LowLevelCallable argument of scipy.integrate.quad. I need the same thing with additional parameter.

I'm basically trying to do something like this:

import numpy as np
from numba import cfunc
import numba.types
voidp = numba.types.voidptr
def integrand(t, params):
    a = params[0] # this is additional parameter
    return np.exp(-t/a) / t**2
nb_integrand = cfunc(numba.float32(numba.float32, voidp))(integrand)

However, it does not work, because params are supposed to be voidptr/void* and they cannot be transformed to double. I have the following error message:

TypingError: Failed at nopython (nopython frontend)
Invalid usage of getitem with parameters (void*, int64)
 * parameterized

I didn't find any information on how to extract values from void* in Numba. In C, it should be something like a = *((double*) params) — is it possible to do the same thing in Numba?

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

1. Passing extra arguments through scipy.integrate.quad

The quad docs say:

If the user desires improved integration performance, then f may be a scipy.LowLevelCallable with one of the signatures:

double func(double x)

double func(double x, void *user_data)

double func(int n, double *xx)

double func(int n, double *xx, void *user_data)

The user_data is the data contained in the scipy.LowLevelCallable. In the call forms with xx, n is the length of the xx array which contains xx[0] == x and the rest of the items are numbers contained in the args argument of quad.

Therefore to pass an extra argument to integrand through quad, you are better of using the double func(int n, double *xx) signature.

You can write a decorator to your integrand function to transform it to a LowLevelCallable like so:

import numpy as np
import scipy.integrate as si
import numba
from numba import cfunc
from numba.types import intc, CPointer, float64
from scipy import LowLevelCallable


def jit_integrand_function(integrand_function):
    jitted_function = numba.jit(integrand_function, nopython=True)
    
    @cfunc(float64(intc, CPointer(float64)))
    def wrapped(n, xx):
        return jitted_function(xx[0], xx[1])
    return LowLevelCallable(wrapped.ctypes)

@jit_integrand_function
def integrand(t, *args):
    a = args[0]
    return np.exp(-t/a) / t**2

def do_integrate(func, a):
    """
    Integrate the given function from 1.0 to +inf with additional argument a.
    """
    return si.quad(func, 1, np.inf, args=(a,))

print(do_integrate(integrand, 2.))
>>>(0.326643862324553, 1.936891932288535e-10)

Or if you don't want the decorator, create the LowLevelCallable manually and pass it to quad.

2. Wrapping the integrand function

I am not sure if the following would meet your requirements but you could also wrap your integrand function to achieve the same result:

import numpy as np
from numba import cfunc
import numba.types

def get_integrand(*args):
    a = args[0]
    def integrand(t):
        return np.exp(-t/a) / t**2
    return integrand

nb_integrand = cfunc(numba.float64(numba.float64))(get_integrand(2.))

import scipy.integrate as si
def do_integrate(func):
    """
    Integrate the given function from 1.0 to +inf.
    """
    return si.quad(func, 1, np.inf)

print(do_integrate(get_integrand(2)))
>>>(0.326643862324553, 1.936891932288535e-10)
print(do_integrate(nb_integrand.ctypes))
>>>(0.326643862324553, 1.936891932288535e-10)

3. Casting from voidptr to a python type

I don't think this is possible yet. From this discussion in 2016, it seems that voidptr is only here to pass a context to a C callback.

The void * pointer case would be for APIs where foreign C code does not every try to dereference the pointer, but simply passes it back to the callback as way for the callback to retain state between calls. I don't think it is particularly important at the moment, but I wanted to raise the issue.

And trying the following:

numba.types.RawPointer('p').can_convert_to(
    numba.typing.context.Context(), CPointer(numba.types.Any)))
>>>None

doesn't seem encouraging either!


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...