Thanks to visit codestin.com
Credit goes to github.com

Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 52 additions & 42 deletions src/operator/subgraph/dnnl/dnnl_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ class SgDNNLSelfAttQKOp {
void Forward(const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);
const std::vector<NDArray>& outputs,
bool already_prepared);

void Backward(const OpContext& ctx,
const std::vector<NDArray>& inputs,
Expand Down Expand Up @@ -163,10 +164,12 @@ static void SgDNNLSelfAttQKForward(const OpStatePtr& state_pointer,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
SgDNNLSelfAttQKOp& op = state_pointer.get_state<SgDNNLSelfAttQKOp>();
bool already_prepared = false;
if (!op.IsInitialized()) {
op.Initialize(ctx, inputs, req, outputs);
already_prepared = true;
}
op.Forward(ctx, inputs, req, outputs);
op.Forward(ctx, inputs, req, outputs, already_prepared);
}

static bool SgDNNLSelfAttStorageType(const nnvm::NodeAttrs& attrs,
Expand Down Expand Up @@ -264,21 +267,23 @@ void SgDNNLSelfAttQKOp::Initialize(const OpContext& ctx,
void SgDNNLSelfAttQKOp::Forward(const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const size_t output_lin_dim = inputs[0].shape()[2];
const size_t embed_dim = output_lin_dim / QKV_NUM;

MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, {
DType* query_mem_ptr = inputs[0].data().dptr<DType>();
DType* key_mem_ptr = query_mem_ptr + embed_dim;
cached_query_mem_->set_data_handle(query_mem_ptr);
cached_key_mem_->set_data_handle(key_mem_ptr);
});

MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
cached_out_mem_->set_data_handle(outputs[0].data().dptr<DType>());
});

const std::vector<NDArray>& outputs,
bool already_prepared) {
if (!already_prepared) {
const size_t output_lin_dim = inputs[0].shape()[2];
const size_t embed_dim = output_lin_dim / QKV_NUM;

MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, {
DType* query_mem_ptr = inputs[0].data().dptr<DType>();
DType* key_mem_ptr = query_mem_ptr + embed_dim;
cached_query_mem_->set_data_handle(query_mem_ptr);
cached_key_mem_->set_data_handle(key_mem_ptr);
});

MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
cached_out_mem_->set_data_handle(outputs[0].data().dptr<DType>());
});
}
DNNLStream::Get()->RegisterPrimArgs(*fwd_, args_);
DNNLStream::Get()->Submit();

Expand Down Expand Up @@ -483,7 +488,8 @@ class DNNLSelfAttValAttOp {
void Forward(const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);
const std::vector<NDArray>& outputs,
bool already_prepared);

void Backward(const OpContext& ctx,
const std::vector<NDArray>& inputs,
Expand Down Expand Up @@ -537,10 +543,12 @@ static void DNNLSelfAttValAttForward(const OpStatePtr& state_pointer,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
DNNLSelfAttValAttOp& op = state_pointer.get_state<DNNLSelfAttValAttOp>();
bool already_prepared = false;
if (!op.IsInitialized()) {
op.Initialize(ctx, inputs, req, outputs);
already_prepared = true;
}
op.Forward(ctx, inputs, req, outputs);
op.Forward(ctx, inputs, req, outputs, already_prepared);
}

void DNNLSelfAttValAttOp::Initialize(const OpContext& ctx,
Expand Down Expand Up @@ -663,29 +671,31 @@ void DNNLSelfAttValAttOp::Initialize(const OpContext& ctx,
void DNNLSelfAttValAttOp::Forward(const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
// multiply by 2 as we need to skip queries and keys
const size_t value_offset = inputs[1].shape()[2] / QKV_NUM * 2;

auto att_buffer = inputs[0];
if (att_buffer.IsDNNLData())
att_buffer = att_buffer.Reorder2Default();

MSHADOW_TYPE_SWITCH(att_buffer.dtype(), DType, {
DType* attention_ptr = att_buffer.data().dptr<DType>();
cached_att_mem_->set_data_handle(attention_ptr);
});

MSHADOW_TYPE_SWITCH(inputs[1].dtype(), DType, {
DType* qkv_ptr = inputs[1].data().dptr<DType>();
DType* value_mem_ptr = qkv_ptr + value_offset;
cached_value_mem_->set_data_handle(value_mem_ptr);
});

MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
cached_transposed_mem_->set_data_handle(outputs[0].data().dptr<DType>());
});

const std::vector<NDArray>& outputs,
bool already_prepared) {
if (!already_prepared) {
// multiply by 2 as we need to skip queries and keys
const size_t value_offset = inputs[1].shape()[2] / QKV_NUM * 2;

auto att_buffer = inputs[0];
if (att_buffer.IsDNNLData())
att_buffer = att_buffer.Reorder2Default();

MSHADOW_TYPE_SWITCH(att_buffer.dtype(), DType, {
DType* attention_ptr = att_buffer.data().dptr<DType>();
cached_att_mem_->set_data_handle(attention_ptr);
});

MSHADOW_TYPE_SWITCH(inputs[1].dtype(), DType, {
DType* qkv_ptr = inputs[1].data().dptr<DType>();
DType* value_mem_ptr = qkv_ptr + value_offset;
cached_value_mem_->set_data_handle(value_mem_ptr);
});

MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
cached_transposed_mem_->set_data_handle(outputs[0].data().dptr<DType>());
});
}
DNNLStream::Get()->RegisterPrimArgs(*fwd_, args_);
DNNLStream::Get()->RegisterPrimArgs(*reorder_, reorder_args);
DNNLStream::Get()->Submit();
Expand Down