use crate::internal::*;

pub use super::{InletId, Model, Node, OutletId};

/// Extensions on Model to explore and build graph models more easily.
pub trait ModelDsl<TI: TensorInfo> {
    /// Find the lone precursor of a node, if applicable.
    fn single_prec(&self, id: usize) -> TractResult<Option<&Node<TI>>>;
    /// Find the count-th precursor of a node `id` in a chain of single tensor
    /// operation, if applicable.
    fn single_prec_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<TI>>>;
    /// Find the lone succesor of a node, if applicable.
    fn single_succ(&self, id: usize) -> TractResult<Option<&Node<TI>>>;
    /// Find the count-th successor of a node `id` in a chain of single tensor
    /// operation, if applicable.
    fn single_succ_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<TI>>>;

    /// Adds a source op to the network.
    ///
    /// The model will assume this is an input.
    fn add_source(&mut self, name: impl Into<String>, fact: TI) -> TractResult<usize>;
    /// Chain a node to the latest inserted node.
    ///
    /// * creates a node with name and op
    /// * connect the 0-th input of the new node to the 0-th outlet of the
    /// latest previously inserted node.
    fn chain(
        &mut self,
        name: impl Into<String>,
        op: impl Into<Box<Op>>,
        facts: TVec<TI>,
    ) -> TractResult<usize>;

    /// Chain a node to an arbitrary node.
    ///
    /// * creates a node with name and op
    /// * connect the 0-th input of the new node to `tap`
    fn chain_after(
        &mut self,
        tap: OutletId,
        name: impl Into<String>,
        op: impl Into<Box<Op>>,
        facts: TVec<TI>,
    ) -> TractResult<usize>;
}

impl<TI: TensorInfo> ModelDsl<TI> for Model<TI> {
    fn add_source(&mut self, name: impl Into<String>, fact: TI) -> TractResult<usize> {
        let id = self.add_node(name, crate::ops::source::Source::new(), tvec!(fact))?;
        Ok(id)
    }

    fn chain(
        &mut self,
        name: impl Into<String>,
        op: impl Into<Box<Op>>,
        facts: TVec<TI>,
    ) -> TractResult<usize> {
        let previous_id = self.nodes().len() - 1;
        self.chain_after(OutletId::new(previous_id, 0), name, op.into(), facts)
    }

    fn single_prec(&self, id: usize) -> TractResult<Option<&Node<TI>>> {
        let node = &self.nodes()[id];
        if node.inputs.len() != 1 {
            return Ok(None);
        }
        let prec = &self.nodes()[node.inputs[0].node];
        if prec.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
            return Ok(None);
        }
        Ok(Some(prec))
    }

    fn single_prec_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<TI>>> {
        let mut node = self.node(id);
        for _ in 0..count {
            if let Some(next) = self.single_prec(node.id)? {
                node = next
            } else {
                return Ok(None);
            }
        }
        Ok(Some(node))
    }

    fn single_succ_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<TI>>> {
        let mut node = self.node(id);
        for _ in 0..count {
            if let Some(next) = self.single_succ(node.id)? {
                node = next
            } else {
                return Ok(None);
            }
        }
        Ok(Some(node))
    }

    fn single_succ(&self, id: usize) -> TractResult<Option<&Node<TI>>> {
        let node = &self.nodes()[id];
        if node.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
            return Ok(None);
        }
        let succ = node.outputs[0].successors[0];
        let succ = &self.nodes()[succ.node];
        if succ.inputs.len() != 1 {
            return Ok(None);
        }
        Ok(Some(succ))
    }

    fn chain_after(
        &mut self,
        tap: OutletId,
        name: impl Into<String>,
        op: impl Into<Box<Op>>,
        facts: TVec<TI>,
    ) -> TractResult<usize> {
        let id = self.add_node(name, op, facts)?;
        self.add_edge(tap, InletId::new(id, 0))?;
        Ok(id)
    }
}

/// Extension to add constants to model that tolerates them.
pub trait ModelDslConst {
    /// Add a constant node to the graph.
    fn add_const(&mut self, name: impl Into<String>, v: impl IntoArcTensor) -> TractResult<usize>;
    /// Add a constant node to the graph and connect its output to `inlet`.
    fn plug_const(
        &mut self,
        inlet: InletId,
        name: impl Into<String>,
        v: impl IntoArcTensor,
    ) -> TractResult<()>;
}

impl ModelDslConst for super::InferenceModel {
    fn add_const(&mut self, name: impl Into<String>, v: impl IntoArcTensor) -> TractResult<usize> {
        let v = v.into_arc_tensor();
        let facts = tvec!(v.clone().into());
        self.add_node(name, crate::ops::konst::Const::new(v), facts)
    }
    fn plug_const(
        &mut self,
        inlet: InletId,
        name: impl Into<String>,
        v: impl IntoArcTensor,
    ) -> TractResult<()> {
        let cst = self.add_const(name, v)?;
        self.add_edge(OutletId::new(cst, 0), inlet)?;
        Ok(())
    }
}

impl ModelDslConst for super::TypedModel {
    fn add_const(&mut self, name: impl Into<String>, v: impl IntoArcTensor) -> TractResult<usize> {
        let v = v.into_arc_tensor();
        let facts = tvec!(v.clone().into());
        self.add_node(name, crate::ops::konst::Const::new(v), facts)
    }
    fn plug_const(
        &mut self,
        inlet: InletId,
        name: impl Into<String>,
        v: impl IntoArcTensor,
    ) -> TractResult<()> {
        let cst = self.add_const(name, v)?;
        self.add_edge(OutletId::new(cst, 0), inlet)?;
        Ok(())
    }
}

/// Model extension for InferenceModel
pub trait ModelDslInfer {
    /// Add a source with no tensor information.
    fn add_source_default(&mut self, name: impl Into<String>) -> TractResult<usize>;
    /// Add a node without tensor information.
    fn add_node_default(
        &mut self,
        name: impl Into<String>,
        op: impl Into<Box<Op>>,
    ) -> TractResult<usize>;
    /// Chain a node without tensor information.
    fn chain_default(
        &mut self,
        name: impl Into<String>,
        op: impl Into<Box<Op>>,
    ) -> TractResult<usize>;
}

impl ModelDslInfer for super::InferenceModel {
    fn add_source_default(&mut self, name: impl Into<String>) -> TractResult<usize> {
        self.add_source(name, TensorFact::default())
    }
    fn add_node_default(
        &mut self,
        name: impl Into<String>,
        op: impl Into<Box<Op>>,
    ) -> TractResult<usize> {
        self.add_node(name, op, tvec!(TensorFact::default()))
    }
    fn chain_default(
        &mut self,
        name: impl Into<String>,
        op: impl Into<Box<Op>>,
    ) -> TractResult<usize> {
        self.chain(name, op, tvec!(TensorFact::default()))
    }
}
