#include <cstdlib>
#include <string>
#include <sstream>
#include <iostream>
#include <cmath>

using namespace std;
#include "Nodo.hpp"

void Debug(std::string msg)
{
	std::cerr << "\nDEBUG: " << msg << std::endl;
}

// ------- Const ----------
Const::Const(double v) : Nodo()
{
	val= v;
}
	
Nodo* Const::copy(void) const
{
	return new Const(val);
}

double Const::eval(double x) const
{
	return val;
}

Nodo* Const::deriv(void) const
{
	return new Const(0);
}

string Const::print(void) const
{
	ostringstream out;
	out << val;
	return out.str();
}


// ------- VarX  ----------
VarX::VarX(void) : Nodo()
{
}

Nodo* VarX::copy(void) const
{
	return new VarX();
}

double VarX::eval(double x) const
{
	return x;
}

string VarX::print(void) const
{
	ostringstream out;
	out << 'x';
	return out.str();
}

Nodo* VarX::deriv(void) const
{
	return new Const(1);
}

// ------- OpBin ----------

OpBin::OpBin(Nodo *l, Nodo *r) : Nodo()
{
	left= l;
	right= r;
}

// ------- Add ----------

Add::Add(Nodo *i, Nodo *d) : OpBin(i, d)
{
}

Nodo* Add::copy(void) const
{
	return new Add(Lft()->copy(), Rgt()->copy());
}

string Add::print(void) const
{
	ostringstream out;
	out << Lft()->print() << "+ " << Rgt()->print();
	return out.str();
}

double Add::eval(double x) const
{
	return Lft()->eval(x)  +   Rgt()->eval(x);
}

Nodo* Add::deriv(void) const
{
	return new Add(Lft()->deriv(), Rgt()->deriv());
}

Nodo* Add::simp(void) const
{
	Nodo* l= Lft()->simp();
	Nodo* r= Rgt()->simp();
	if (l->type()==NUM && r->type()==NUM) {  
		Nodo* t= new Const(l->eval() + r->eval());  Debug(t->print());
		delete r;
		delete l;
		return t;
	}
	if (l->type()==NUM) {	
		if(l->eval()==0.0) {
			delete l;
			return r;
		 }
	}
	if (r->type()==NUM) {
		if(r->eval()==0.0) {
			delete r;
			return l;
		}
	}
	return new Add(l, r);
}

// ------- Sub ----------

Sub::Sub(Nodo *i, Nodo *d) : OpBin(i, d)
{
}

Nodo* Sub::copy(void) const
{
	return new Sub(Lft()->copy(), Rgt()->copy());
}

string Sub::print(void) const
{
	ostringstream out;
	out << Lft()->print() << "- ";
	if (Rgt()->type()==ADD || Rgt()->type()==SUB) 
		out << "(" << Rgt()->print() << ")";
	else
		out << Rgt()->print();
	return out.str();
}

double Sub::eval(double x) const
{
	return Lft()->eval(x)  -   Rgt()->eval(x);
}

Nodo* Sub::deriv(void) const
{
	return new Sub(Lft()->deriv(), Rgt()->deriv());
}

Nodo* Sub::simp(void) const
{
	Nodo* l= Lft()->simp();
	Nodo* r= Rgt()->simp();
	if (l->type()==NUM && r->type()==NUM) {
	    double v= l->eval() - r->eval();
		delete r;
		delete l;
		if (v < 0)
			return new Chs( new Const(-v) ); 
		else	
			return new Const(v);
	}
	if (l->type()==NUM) {
		if(l->eval()==0.0) {
			delete l;
			return new Chs(r);
		 }
	}
	if (r->type()==NUM) {
		if(r->eval()==0.0) {
			delete r;
			return l;
		}
	}
	return new Sub(l, r);
}

// ------- Mul ----------

Mul::Mul(Nodo *i, Nodo *d) : OpBin(i, d)
{
}

Nodo* Mul::copy(void) const
{
	return new Mul(Lft()->copy(), Rgt()->copy());
}

string Mul::print(void) const
{
	ostringstream out;
	if (Lft()->type()==ADD || Lft()->type()==SUB) 
		out << "(" << Lft()->print() << ")";
	else
		out << Lft()->print();
	out << "* ";
	if (Rgt()->type()==ADD || Rgt()->type()==SUB || Rgt()->type()==CHS) 
		out << "(" << Rgt()->print() << ")";
	else
		out << Rgt()->print();
	return out.str();
}

double Mul::eval(double x) const
{
	return Lft()->eval(x) * Rgt()->eval(x);
}

Nodo* Mul::deriv(void) const
{
	return  new Add(
				new Mul(Lft()->copy(), Rgt()->deriv()),
				new Mul(Lft()->deriv(), Rgt()->copy())
			);
}

Nodo* Mul::simp(void) const
{
	Nodo* l= Lft()->simp();
	Nodo* r= Rgt()->simp();
	if (l->type()==NUM && r->type()==NUM) {
		Nodo* t= new Const(l->eval() * r->eval());
		delete r;
		delete l;
		return t;
	}
	if (l->type()==NUM) {
		if(l->eval()==1.0) {
			delete l;
			return r;
		 }
		if(l->eval()==0.0) {
			delete r;
			return l;
		 }
	}
	if (r->type()==NUM) {
		if(r->eval()==1.0) {
			delete r;
			return l;
		}
		if(r->eval()==0.0) {
			delete l;
			return r;
		 }
	}
	return new Mul(l, r);
}
// ------- Div ----------

Div::Div(Nodo *i, Nodo *d) : OpBin(i, d)
{
}

Nodo* Div::copy(void) const
{
	return new Div(Lft()->copy(), Rgt()->copy());
}

Nodo* Div::deriv(void) const
{
	return new Div(
					new Sub(
							new Mul(
									Rgt()->copy(),
									Lft()->deriv()
								   ),
							new Mul(
									Lft()->copy(),
									Rgt()->deriv()
								   )
						   ),
					new Pow(
							Rgt()->copy(),
							new Const(2)
						)
				);
}

string Div::print(void) const
{
	ostringstream out;
	if (Lft()->type()==ADD || Lft()->type()==SUB) 
		out << "(" << Lft()->print() << ")";
	else
		out << Lft()->print();
	out << "/ ";
	if (Rgt()->type()==ADD || Rgt()->type()==SUB ||
		Rgt()->type()==MUL || Rgt()->type()==DIV 
		|| Rgt()->type()==CHS) 
		out << "(" << Rgt()->print() << ")";
	else
		out << Rgt()->print();
	return out.str();
}

double Div::eval(double x) const
{
	return Lft()->eval(x) / Rgt()->eval(x);
}

Nodo* Div::simp(void) const
{
	Nodo* l= Lft()->simp();
	Nodo* r= Rgt()->simp();
	return new Div(l, r);
}

// ------- Pow ----------

Pow::Pow(Nodo *i, Nodo *d) : OpBin(i, d)
{
}

Nodo* Pow::copy(void) const
{
	return new Pow(Lft()->copy(), Rgt()->copy());
}

string Pow::print(void) const
{
	ostringstream out;
	int t= Lft()->type();
	if (t==ADD || t==SUB || t==MUL || t==DIV || t==POW || t== CHS)
		out << "(" << Lft()->print() << ")";
	else
		out << Lft()->print();
	out << "^";
	t= Rgt()->type();
	if (t==ADD || t==SUB || t==MUL || t==DIV || t== CHS)
		out << "(" << Rgt()->print() << ")";
	else
		out << Rgt()->print();
	return out.str();
}

double Pow::eval(double x) const
{
	int base=  Lft()->eval(x);
	int expo=  Rgt()->eval(x);
	return pow(base, expo);
}

Nodo* Pow::deriv(void) const
{
	return new Add (
			new Mul(	
				new Mul(
					Rgt()->copy(),	
					new Pow(
							Lft()->copy(), 
							new Sub(
								Rgt()->copy(),
								new Const(1)
							)
						)
				   ),
				Lft()->deriv()
				),
			new Mul(
				new Pow( 
					Lft()->copy(),
					Rgt()->copy()
				),
				new Mul(
					new Log (
						Lft()->copy()
					),
					Rgt()->deriv()	
				)
			)
		);
}

Nodo* Pow::simp(void) const
{
	Nodo* l= Lft()->simp();
	Nodo* r= Rgt()->simp();
	if (l->type()==NUM && r->type()==NUM) {
		Nodo* t= new Const(pow(l->eval(), r->eval()));
		delete r;
		delete l;
		return t;
	}
	if (l->type()==NUM) {
		if(l->eval()==1.0) {    // 1^algo -> 1
			delete r;
			return l;
		 }
		if(l->eval()==0.0) {    // 0^algo -> 0
			delete r;
			return l;
		 }
	}
	if (r->type()==NUM) {
		if(r->eval()==1.0) {
			delete r;
			return l;
		}
		if(r->eval()==0.0) {
			delete l;
			delete r;
			return new Const(1.0);
		 }
	}
	if (l->type()==POW) {
			Pow* p= (Pow*) l;
			Nodo* b= p->Lft();
			Nodo* e= p->Rgt();
			Nodo* u= new Mul(e->copy(), r);
			Nodo* v= u->simp();
			delete u;
			Nodo* t= new Pow( b->copy(), v );
			delete l;
			return t;
	}
	return new Pow(l, r);
}

// ------- Chs ----------

Chs::Chs(Nodo *d)
{
	right= d;
}

Nodo* Chs::copy(void) const
{
	return new Chs(Rgt()->copy());
}

double Chs::eval(double x) const
{
	return - Rgt()->eval(x);
}

Nodo* Chs::deriv(void) const
{
	return new Chs(Rgt()->deriv());
}

string Chs::print(void) const
{
	ostringstream out;
	out << "-";
	int td= Rgt()->type();
	if (td==ADD || td==SUB || td==CHS) 
		out << "(" << Rgt()->print() << ")";
	else
		out << Rgt()->print();

	return out.str();
}

Nodo* Chs::simp(void) const
{
	Nodo* r= Rgt()->simp();
	return new Chs(r);
}

// ------- Fun ----------

Fun::Fun(Nodo *r) : Nodo()
{
	right= r;
}

// ------- Sin ----------

Sin::Sin(Nodo *d) : Fun(d)
{
}

Nodo* Sin::copy(void) const
{
	return new Sin(Rgt()->copy());
}

string Sin::print(void) const
{
	ostringstream out;
	out << "sin(" << Rgt()->print() << ")";
	return out.str();
}

double Sin::eval(double x) const
{
	return sin(Rgt()->eval(x));
}

Nodo* Sin::deriv(void) const
{
	return new Mul(new Cos(Rgt()->copy()), Rgt()->deriv());
}

Nodo* Sin::simp(void) const
{
	Nodo* r= Rgt()->simp();
	return new Sin(r);
}

// ------- Cos ----------

Cos::Cos(Nodo *d) : Fun(d)
{
}

Nodo* Cos::copy(void) const
{
	return new Cos(Rgt()->copy());
}

string Cos::print(void) const
{
	ostringstream out;
	out << "cos(" << Rgt()->print() << ")";
	return out.str();
}

double Cos::eval(double x) const
{
	return cos(Rgt()->eval(x));
}

Nodo* Cos::deriv(void) const
{
	return new Chs(new Mul(
							new Sin(Rgt()->copy()), 
							Rgt()->deriv()
						)
				);
}

Nodo* Cos::simp(void) const
{
	Nodo* r= Rgt()->simp();
	return new Cos(r);
}

// ------- Atan ----------

Atan::Atan(Nodo *d) : Fun(d)
{
}

Nodo* Atan::copy(void) const
{
	return new Atan(Rgt()->copy());
}

string Atan::print(void) const
{
	ostringstream out;
	out << "atan(" << Rgt()->print() << ")";
	return out.str();
}

double Atan::eval(double x) const
{
	return acos(Rgt()->eval(x));
}

Nodo* Atan::deriv(void) const
{
	return new Div(
					Rgt()->deriv(),
					new Add(
						new Const(1.0),
						new Pow(
								Rgt()->copy(),
								new Const(2.0)
								)
					)
				  );
}

Nodo* Atan::simp(void) const
{
	Nodo* r= Rgt()->simp();
	return new Atan(r);
}
// ------- Exp ----------

Exp::Exp(Nodo *d) : Fun(d)
{
}

Nodo* Exp::copy(void) const
{
	return new Exp(Rgt()->copy());
}

string Exp::print(void) const
{
	ostringstream out;
	out << "exp(" << Rgt()->print() << ")";
	return out.str();
}

double Exp::eval(double x) const
{
	return exp(Rgt()->eval(x));
}

Nodo* Exp::deriv(void) const
{
	return new Mul(
					new Exp(
						Rgt()->copy()
					),
					Rgt()->deriv()
				  );
}

Nodo* Exp::simp(void) const
{
	Nodo* r= Rgt()->simp();
	return new Exp(r);
}

// ------- Log ----------

Log::Log(Nodo *d) : Fun(d)
{
}

Nodo* Log::copy(void) const
{
	return new Log(Rgt()->copy());
}

string Log::print(void) const
{
	ostringstream out;
	out << "log(" << Rgt()->print() << ")";
	return out.str();
}

double Log::eval(double x) const
{
	return exp(Rgt()->eval(x));
}

Nodo* Log::deriv(void) const
{
	return new Div(
					Rgt()->deriv(),
					Rgt()->copy()
				  );
}

Nodo* Log::simp(void) const
{
	Nodo* r= Rgt()->simp();
	return new Log(r);
}

