// Generated by paddle/phi/api/generator/tensor_operants_gen.py

#pragma once

#include "paddle/phi/api/include/operants_base.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/common/macros.h"
#include "paddle/utils/test_macros.h"


namespace paddle {

using Tensor = paddle::Tensor;
using Scalar = paddle::experimental::Scalar;
using IntArray = paddle::experimental::IntArray;
using TensorOperantsBase = paddle::operants::TensorOperantsBase;

/**
 * [ Why need OperantsManager? ]
 *
 * Ideally, overloading tensor operators should call Tensor API directly.
 * However, we faced two problems:
 *
 * 1. Support multiple modes: Tensor operator overloading needs to support
 * [static mode / autograd mode / custom operator mode] at the same time.
 *
 * 2. Decouple phi and fluid: Tensor belongs to the phi library, but it relies
 * upon functions in fluid when overloading Tensor operators.
 *
 * We design OperantsManager to solve these two problems:
 *
 * 1. use `FLAGS_tensor_operants_mode` to handle overloading mode, set this flag
 * at the entry point of each mode:
 *
 * - FLAGS_tensor_operants_mode = "static": at the construction function of
 * `CompositeGradOpMakerBase`.
 * - FLAGS_tensor_operants_mode = "eager": at the beginning of dygraph_function.
 * - FLAGS_tensor_operants_mode = "phi": at the beginning of the
 * `eager_api_run_custom_op` function in eager mode and at the beginning of
 * calling kernels in static mode.
 *
 * In order to guarantee the performance, OperantsManager holds three pointers
 * to identify each mode respectively.
 *
 * 2. Decouple phi with the help of the polymorphism mechanism,
 * TensorOperantsBase derives three child classes: PhiTensorOperants,
 * EagerTensorOperants, and StaticTensorOperants. We set eager and static tensor
 * operants at the fluid library and set phi operants at the phi library.
 *
 */
class TEST_API OperantsManager {
 private:
  OperantsManager() = default;
  DISABLE_COPY_AND_ASSIGN(OperantsManager);

 public:
  std::unique_ptr<TensorOperantsBase> eager_operants{nullptr};
  std::unique_ptr<TensorOperantsBase> static_operants{nullptr};
  std::unique_ptr<TensorOperantsBase> phi_operants{nullptr};

 public:
  static OperantsManager& Instance();

  Tensor add(const Tensor& x, const Scalar& y);

  Tensor subtract(const Tensor& x, const Scalar& y);

  Tensor multiply(const Tensor& x, const Scalar& y);

  Tensor divide(const Tensor& x, const Scalar& y);

  Tensor add(const Scalar& x, const Tensor& y);

  Tensor subtract(const Scalar& x, const Tensor& y);

  Tensor multiply(const Scalar& x, const Tensor& y);

  Tensor divide(const Scalar& x, const Tensor& y);

  Tensor pow(const Tensor& x, const Tensor& y);

  Tensor pow(const Tensor& x, const Scalar& y);


  Tensor abs(const Tensor& x);

  Tensor bitwise_and(const Tensor& x, const Tensor& y);

  Tensor bitwise_not(const Tensor& x);

  Tensor bitwise_or(const Tensor& x, const Tensor& y);

  Tensor bitwise_xor(const Tensor& x, const Tensor& y);

  Tensor exp(const Tensor& x);

  Tensor expand(const Tensor& x, const IntArray& shape = {});

  Tensor floor(const Tensor& x);

  Tensor gather_nd(const Tensor& x, const Tensor& index);

  Tensor log(const Tensor& x);

  Tensor max(const Tensor& x, const IntArray& axis = {}, bool keepdim = false);

  Tensor roll(const Tensor& x, const IntArray& shifts = {}, const std::vector<int64_t>& axis = {});

  Tensor scale(const Tensor& x, const Scalar& scale = 1.0, const Scalar& bias = 0.0, bool bias_after_scale = true);

  Tensor scatter(const Tensor& x, const Tensor& index, const Tensor& updates, bool overwrite = true);

  Tensor scatter_nd_add(const Tensor& x, const Tensor& index, const Tensor& updates);

  Tensor sum(const Tensor& x, const IntArray& axis = {}, DataType dtype = DataType::UNDEFINED, bool keepdim = false);

  Tensor add(const Tensor& x, const Tensor& y);

  Tensor assign(const Tensor& x);

  Tensor divide(const Tensor& x, const Tensor& y);

  Tensor elementwise_pow(const Tensor& x, const Tensor& y);

  Tensor equal(const Tensor& x, const Tensor& y);

  Tensor greater_equal(const Tensor& x, const Tensor& y);

  Tensor greater_than(const Tensor& x, const Tensor& y);

  Tensor less_equal(const Tensor& x, const Tensor& y);

  Tensor less_than(const Tensor& x, const Tensor& y);

  Tensor matmul(const Tensor& x, const Tensor& y, bool transpose_x = false, bool transpose_y = false);

  Tensor maximum(const Tensor& x, const Tensor& y);

  Tensor minimum(const Tensor& x, const Tensor& y);

  Tensor multiply(const Tensor& x, const Tensor& y);

  Tensor not_equal(const Tensor& x, const Tensor& y);

  Tensor subtract(const Tensor& x, const Tensor& y);

  Tensor tile(const Tensor& x, const IntArray& repeat_times = {});
};

}  // namespace paddle

