From d2d9b1915828256bd468328bd0c55d24800651fb Mon Sep 17 00:00:00 2001 From: Dongkyu Kim Date: Mon, 18 Dec 2023 11:17:53 +0900 Subject: [PATCH 1/5] Add eALS (#77) --- buffalo/__init__.py | 3 +- buffalo/algo/_eals.pyx | 63 ++++++ buffalo/algo/eals.py | 192 ++++++++++++++++ buffalo/algo/options.py | 37 ++++ include/buffalo/algo_impl/eals/eals.hpp | 84 +++++++ include/buffalo/misc/blas.hpp | 79 +++++++ lib/algo_impl/eals/eals.cc | 281 ++++++++++++++++++++++++ setup.py | 7 + tests/algo/test_eals.py | 85 +++++++ 9 files changed, 830 insertions(+), 1 deletion(-) create mode 100644 buffalo/algo/_eals.pyx create mode 100644 buffalo/algo/eals.py create mode 100644 include/buffalo/algo_impl/eals/eals.hpp create mode 100644 include/buffalo/misc/blas.hpp create mode 100644 lib/algo_impl/eals/eals.cc create mode 100644 tests/algo/test_eals.py diff --git a/buffalo/__init__.py b/buffalo/__init__.py index 2d82821..e89530e 100644 --- a/buffalo/__init__.py +++ b/buffalo/__init__.py @@ -3,10 +3,11 @@ __version__ = importlib.metadata.version('buffalo') from buffalo.algo.als import ALS, inited_CUALS +from buffalo.algo.eals import EALS from buffalo.algo.base import Algo from buffalo.algo.bpr import BPRMF, inited_CUBPR from buffalo.algo.cfr import CFR -from buffalo.algo.options import (AlgoOption, ALSOption, BPRMFOption, +from buffalo.algo.options import (AlgoOption, ALSOption, EALSOption, BPRMFOption, CFROption, PLSIOption, W2VOption, WARPOption) from buffalo.algo.plsi import PLSI from buffalo.algo.w2v import W2V diff --git a/buffalo/algo/_eals.pyx b/buffalo/algo/_eals.pyx new file mode 100644 index 0000000..c7448d4 --- /dev/null +++ b/buffalo/algo/_eals.pyx @@ -0,0 +1,63 @@ +# cython: language_level=3, boundscheck=False, wraparound=False +# distutils: language=c++ + +cimport numpy as np +from libc.stdint cimport int32_t, int64_t +from libcpp.pair cimport pair +from libcpp.string cimport string + +np.import_array() + + +cdef extern from "buffalo/algo_impl/eals/eals.hpp" namespace "eals": + cdef cppclass CEALS: + bint init(string) nogil except + + void initialize_model(float*, float*, float*, int, int) nogil except + + void precompute_cache(int, int64_t*, int32_t*, int) nogil except + + bint update(int64_t*, int32_t*, float*, int) nogil except + + pair[float, float] estimate_loss(int, + int64_t*, + int32_t*, + float*, + int) nogil except + + +cdef class CyEALS: + """CEALS object holder""" + cdef CEALS* _obj # C-EALS object + + def __cinit__(self): + self._obj = new CEALS() + + def __dealloc__(self): + del self._obj + + def init(self, option_path): + return self._obj.init(option_path) + + def initialize_model(self, + np.ndarray[np.float32_t, ndim=2] P, + np.ndarray[np.float32_t, ndim=2] Q, + np.ndarray[np.float32_t, ndim=1] C): + self._obj.initialize_model(&P[0, 0], &Q[0, 0], &C[0], P.shape[0], Q.shape[0]) + + def precompute_cache(self, + int nnz, + np.ndarray[np.int64_t, ndim=1] indptr, + np.ndarray[np.int32_t, ndim=1] keys, + int axis): + self._obj.precompute_cache(nnz, &indptr[0], &keys[0], axis) + + def update(self, + np.ndarray[np.int64_t, ndim=1] indptr, + np.ndarray[np.int32_t, ndim=1] keys, + np.ndarray[np.float32_t, ndim=1] vals, + int axis): + return self._obj.update(&indptr[0], &keys[0], &vals[0], axis) + + def estimate_loss(self, + int nnz, + np.ndarray[np.int64_t, ndim=1] indptr, + np.ndarray[np.int32_t, ndim=1] keys, + np.ndarray[np.float32_t, ndim=1] vals, + int axis): + return self._obj.estimate_loss(nnz, &indptr[0], &keys[0], &vals[0], axis) diff --git a/buffalo/algo/eals.py b/buffalo/algo/eals.py new file mode 100644 index 0000000..5308584 --- /dev/null +++ b/buffalo/algo/eals.py @@ -0,0 +1,192 @@ +import json +import time +from typing import Callable, Dict, Optional + +import numpy as np + +import buffalo.data +from buffalo.algo._eals import CyEALS +from buffalo.algo.base import Algo, Serializable +from buffalo.algo.options import EALSOption +from buffalo.data.base import Data +from buffalo.evaluate import Evaluable +from buffalo.misc import aux, log + + +class EALS(Algo, EALSOption, Evaluable, Serializable): + """Python implementation for C-EALS. + + Implementation of Fast Matrix Factorization for Online Recommendation. + + Reference: https://arxiv.org/abs/1708.05024""" + def __init__(self, opt_path=None, *args, **kwargs): + Algo.__init__(self, *args, **kwargs) + EALSOption.__init__(self, *args, **kwargs) + Evaluable.__init__(self, *args, **kwargs) + Serializable.__init__(self, *args, **kwargs) + if opt_path is None: + opt_path = EALSOption().get_default_option() + + self.logger = log.get_logger("EALS") + self.group2axis = {"rowwise": 0, "colwise": 1} + self.opt, self.opt_path = self.get_option(opt_path) + self.obj = CyEALS() + assert self.obj.init(bytes(self.opt_path, "utf-8")), "cannot parse option file: %s" % opt_path + + self.data = None + data = kwargs.get("data") + data_opt = self.opt.get("data_opt") + data_opt = kwargs.get("data_opt", data_opt) + if data_opt: + self.data = buffalo.data.load(data_opt) + self.data.create() + elif isinstance(data, Data): + self.data = data + self.logger.info("eALS(%s)" % json.dumps(self.opt, indent=2)) + if self.data: + self.logger.info(self.data.show_info()) + assert self.data.data_type in ["matrix"] + + @staticmethod + def new(path, data_fields=[]): + return EALS.instantiate(EALSOption, path, data_fields) + + def set_data(self, data): + assert isinstance(data, aux.data.Data), "Wrong instance: {}".format(type(data)) + self.data = data + + def normalize(self, group="item"): + if group == "item" and not self.opt._nrz_Q: + self.Q = self._normalize(self.Q) + self.opt._nrz_Q = True + elif group == "user" and not self.opt._nrz_P: + self.P = self._normalize(self.P) + self.opt._nrz_P = True + + def initialize(self): + super().initialize() + self.init_factors() + + def init_factors(self): + assert self.data, "Data is not set" + self.vdim = self.opt.d + header = self.data.get_header() + self._nnz = header["num_nnz"] + for name, rows in [("P", header["num_users"]), ("Q", header["num_items"])]: + setattr(self, name, None) + setattr(self, name, np.random.normal(scale=1.0 / (self.opt.d ** 2), + size=(rows, self.vdim)).astype("float32")) + self.P[:, self.opt.d:] = 0.0 + self.Q[:, self.opt.d:] = 0.0 + self.C = self._get_negative_weights() + + self.obj.initialize_model(self.P, self.Q, self.C) + + def _get_topk_recommendation(self, rows, topk, pool=None): + p = self.P[rows] + topks = super()._get_topk_recommendation( + p, self.Q, + pb=None, Qb=None, + pool=pool, topk=topk, num_workers=self.opt.num_workers) + return zip(rows, topks) + + def _get_most_similar_item(self, col, topk, pool): + return super()._get_most_similar_item(col, topk, self.Q, self.opt._nrz_Q, pool) + + def get_scores(self, row_col_pairs): + rets = {(r, c): self.P[r].dot(self.Q[c]) for r, c in row_col_pairs} + return rets + + def _get_scores(self, row, col): + scores = (self.P[row] * self.Q[col]).sum(axis=1) + return scores + + def _get_negative_weights(self): + # Get item popularity from self.data + indptr, _, _ = self._get_mm_data(group="colwise") + pop = np.array([indptr[i] - (0 if i == 0 else indptr[i - 1]) for i in range(len(indptr))], dtype="float32") + assert len(pop) == self.data.get_header()["num_items"] + # Return negative weights calculated by the power-law weighting scheme + pop /= max(pop) + pop_with_exponent = pop**self.opt.get("exponent", 0.0) + return self.opt.get("c0", 1.0) * pop_with_exponent / sum(pop_with_exponent) + + def _get_mm_data(self, group): + group = self.data.get_group(group) + indptr = group["indptr"][:] + keys = group["key"][:] + vals = group["val"][:] + return indptr, keys, vals + + def _precompute_cache(self): + for group in ["rowwise", "colwise"]: + indptr, keys, _ = self._get_mm_data(group) + axis = self.group2axis[group] + self.obj.precompute_cache(self._nnz, indptr, keys, axis) + + def _iterate(self, group): + indptr, keys, vals = self._get_mm_data(group) + axis = self.group2axis[group] + assert self.obj.update(indptr, keys, vals, axis) + + def _get_loss(self): + axis = self.group2axis[(group := "rowwise")] + indptr, keys, vals = self._get_mm_data(group=group) + # loss: RMSE / total_loss := RMSE^2 + L2-loss + Negative Feedbacks + loss, total_loss = self.obj.estimate_loss(self._nnz, indptr, keys, vals, axis) + return loss, total_loss + + def train(self, training_callback: Optional[Callable[[int, Dict[str, float]], None]] = None): + best_loss, loss, self.validation_result = float("inf"), None, {} + full_st = time.time() + + self._precompute_cache() + + for i in range(self.opt.num_iters): + start_t = time.time() + self._iterate(group="rowwise") + self._iterate(group="colwise") + loss, total_loss = self._get_loss() + + train_t = time.time() - start_t + metrics = {"train_loss": loss} + if self.opt.validation and \ + self.opt.evaluation_on_learning and \ + self.periodical(self.opt.evaluation_period, i): + start_t = time.time() + self.validation_result = self.get_validation_results() + vali_t = time.time() - start_t + val_str = " ".join([f"{k}:{v:0.5f}" for k, v in self.validation_result.items()]) + self.logger.info(f"Validation: {val_str} Elapsed {vali_t:0.3f} secs") + metrics.update({"val_%s" % k: v + for k, v in self.validation_result.items()}) + if training_callback is not None and callable(training_callback): + training_callback(i, metrics) + self.logger.info("Iteration %d: RMSE %.3f TotalLoss %.3f Elapsed %.3f secs" % (i + 1, loss, (total_loss / self._nnz), train_t)) + best_loss = self.save_best_only(loss, best_loss, i) + if self.early_stopping(loss): + break + + full_el = time.time() - full_st + self.logger.info(f"elapsed for full epochs: {full_el:.2f} sec") + ret = {"train_loss": loss} + ret.update({"val_%s" % k: v + for k, v in self.validation_result.items()}) + return ret + + def _get_feature(self, index, group="item"): + if group == "item": + return self.Q[index] + elif group == "user": + return self.P[index] + return None + + def _get_data(self): + data = super()._get_data() + data.extend([("opt", self.opt), + ("Q", self.Q), + ("P", self.P)]) + return data + + def get_evaluation_metrics(self): + return ["train_loss", "val_rmse", "val_ndcg", "val_map", "val_accuracy", "val_error"] diff --git a/buffalo/algo/options.py b/buffalo/algo/options.py index cdf8c00..ab60148 100644 --- a/buffalo/algo/options.py +++ b/buffalo/algo/options.py @@ -95,6 +95,43 @@ def is_valid_option(self, opt): return b +class EALSOption(AlgoOption): + def __init__(self, *args, **kwargs): + super(EALSOption, self).__init__(*args, **kwargs) + + def get_default_option(self): + """Options for Alternating Least Squares. + + :ivar bool save_factors: Set True, to save models. (default: False) + :ivar int d: The number of latent feature dimension. (default: 20) + :ivar int num_iters: The number of iterations for training. (default: 10) + :ivar int num_workers: The number of threads. (default: 1) + :ivar float reg_u: The L2 regularization coefficient for user embedding matrix. (default: 0.1) + :ivar float reg_i: The L2 regularization coefficient for item embedding matrix. (default: 0.1) + :ivar float alpha: The coefficient of giving more weights to losses on positive samples. (default: 8) + :ivar float c0: The strength of the negative feedbacks + :ivar float exponent: exponent to item popularity for the negative feedbacks + :ivar str model_path: Where to save model. + :ivar dict data_opt: This option will be used to load data if given. + """ + + opt = super().get_default_option() + opt.update({ + "save_factors": False, + "d": 20, + "num_iters": 10, + "num_workers": 1, + "reg_u": 0.1, + "reg_i": 0.1, + "alpha": 8.0, + "c0": 512.0, + "exponent": 0.5, + "model_path": "", + "data_opt": {} + }) + return aux.Option(opt) + + class CFROption(AlgoOption): def __init__(self, *args, **kwargs): super(CFROption, self).__init__(*args, **kwargs) diff --git a/include/buffalo/algo_impl/eals/eals.hpp b/include/buffalo/algo_impl/eals/eals.hpp new file mode 100644 index 0000000..42204f6 --- /dev/null +++ b/include/buffalo/algo_impl/eals/eals.hpp @@ -0,0 +1,84 @@ +#pragma once +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "buffalo/algo.hpp" +#include "buffalo/misc/blas.hpp" +using namespace std; +using namespace Eigen; + + +namespace eals { + + +class CEALS : public Algorithm { + public: + CEALS(); + virtual ~CEALS(); + + virtual bool init(string opt_path); + virtual bool parse_option(string opt_path); + void _leastsquare(Map& X, + int idx, + MatrixType& A, + VectorType& y) = delete; + void initialize_model(float* P_ptr, + float* Q_ptr, + float* C_ptr, + const int32_t P_rows, + const int32_t Q_rows); + void precompute_cache(const int32_t nnz, + const int64_t* indptr, + const int32_t* keys, + const int32_t axis); + bool update(const int64_t* indptr, + const int32_t* keys, + const float* vals, + const int32_t axis); + pair estimate_loss(const int32_t nnz, + const int64_t* indptr, + const int32_t* keys, + const float* vals, + const int32_t axis); + + private: + class IdxCoord { + public: + void set(int32_t row, int32_t col, int64_t key) { + row_ = row; + col_ = col; + key_ = key; + } + int32_t get_row() const { return row_; } + int32_t get_col() const { return col_; } + int64_t get_key() const { return key_; } + private: + int32_t row_, col_; + int64_t key_; + }; + + void update_P_(const int64_t* indptr, + const int32_t* keys, + const float* vals); + void update_Q_(const int64_t* indptr, + const int32_t* keys, + const float* vals); + + Json opt_; + int32_t P_rows_, Q_rows_; + bool is_P_cached_, is_Q_cached_; + float* P_ptr_, * Q_ptr_, * C_ptr_; + vector vhat_cache_u_, vhat_cache_i_; + vector ind_u2i_, ind_i2u_; + const float kOne, kZero; +}; + +} diff --git a/include/buffalo/misc/blas.hpp b/include/buffalo/misc/blas.hpp new file mode 100644 index 0000000..0986977 --- /dev/null +++ b/include/buffalo/misc/blas.hpp @@ -0,0 +1,79 @@ +#pragma once + + +extern "C" { +// blas subroutines +void ssyrk_(const char*, const char*, const int*, const int*, const float*, const float*, const int*, + const float*, float*, const int*); + +void dsyrk_(const char*, const char*, const int*, const int*, const double*, const double*, const int*, + const double*, double*, const int*); +} + +namespace blas { +namespace impl { +inline void syrk(const char uplo, + const char trans, + const int n, + const int k, + const float alpha, + const float* A, + const int lda, + const float beta, + float* C, + const int ldc) { + ssyrk_(&uplo, &trans, &n, &k, &alpha, A, &lda, &beta, C, &ldc); +} + +inline void syrk(const char uplo, + const char trans, + const int n, + const int k, + const double alpha, + const double* A, + const int lda, + const double beta, + double* C, + const int ldc) { + dsyrk_(&uplo, &trans, &n, &k, &alpha, A, &lda, &beta, C, &ldc); +} +} // end namespace impl + +namespace etc { +template T max(const T a, const T b) { return ((a > b) ? a : b); } +template T min(const T a, const T b) { return ((a > b) ? b : a); } +template +void fill_left_elems(T* A, const int m, const std::string uplo) { + if (!uplo.compare("u")) { + for (int i=0; i < (m - 1); ++i) { + for (int j=(i + 1); j < m; ++j) { + A[j*m + i] = A[i*m + j]; + } + } + } else if (!uplo.compare("l")) { + for (int i=1; i < m; ++i) { + for (int j=0; j < i; ++j) { + A[j*m + i] = A[i*m + j]; + } + } + } +} +} // end namespace etc + +template +void syrk(const std::string uplo, + const std::string trans, + const int n, + const int k, + const T alpha, + const T* A, + const T beta, + T* C) { + const char uplo_ = (uplo.c_str()[0] == 'u')? 'l' : 'u'; + const char trans_ = (trans.c_str()[0] == 't')? 'n' : 't'; + const int lda = (trans_ == 'n')? etc::max(1, n) : etc::max(1, k); + const int ldc = etc::max(1, n); + impl::syrk(uplo_, trans_, n, k, alpha, A, lda, beta, C, ldc); + etc::fill_left_elems(C, n, uplo); +} +} // end namespace blas diff --git a/lib/algo_impl/eals/eals.cc b/lib/algo_impl/eals/eals.cc new file mode 100644 index 0000000..9b5993d --- /dev/null +++ b/lib/algo_impl/eals/eals.cc @@ -0,0 +1,281 @@ +#include "buffalo/algo_impl/eals/eals.hpp" + + +namespace eals { + +CEALS::CEALS() + : P_ptr_(nullptr), + Q_ptr_(nullptr), + C_ptr_(nullptr), + P_rows_(-1), + Q_rows_(-1), + is_P_cached_(false), + is_Q_cached_(false), + kOne(1.0), + kZero(0.0) {} + +CEALS::~CEALS() {} + +bool CEALS::init(string opt_path) { + const bool ok = parse_option(opt_path); + if (ok) { + const int32_t num_workers = opt_["num_workers"].int_value(); + omp_set_num_threads(num_workers); + } + return ok; +} + +bool CEALS::parse_option(string opt_path) { + const bool ok = Algorithm::parse_option(opt_path, opt_); + return ok; +} + +void CEALS::initialize_model(float* P_ptr, + float* Q_ptr, + float* C_ptr, + const int32_t P_rows, + const int32_t Q_rows) { + const int32_t D = opt_["d"].int_value(); + P_ptr_ = P_ptr; + Q_ptr_ = Q_ptr; + C_ptr_ = C_ptr; + P_rows_ = P_rows; + Q_rows_ = Q_rows; + is_P_cached_ = false; + is_Q_cached_ = false; + DEBUG("P({} x {}) Q({} x {}) set", P_rows_, D, Q_rows_, D); +} + +void CEALS::precompute_cache(const int32_t nnz, + const int64_t* indptr, + const int32_t* keys, + const int32_t axis) { + bool& is_cached = axis == 0 ? is_P_cached_ : is_Q_cached_; + if (is_cached) { + return; + } + const float* P = axis == 0 ? P_ptr_ : Q_ptr_; + const float* Q = axis == 0 ? Q_ptr_ : P_ptr_; + auto& vhat_cache = axis == 0 ? vhat_cache_u_ : vhat_cache_i_; + if (nnz != vhat_cache.size()) { + vhat_cache.resize(nnz); + } + vector idx_xmajor(nnz); + const int32_t end_loop = axis == 0 ? P_rows_ : Q_rows_; + const int32_t D = opt_["d"].int_value(); + #pragma omp parallel for schedule(dynamic, 8) + for (int32_t xidx=0; xidx < end_loop; ++xidx) { + const int64_t beg = xidx == 0 ? 0 : indptr[xidx - 1]; + const int64_t end = indptr[xidx]; + int64_t data_size = end - beg; + if (data_size == 0) { + TRACE("No data exists for {}", xidx); + continue; + } + const float* p_ptr = &P[xidx * D]; + for (int64_t ind=beg; ind < end; ++ind) { + const int32_t yidx = keys[ind]; + const float* q_ptr = &Q[yidx * D]; + vhat_cache[ind] = inner_product(p_ptr, p_ptr + D, q_ptr, kZero); + idx_xmajor[ind].set(yidx, xidx, ind); + } + } + auto& ind_mapper = axis == 0 ? ind_u2i_ : ind_i2u_; + if (nnz != ind_mapper.size()) { + ind_mapper.resize(nnz); + } + sort(idx_xmajor.begin(), idx_xmajor.end(), + [](const IdxCoord& a, const IdxCoord& b) -> bool { + if (a.get_row() == b.get_row()) { + return a.get_col() < b.get_col(); + } else { + return a.get_row() < b.get_row(); + } + }); + for (int32_t ind=0; ind < nnz; ++ind) { + const int64_t key = idx_xmajor[ind].get_key(); + ind_mapper[key] = ind; + } + is_cached = true; +} + +bool CEALS::update(const int64_t* indptr, + const int32_t* keys, + const float* vals, + const int32_t axis) { + const bool is_cached = is_P_cached_ && is_Q_cached_; + if (is_cached) { + if (axis == 0) { + this->update_P_(indptr, keys, vals); + } else { + this->update_Q_(indptr, keys, vals); + } + } + return is_cached; +} + +pair CEALS::estimate_loss(const int32_t nnz, + const int64_t* indptr, + const int32_t* keys, + const float* vals, + const int32_t axis) { + // Loss := sum_{(u,i) \in R} (1 + alpha * v_{u,i}) * (v_{u,i} - vHat_{u,i})^2 + + // sum_{u} sum_{i \notin R_u} C_{i} * vHat_{u,i}^2 + + // reg_u * |P|^2 + reg_i * |Q|^2 + + const bool is_cached = is_P_cached_ && is_Q_cached_; + if (!is_cached) { + TRACE("Compute cache first(P: {}, Q: {}). Empty data is returned.", is_P_cached_, is_Q_cached_); + return pair(kZero, kZero); + } + const int32_t num_workers = opt_["num_workers"].int_value(); + const float alpha = opt_["alpha"].number_value(); + float feedbacks = kZero; + float squared_error = kZero; + auto& vhat_cache = axis == 0 ? vhat_cache_u_ : vhat_cache_i_; + const int32_t end_loop = axis == 0 ? P_rows_ : Q_rows_; + for (int32_t xidx=0; xidx < end_loop; ++xidx) { + const int32_t tid = omp_get_thread_num(); + const int64_t beg = xidx == 0 ? 0 : indptr[xidx - 1]; + const int64_t end = indptr[xidx]; + for (int64_t ind=beg; ind < end; ++ind) { + const int32_t yidx = keys[ind]; + const float v = vals[ind]; + const float vhat = vhat_cache[ind]; + const float error = (v - vhat); + feedbacks += (kOne + alpha * v) * error * error; + const int32_t iidx = axis == 0 ? yidx : xidx; + feedbacks -= C_ptr_[iidx] * vhat * vhat; // To avoid duplication in a term of negative feedbacks. + squared_error += error * error; + } + } + + // Add L2 regularization terms + const int32_t D = opt_["d"].int_value(); + const float reg_u = opt_["reg_u"].number_value(); + const float reg_i = opt_["reg_i"].number_value(); + auto op = [](const float & init, const float & v) -> float { + return init + v * v; + }; + const float reg = reg_u * accumulate(P_ptr_, P_ptr_ + (P_rows_ * D), kZero, op) + + reg_i * accumulate(Q_ptr_, Q_ptr_ + (Q_rows_ * D), kZero, op); + + // Add negative feedbacks: sum_{u} sum_{i \in R_u} C_{i} * vHat_{u,i}^2 + vector CQ(Q_rows_ * D, kZero); + #pragma omp parallel for schedule(dynamic, 8) + for (int32_t iidx=0; iidx < Q_rows_; ++iidx) { + const int32_t ridx = iidx * D; + for (int32_t d=0; d < D; ++d) { + CQ[ridx + d] = sqrt(C_ptr_[iidx]) * Q_ptr_[ridx + d]; + } + } + vector Sp(D * D, kZero); + vector Sq(D * D, kZero); + blas::syrk("u", "t", D, P_rows_, kOne, P_ptr_, kZero, Sp.data()); + blas::syrk("u", "t", D, Q_rows_, kOne, CQ.data(), kZero, Sq.data()); + feedbacks += inner_product(Sp.data(), Sp.data() + (D * D), Sq.data(), kZero); + const float rmse = sqrt(squared_error / nnz); + const float loss = feedbacks + reg; + return pair(rmse, loss); +} + +void CEALS::update_P_(const int64_t* indptr, + const int32_t* keys, + const float* vals) { + const int32_t D = opt_["d"].int_value(); + vector CQ(Q_rows_ * D, kZero); + for (int32_t iidx=0; iidx < Q_rows_; ++iidx) { + const float sqrt_C = sqrt(C_ptr_[iidx]); + const int32_t ridx = iidx * D; + const float* q_ptr = &Q_ptr_[ridx]; + float* cq_ptr = &CQ[ridx]; + transform(q_ptr, q_ptr + D, cq_ptr, + [sqrt_C](const float elem) -> float { + return sqrt_C * elem; + }); + } + vector Sq(D * D, kZero); + blas::syrk("u", "t", D, Q_rows_, kOne, CQ.data(), kZero, Sq.data()); + const float alpha = opt_["alpha"].number_value(); + const float reg_u = opt_["reg_u"].number_value(); + #pragma omp parallel for schedule(dynamic, 8) + for (int32_t uidx=0; uidx < P_rows_; ++uidx) { + float* p_ptr = &P_ptr_[uidx * D]; + const int64_t beg = uidx == 0 ? 0 : indptr[uidx - 1]; + const int64_t end = indptr[uidx]; + for (int32_t d=0; d < D; ++d) { + float numerator = kZero; + float denominator = kZero; + for (int64_t ind=beg; ind < end; ++ind) { + const int32_t iidx = keys[ind]; + const float v = vals[ind]; + const float* q_ptr = &Q_ptr_[iidx * D]; + const float vhat = vhat_cache_u_[ind]; + const float pq = p_ptr[d] * q_ptr[d]; + const float vf = vhat - pq; + const float w = (kOne + alpha * v); + const float wmc = w - C_ptr_[iidx]; + numerator += (w * v - wmc * vf) * q_ptr[d]; + denominator += wmc * q_ptr[d] * q_ptr[d]; + vhat_cache_u_[ind] -= pq; + vhat_cache_i_[ind_u2i_[ind]] -= pq; + } + numerator += -inner_product(p_ptr, p_ptr + D, &Sq[D * d], kZero) + p_ptr[d] * Sq[D * d + d]; + denominator += Sq[D * d + d] + reg_u; + p_ptr[d] = numerator / denominator; + for (int64_t ind=beg; ind < end; ++ind) { + const int32_t iidx = keys[ind]; + const float* q_ptr = &Q_ptr_[iidx * D]; + const float pq = p_ptr[d] * q_ptr[d]; + vhat_cache_u_[ind] += pq; + vhat_cache_i_[ind_u2i_[ind]] += pq; + } + } + } +} + +void CEALS::update_Q_(const int64_t* indptr, + const int32_t* keys, + const float* vals) { + const int32_t D = opt_["d"].int_value(); + vector Sp(D * D, kZero); + blas::syrk("u", "t", D, P_rows_, kOne, P_ptr_, kZero, Sp.data()); + const float alpha = opt_["alpha"].number_value(); + const float reg_i = opt_["reg_i"].number_value(); + #pragma omp parallel for schedule(dynamic, 8) + for (int32_t iidx=0; iidx < Q_rows_; ++iidx) { + float* q_ptr = &Q_ptr_[iidx * D]; + const int64_t beg = iidx == 0 ? 0 : indptr[iidx - 1]; + const int64_t end = indptr[iidx]; + for (int32_t d=0; d < D; ++d) { + float numerator = kZero; + float denominator = kZero; + for (int64_t ind=beg; ind < end; ++ind) { + const int32_t uidx = keys[ind]; + const float v = vals[ind]; + const float* p_ptr = &P_ptr_[uidx * D]; + const float vhat = vhat_cache_i_[ind]; + const float pq = p_ptr[d] * q_ptr[d]; + const float vf = vhat - pq; + const float w = (kOne + alpha * v); + const float wmc = w - C_ptr_[iidx]; + numerator += (w * v - wmc * vf) * p_ptr[d]; + denominator += wmc * p_ptr[d] * p_ptr[d]; + vhat_cache_i_[ind] -= pq; + vhat_cache_u_[ind_i2u_[ind]] -= pq; + } + numerator += -C_ptr_[iidx] * (inner_product(q_ptr, q_ptr + D, &Sp[D * d], kZero) - q_ptr[d] * Sp[D * d + d]); + denominator += C_ptr_[iidx] * Sp[D * d + d] + reg_i; + q_ptr[d] = numerator / denominator; + for (int64_t ind=beg; ind < end; ++ind) { + const int32_t uidx = keys[ind]; + const float* p_ptr = &P_ptr_[uidx * D]; + const float pq = p_ptr[d] * q_ptr[d]; + vhat_cache_i_[ind] += pq; + vhat_cache_u_[ind_i2u_[ind]] += pq; + } + } + } +} + +} diff --git a/setup.py b/setup.py index 414d796..f7f3d0f 100644 --- a/setup.py +++ b/setup.py @@ -83,6 +83,13 @@ def get_compiler(name: str): libraries=["gomp"], extra_compile_args=["-fopenmp", "-std=c++14", "-O3"] + extended_compile_flags, define_macros=[("NPY_NO_DEPRECATED_API", "1")]), + Extension(name="buffalo.algo._eals", + sources=["buffalo/algo/_eals.pyx", "lib/algo_impl/eals/eals.cc"] + common_srcs, + language="c++", + include_dirs=["./include"] + extra_include_dirs, + libraries=["gomp", "openblas"], + extra_compile_args=["-fopenmp", "-std=c++14", "-O3"] + extended_compile_flags, + define_macros=[("NPY_NO_DEPRECATED_API", "1")]), Extension(name="buffalo.algo._cfr", sources=["buffalo/algo/_cfr.pyx", "lib/algo_impl/cfr/cfr.cc"] + common_srcs, language="c++", diff --git a/tests/algo/test_eals.py b/tests/algo/test_eals.py new file mode 100644 index 0000000..cf73d65 --- /dev/null +++ b/tests/algo/test_eals.py @@ -0,0 +1,85 @@ +import unittest + +from loguru import logger + +import buffalo +from buffalo import EALS, EALSOption, aux, set_log_level, MatrixMarketOptions, set_log_level + +from .base import TestBase + + +class TestEALS(TestBase): + def test00_get_default_option(self): + EALSOption().get_default_option() + self.assertTrue(True) + + def test01_is_valid_option(self): + opt = EALSOption().get_default_option() + self.assertTrue(EALSOption().is_valid_option(opt)) + opt["save_best"] = 1 + self.assertRaises(RuntimeError, EALSOption().is_valid_option, opt) + opt["save_best"] = False + self.assertTrue(EALSOption().is_valid_option(opt)) + + def test02_init_with_dict(self): + set_log_level(3) + opt = EALSOption().get_default_option() + EALS(opt) + self.assertTrue(True) + + def test03_init(self): + opt = EALSOption().get_default_option() + self._test3_init(EALS, opt) + + def test04_train(self): + opt = EALSOption().get_default_option() + opt.d = 20 + self._test4_train(EALS, opt) + + def test05_validation(self): + opt = EALSOption().get_default_option() + opt.d = 5 + opt.num_iters = 20 + opt.validation = aux.Option({"topk": 10}) + self._test5_validation(EALS, opt) + + def test05_1_validation_with_callback(self,): + opt = EALSOption().get_default_option() + opt.d = 5 + opt.num_iters = 20 + opt.validation = aux.Option({"topk": 10}) + self._test5_1_validation_with_callback(EALS, opt) + + def test06_topk(self): + opt = EALSOption().get_default_option() + opt.d = 5 + opt.validation = aux.Option({"topk": 10}) + self._test6_topk(EALS, opt) + + def test07_train_ml_20m(self): + opt = EALSOption().get_default_option() + opt.num_workers = 8 + opt.validation = aux.Option({"topk": 10}) + self._test7_train_ml_20m(EALS, opt) + + def test08_serialization(self): + opt = EALSOption().get_default_option() + opt.d = 5 + opt.validation = aux.Option({"topk": 10}) + self._test8_serialization(EALS, opt) + + def test09_compact_serialization(self): + opt = EALSOption().get_default_option() + opt.d = 5 + opt.validation = aux.Option({"topk": 10}) + self._test9_compact_serialization(EALS, opt) + + def test10_fast_most_similar(self): + opt = EALSOption().get_default_option() + opt.d = 5 + opt.validation = aux.Option({"topk": 10}) + self._test10_fast_most_similar(EALS, opt) + + +if __name__ == "__main__": + unittest.main() From d068b23aff2a1e6a63110020630fba71aaf7bd2b Mon Sep 17 00:00:00 2001 From: Dongkyu Kim Date: Mon, 18 Dec 2023 13:09:22 +0900 Subject: [PATCH 2/5] Fix imports (#79) * Fix imports * Fix lint error * Fix lint error(import rule) * Fix lint error(Flake8) * Fix lint error(cpplint) --- buffalo/__init__.py | 7 ++++--- buffalo/algo/eals.py | 24 ++++++++++++------------ include/buffalo/misc/blas.hpp | 3 +++ tests/algo/test_eals.py | 7 ++----- 4 files changed, 21 insertions(+), 20 deletions(-) diff --git a/buffalo/__init__.py b/buffalo/__init__.py index e89530e..745b699 100644 --- a/buffalo/__init__.py +++ b/buffalo/__init__.py @@ -3,12 +3,13 @@ __version__ = importlib.metadata.version('buffalo') from buffalo.algo.als import ALS, inited_CUALS -from buffalo.algo.eals import EALS from buffalo.algo.base import Algo from buffalo.algo.bpr import BPRMF, inited_CUBPR from buffalo.algo.cfr import CFR -from buffalo.algo.options import (AlgoOption, ALSOption, EALSOption, BPRMFOption, - CFROption, PLSIOption, W2VOption, WARPOption) +from buffalo.algo.eals import EALS +from buffalo.algo.options import (AlgoOption, ALSOption, BPRMFOption, + CFROption, EALSOption, PLSIOption, W2VOption, + WARPOption) from buffalo.algo.plsi import PLSI from buffalo.algo.w2v import W2V from buffalo.algo.warp import WARP diff --git a/buffalo/algo/eals.py b/buffalo/algo/eals.py index 5308584..47b4cf6 100644 --- a/buffalo/algo/eals.py +++ b/buffalo/algo/eals.py @@ -75,7 +75,7 @@ def init_factors(self): for name, rows in [("P", header["num_users"]), ("Q", header["num_items"])]: setattr(self, name, None) setattr(self, name, np.random.normal(scale=1.0 / (self.opt.d ** 2), - size=(rows, self.vdim)).astype("float32")) + size=(rows, self.vdim)).astype("float32")) self.P[:, self.opt.d:] = 0.0 self.Q[:, self.opt.d:] = 0.0 self.C = self._get_negative_weights() @@ -151,17 +151,17 @@ def train(self, training_callback: Optional[Callable[[int, Dict[str, float]], No train_t = time.time() - start_t metrics = {"train_loss": loss} if self.opt.validation and \ - self.opt.evaluation_on_learning and \ - self.periodical(self.opt.evaluation_period, i): - start_t = time.time() - self.validation_result = self.get_validation_results() - vali_t = time.time() - start_t - val_str = " ".join([f"{k}:{v:0.5f}" for k, v in self.validation_result.items()]) - self.logger.info(f"Validation: {val_str} Elapsed {vali_t:0.3f} secs") - metrics.update({"val_%s" % k: v - for k, v in self.validation_result.items()}) - if training_callback is not None and callable(training_callback): - training_callback(i, metrics) + self.opt.evaluation_on_learning and \ + self.periodical(self.opt.evaluation_period, i): + start_t = time.time() + self.validation_result = self.get_validation_results() + vali_t = time.time() - start_t + val_str = " ".join([f"{k}:{v:0.5f}" for k, v in self.validation_result.items()]) + self.logger.info(f"Validation: {val_str} Elapsed {vali_t:0.3f} secs") + metrics.update({"val_%s" % k: v + for k, v in self.validation_result.items()}) + if training_callback is not None and callable(training_callback): + training_callback(i, metrics) self.logger.info("Iteration %d: RMSE %.3f TotalLoss %.3f Elapsed %.3f secs" % (i + 1, loss, (total_loss / self._nnz), train_t)) best_loss = self.save_best_only(loss, best_loss, i) if self.early_stopping(loss): diff --git a/include/buffalo/misc/blas.hpp b/include/buffalo/misc/blas.hpp index 0986977..62ab806 100644 --- a/include/buffalo/misc/blas.hpp +++ b/include/buffalo/misc/blas.hpp @@ -1,5 +1,8 @@ #pragma once +#include +#include + extern "C" { // blas subroutines diff --git a/tests/algo/test_eals.py b/tests/algo/test_eals.py index cf73d65..c62aaae 100644 --- a/tests/algo/test_eals.py +++ b/tests/algo/test_eals.py @@ -1,9 +1,6 @@ import unittest -from loguru import logger - -import buffalo -from buffalo import EALS, EALSOption, aux, set_log_level, MatrixMarketOptions, set_log_level +from buffalo import EALS, EALSOption, aux, set_log_level from .base import TestBase @@ -61,7 +58,7 @@ def test07_train_ml_20m(self): opt.num_workers = 8 opt.validation = aux.Option({"topk": 10}) self._test7_train_ml_20m(EALS, opt) - + def test08_serialization(self): opt = EALSOption().get_default_option() opt.d = 5 From 6a8312fb2682299267907858d7c7739feb89ba81 Mon Sep 17 00:00:00 2001 From: Dongkyu Kim Date: Fri, 22 Dec 2023 14:26:21 +0900 Subject: [PATCH 3/5] Fix type mismatch error in MatrixMarket (#81) --- buffalo/data/mm.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/buffalo/data/mm.py b/buffalo/data/mm.py index fc64345..1d99db2 100644 --- a/buffalo/data/mm.py +++ b/buffalo/data/mm.py @@ -138,17 +138,17 @@ def get_max_column_length(fname): # if not given, assume id as is if uid_path: with open(uid_path) as fin: - idmap["rows"][:] = np.loadtxt(fin, dtype=f"S{uid_max_col}") + idmap["rows"][:] = np.loadtxt(fin, dtype=h5py.string_dtype("utf-8", length=uid_max_col)) else: idmap["rows"][:] = np.array([str(i) for i in range(1, num_users + 1)], - dtype=f"S{uid_max_col}") + dtype=h5py.string_dtype("utf-8", length=uid_max_col)) pbar.update(1) if iid_path: with open(iid_path) as fin: - idmap["cols"][:] = np.loadtxt(fin, dtype=f"S{iid_max_col}") + idmap["cols"][:] = np.loadtxt(fin, dtype=h5py.string_dtype("utf-8", length=iid_max_col)) else: idmap["cols"][:] = np.array([str(i) for i in range(1, num_items + 1)], - dtype=f"S{iid_max_col}") + dtype=h5py.string_dtype("utf-8", length=iid_max_col)) pbar.update(1) num_header_lines = 0 with open(main_path) as fin: @@ -252,8 +252,8 @@ def create(self) -> h5py.File: self.logger.debug("Building meta part...") db, num_header_lines = self._create(data_path, {"main_path": mm_main_path, - "uid_path": mm_uid_path, - "iid_path": mm_iid_path}, + "uid_path": mm_uid_path, + "iid_path": mm_iid_path}, header) try: num_header_lines += 1 # add metaline From efd7d0c9072be65fa8883bf60617bd2fac947d03 Mon Sep 17 00:00:00 2001 From: yupyub Date: Tue, 9 Jan 2024 20:47:12 +0900 Subject: [PATCH 4/5] fix stream data newline char error. fix typo (#82) --- buffalo/data/fileio.hpp | 2 +- buffalo/data/stream.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/buffalo/data/fileio.hpp b/buffalo/data/fileio.hpp index ad57cff..49b2da7 100644 --- a/buffalo/data/fileio.hpp +++ b/buffalo/data/fileio.hpp @@ -322,7 +322,7 @@ vector _sort_and_compressed_binarization( records.insert(end(records), begin(v), end_it); } - assert(records.size == total_lines); + assert(records.size() == total_lines); omp_set_num_threads(num_workers); diff --git a/buffalo/data/stream.py b/buffalo/data/stream.py index 3cdd283..3dc1fdd 100644 --- a/buffalo/data/stream.py +++ b/buffalo/data/stream.py @@ -105,6 +105,8 @@ def get_max_column_length(fname): with open(main_path) as fin: for line in log.ProgressBar(level=log.DEBUG, iterable=fin): data = line.strip().split() + if not data: + continue if not iid_path: itemids |= set(data) @@ -246,7 +248,7 @@ def _create_working_data(self, db, stream_main_path, itemids, for col in train_data: w.write(f"{user} {col} 1\n") for col in vali_data: - vali_lines.append(f"{user} {col} {val}") + vali_lines.append(f"{user} {col} 1") else: for col, val in Counter(train_data).items(): w.write(f"{user} {col} {val}\n") From 764adf32737fdc3afb891b4fea0504ca7a630e66 Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Mon, 15 Jan 2024 11:06:32 +0900 Subject: [PATCH 5/5] bump version --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index ebb8ba5..dfa3a95 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = buffalo -version = 2.0.3 +version = 2.0.4 author = recoteam author_email = recoteam@kakaocorp.com url = https://github.com/kakao/buffalo