Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 9b82476

Browse files
Add missing inference support for GPTNeoXForCausalLM (Pythia and GPT-NeoX base models) (ggml-org#7461)
* convert-hf : add conversion of bloom-style qkv tensor to gpt-style qkv (code borrowed from BloomModel) * llama : add inference support for LLM_ARCH_GPTNEOX * llama : add model types for every Pythia variant and GPT-NeoX Co-authored-by: Stanisław Szymczyk <[email protected]>
1 parent a61a94e commit 9b82476

File tree

2 files changed

+273
-1
lines changed

2 files changed

+273
-1
lines changed

convert-hf-to-gguf.py

+38
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,44 @@ def set_gguf_parameters(self):
673673
self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))
674674
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
675675

676+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
677+
del bid # unused
678+
679+
n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
680+
n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
681+
682+
tensors: list[tuple[str, Tensor]] = []
683+
684+
if re.match(r"gpt_neox\.layers\.\d+\.attention\.query_key_value\.weight", name):
685+
# Map bloom-style qkv_linear to gpt-style qkv_linear
686+
# bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa
687+
# gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312 # noqa
688+
qkv_weights = data_torch.reshape((n_head, 3, n_embed // n_head, n_embed))
689+
data_torch = torch.cat(
690+
(
691+
qkv_weights[:, 0, :, :].reshape((-1, n_embed)),
692+
qkv_weights[:, 1, :, :].reshape((-1, n_embed)),
693+
qkv_weights[:, 2, :, :].reshape((-1, n_embed)),
694+
),
695+
dim=0,
696+
)
697+
logger.info("re-format attention.linear_qkv.weight")
698+
elif re.match(r"gpt_neox\.layers\.\d+\.attention\.query_key_value\.bias", name):
699+
qkv_bias = data_torch.reshape((n_head, 3, n_embed // n_head))
700+
data_torch = torch.cat(
701+
(
702+
qkv_bias[:, 0, :].reshape((n_embed,)),
703+
qkv_bias[:, 1, :].reshape((n_embed,)),
704+
qkv_bias[:, 2, :].reshape((n_embed,)),
705+
),
706+
dim=0,
707+
)
708+
logger.info("re-format attention.linear_qkv.bias")
709+
710+
tensors.append((self.map_tensor_name(name), data_torch))
711+
712+
return tensors
713+
676714

677715
@Model.register("BloomForCausalLM")
678716
class BloomModel(Model):

llama.cpp

+235-1
Original file line numberDiff line numberDiff line change
@@ -1692,17 +1692,24 @@ static llama_state g_state;
16921692
// available llama models
16931693
enum e_model {
16941694
MODEL_UNKNOWN,
1695+
MODEL_14M,
16951696
MODEL_17M,
16961697
MODEL_22M,
16971698
MODEL_33M,
1699+
MODEL_70M,
16981700
MODEL_109M,
16991701
MODEL_137M,
1702+
MODEL_160M,
17001703
MODEL_335M,
1704+
MODEL_410M,
17011705
MODEL_0_5B,
17021706
MODEL_1B,
1707+
MODEL_1_4B,
17031708
MODEL_2B,
1709+
MODEL_2_8B,
17041710
MODEL_3B,
17051711
MODEL_4B,
1712+
MODEL_6_9B,
17061713
MODEL_7B,
17071714
MODEL_8B,
17081715
MODEL_12B,
@@ -1734,6 +1741,7 @@ static const size_t GiB = 1024*MiB;
17341741
struct llama_hparams {
17351742
bool vocab_only;
17361743
bool rope_finetuned;
1744+
bool use_par_res;
17371745

17381746
uint32_t n_vocab;
17391747
uint32_t n_ctx_train; // context size the model was trained on
@@ -3773,17 +3781,24 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
37733781

37743782
static const char * llama_model_type_name(e_model type) {
37753783
switch (type) {
3784+
case MODEL_14M: return "14M";
37763785
case MODEL_17M: return "17M";
37773786
case MODEL_22M: return "22M";
37783787
case MODEL_33M: return "33M";
3788+
case MODEL_70M: return "70M";
37793789
case MODEL_109M: return "109M";
37803790
case MODEL_137M: return "137M";
3791+
case MODEL_160M: return "160M";
37813792
case MODEL_335M: return "335M";
3793+
case MODEL_410M: return "410M";
37823794
case MODEL_0_5B: return "0.5B";
37833795
case MODEL_1B: return "1B";
3796+
case MODEL_1_4B: return "1.4B";
37843797
case MODEL_2B: return "2B";
3798+
case MODEL_2_8B: return "2.8B";
37853799
case MODEL_3B: return "3B";
37863800
case MODEL_4B: return "4B";
3801+
case MODEL_6_9B: return "6.9B";
37873802
case MODEL_7B: return "7B";
37883803
case MODEL_8B: return "8B";
37893804
case MODEL_12B: return "12B";
@@ -4282,6 +4297,52 @@ static void llm_load_hparams(
42824297
default: model.type = e_model::MODEL_UNKNOWN;
42834298
}
42844299
} break;
4300+
case LLM_ARCH_GPTNEOX:
4301+
{
4302+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
4303+
ml.get_key(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res);
4304+
switch (hparams.n_layer) {
4305+
case 6:
4306+
switch (hparams.n_ff) {
4307+
case 512: model.type = e_model::MODEL_14M; break;
4308+
case 2048: model.type = e_model::MODEL_70M; break;
4309+
default: model.type = e_model::MODEL_UNKNOWN;
4310+
} break;
4311+
case 12:
4312+
switch (hparams.n_ff) {
4313+
case 3072: model.type = e_model::MODEL_160M; break;
4314+
default: model.type = e_model::MODEL_UNKNOWN;
4315+
} break;
4316+
case 16:
4317+
switch (hparams.n_ff) {
4318+
case 8192: model.type = e_model::MODEL_1B; break;
4319+
default: model.type = e_model::MODEL_UNKNOWN;
4320+
} break;
4321+
case 24:
4322+
switch (hparams.n_ff) {
4323+
case 4096: model.type = e_model::MODEL_410M; break;
4324+
case 8192: model.type = e_model::MODEL_1_4B; break;
4325+
default: model.type = e_model::MODEL_UNKNOWN;
4326+
} break;
4327+
case 32:
4328+
switch (hparams.n_ff) {
4329+
case 10240: model.type = e_model::MODEL_2_8B; break;
4330+
case 16384: model.type = e_model::MODEL_6_9B; break;
4331+
default: model.type = e_model::MODEL_UNKNOWN;
4332+
} break;
4333+
case 36:
4334+
switch (hparams.n_ff) {
4335+
case 20480: model.type = e_model::MODEL_12B; break;
4336+
default: model.type = e_model::MODEL_UNKNOWN;
4337+
} break;
4338+
case 44:
4339+
switch (hparams.n_ff) {
4340+
case 24576: model.type = e_model::MODEL_20B; break;
4341+
default: model.type = e_model::MODEL_UNKNOWN;
4342+
} break;
4343+
default: model.type = e_model::MODEL_UNKNOWN;
4344+
}
4345+
} break;
42854346
default: (void)0;
42864347
}
42874348

@@ -6033,6 +6094,41 @@ static bool llm_load_tensors(
60336094
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
60346095
}
60356096
} break;
6097+
case LLM_ARCH_GPTNEOX:
6098+
{
6099+
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
6100+
// output
6101+
{
6102+
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
6103+
model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
6104+
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
6105+
}
6106+
6107+
for (int i = 0; i < n_layer; ++i) {
6108+
ggml_context * ctx_layer = ctx_for_layer(i);
6109+
ggml_context * ctx_split = ctx_for_layer_split(i);
6110+
6111+
auto & layer = model.layers[i];
6112+
6113+
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
6114+
layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd});
6115+
6116+
layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
6117+
layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa});
6118+
6119+
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
6120+
layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd});
6121+
6122+
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
6123+
layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd});
6124+
6125+
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
6126+
layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd});
6127+
6128+
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
6129+
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff});
6130+
}
6131+
} break;
60366132
default:
60376133
throw std::runtime_error("unknown architecture");
60386134
}
@@ -10560,6 +10656,140 @@ struct llm_build_context {
1056010656

1056110657
return gf;
1056210658
}
10659+
10660+
struct ggml_cgraph * build_gptneox() {
10661+
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
10662+
10663+
const int64_t n_embd_head = hparams.n_embd_head_v;
10664+
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
10665+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
10666+
10667+
struct ggml_tensor * cur;
10668+
struct ggml_tensor * inpL;
10669+
10670+
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
10671+
10672+
// inp_pos - contains the positions
10673+
struct ggml_tensor * inp_pos = build_inp_pos();
10674+
10675+
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
10676+
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
10677+
10678+
for (int il = 0; il < n_layer; ++il) {
10679+
cur = llm_build_norm(ctx0, inpL, hparams,
10680+
model.layers[il].attn_norm,
10681+
model.layers[il].attn_norm_b,
10682+
LLM_NORM, cb, il);
10683+
cb(cur, "attn_norm", il);
10684+
10685+
// self-attention
10686+
{
10687+
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
10688+
cb(cur, "wqkv", il);
10689+
10690+
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
10691+
cb(cur, "bqkv", il);
10692+
10693+
struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
10694+
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
10695+
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
10696+
10697+
cb(Qcur, "Qcur", il);
10698+
cb(Kcur, "Kcur", il);
10699+
cb(Vcur, "Vcur", il);
10700+
10701+
Qcur = ggml_rope_ext(
10702+
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
10703+
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
10704+
ext_factor, attn_factor, beta_fast, beta_slow
10705+
);
10706+
cb(Qcur, "Qcur", il);
10707+
10708+
Kcur = ggml_rope_ext(
10709+
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
10710+
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
10711+
ext_factor, attn_factor, beta_fast, beta_slow
10712+
);
10713+
cb(Kcur, "Kcur", il);
10714+
10715+
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
10716+
model.layers[il].wo, model.layers[il].bo,
10717+
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
10718+
}
10719+
10720+
if (il == n_layer - 1) {
10721+
// skip computing output for unused tokens
10722+
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
10723+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
10724+
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
10725+
}
10726+
10727+
// ffn
10728+
if (hparams.use_par_res) {
10729+
// attention and ffn are computed in parallel
10730+
// x = x + attn(ln1(x)) + ffn(ln2(x))
10731+
10732+
struct ggml_tensor * attn_out = cur;
10733+
10734+
cur = llm_build_norm(ctx0, inpL, hparams,
10735+
model.layers[il].ffn_norm,
10736+
model.layers[il].ffn_norm_b,
10737+
LLM_NORM, cb, il);
10738+
cb(cur, "ffn_norm", il);
10739+
10740+
cur = llm_build_ffn(ctx0, cur,
10741+
model.layers[il].ffn_up, model.layers[il].ffn_up_b,
10742+
NULL, NULL,
10743+
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
10744+
NULL,
10745+
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
10746+
cb(cur, "ffn_out", il);
10747+
10748+
cur = ggml_add(ctx0, cur, inpL);
10749+
cb(cur, "ffn_out", il);
10750+
10751+
inpL = ggml_add(ctx0, cur, attn_out);
10752+
cb(inpL, "l_out", il);
10753+
} else {
10754+
// attention and ffn are computed sequentially
10755+
// x = x + attn(ln1(x))
10756+
// x = x + ffn(ln2(x))
10757+
10758+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
10759+
cb(ffn_inp, "ffn_inp", il);
10760+
10761+
cur = llm_build_norm(ctx0, ffn_inp, hparams,
10762+
model.layers[il].ffn_norm,
10763+
model.layers[il].ffn_norm_b,
10764+
LLM_NORM, cb, il);
10765+
cb(cur, "ffn_norm", il);
10766+
10767+
cur = llm_build_ffn(ctx0, cur,
10768+
model.layers[il].ffn_up, model.layers[il].ffn_up_b,
10769+
NULL, NULL,
10770+
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
10771+
NULL,
10772+
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
10773+
cb(cur, "ffn_out", il);
10774+
10775+
inpL = ggml_add(ctx0, cur, ffn_inp);
10776+
cb(inpL, "l_out", il);
10777+
}
10778+
}
10779+
10780+
cur = llm_build_norm(ctx0, inpL, hparams,
10781+
model.output_norm,
10782+
model.output_norm_b,
10783+
LLM_NORM, cb, -1);
10784+
cb(cur, "result_norm", -1);
10785+
10786+
cur = ggml_mul_mat(ctx0, model.output, cur);
10787+
cb(cur, "result_output", -1);
10788+
10789+
ggml_build_forward_expand(gf, cur);
10790+
10791+
return gf;
10792+
}
1056310793
};
1056410794

1056510795
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@@ -10770,6 +11000,10 @@ static struct ggml_cgraph * llama_build_graph(
1077011000
{
1077111001
result = llm.build_olmo();
1077211002
} break;
11003+
case LLM_ARCH_GPTNEOX:
11004+
{
11005+
result = llm.build_gptneox();
11006+
} break;
1077311007
default:
1077411008
GGML_ASSERT(false);
1077511009
}
@@ -15762,7 +15996,6 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
1576215996
// these models do not use RoPE
1576315997
case LLM_ARCH_GPT2:
1576415998
case LLM_ARCH_GPTJ:
15765-
case LLM_ARCH_GPTNEOX:
1576615999
case LLM_ARCH_MPT:
1576716000
case LLM_ARCH_REFACT:
1576816001
case LLM_ARCH_BLOOM:
@@ -15798,6 +16031,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
1579816031
case LLM_ARCH_PHI3:
1579916032
case LLM_ARCH_GEMMA:
1580016033
case LLM_ARCH_STARCODER2:
16034+
case LLM_ARCH_GPTNEOX:
1580116035
return LLAMA_ROPE_TYPE_NEOX;
1580216036

1580316037
// all model arches should be listed explicitly here

0 commit comments

Comments
 (0)