ml-rust/cl/forward.cl

15 lines
547 B
Common Lisp

__kernel void forward(__global float* activate, __global float* output, __global float* input, __global float* mul, __global float* add, int input_width, int param_width)
{
int current_node = get_global_id(0);
float value = 0;
for (int input_node = 0; input_node < input_width; input_node++)
{
float input = input[input_node];
float factor = mul[input_node * param_width + current_node];
value += input * factor + add[current_node];
}
output[current_node] = value;
activate[current_node] = tanh(value);
}