Thanks to visit codestin.com
Credit goes to github.com

Skip to content

当模型predict参数为ndarray的时候会将参数修改成list #100

@danerlt

Description

@danerlt

我有一个推理函数如下:

@log_execution_time
def batch_predict(src_data: pd.DataFrame) -> list:
    torque_angle_trace: ndarray = preprocess(src_data)
    start = time.time()
    predict_res = model.predict(torque_angle_trace)
    ene = time.time()
    logger.info(f"单个预测耗时:{ene - start}")
    res = predict_res.tolist()
    return res

其中model.predict接收的参数是一个ndarry
下面是使用ThreadedStreamer之后的代码:

stream = ThreadedStreamer(model.predict, batch_size=10, max_latency=0.1)


def batch_predict_stream(src_data: pd.DataFrame) -> list:
    start = time.time()
    torque_angle_trace: ndarray = preprocess(src_data)
    predict_res = stream.predict([torque_angle_trace])
    ene = time.time()
    logger.info(f"stream预测耗时:{ene - start}")
    return predict_res

在调用stream.predict的时候我已经将数据处理成ndarray传进去了。
然后在运行的时候提示TypeError: X is not of a supported input data type.X must be in a supported mtype format for Panel, found <class 'list'>Use datatypes.check_is_mtype to check conformance with specifications.

我查看源码之后发现问题在下图将队列中取到的数据放到了一个list中然后传递给predict函数

image

请问是我使用ThreadedStreamer方法不对,还是predict函数不支持ndarray的参数。

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions