import logging
import unittest
from dateutil.relativedelta import relativedelta

from py_jftech import autowired, parse_date, to_str, next_workday

from api import RebalanceSignal, PortfoliosRisk, RebalanceRuler, RoboReportor

logger = logging.getLogger(__name__)


class RebalanceTest(unittest.TestCase):

    @autowired(names={'builder': 'crisis_one'})
    def test_crisis_one(self, builder: RebalanceSignal = None):
        start = parse_date('2008-03-12')
        end = start + relativedelta(years=3)
        while start < end:
            signal = builder.is_trigger(start, PortfoliosRisk.FT9)
            if signal:
                logger.info(start)
            start = next_workday(start)

    @autowired(names={'builder': 'crisis_two'})
    def test_crisis_two(self, builder: RebalanceSignal = None):
        start = parse_date('2020-04-29')
        end = start + relativedelta(years=3)
        while start < end:
            signal = builder.is_trigger(start, PortfoliosRisk.FT9)
            if signal:
                logger.info(start)
            start = next_workday(start)

    @autowired(names={'builder': 'market-right'})
    def test_market_right(self, builder: RebalanceSignal = None):
        signal = builder.get_signal(parse_date('2008-01-07'), PortfoliosRisk.FT9)
        logger.info(signal)

    @autowired(names={'builder': 'curve-drift'})
    def test_curve_drift(self, builder: RebalanceSignal = None):
        signal = builder.get_signal(parse_date('2022-11-07'), PortfoliosRisk.FT3)
        logger.info(signal)

    @autowired(names={'builder': 'high-buy'})
    def test_high_buy(self, builder: RebalanceSignal = None):
        builder.get_signal(parse_date('2022-09-10'), PortfoliosRisk.FT3)

    @autowired
    def test_rebalance_builder(self, builder: RebalanceRuler = None):
        builder.take_next_signal(parse_date('2020-04-29'), PortfoliosRisk.FT9)

    @autowired(names={'reportor': 'signal-report'})
    def test_signal_report(self, reportor: RoboReportor = None):
        result = reportor.load_report()
        logger.info(to_str(result, show_line=10))


if __name__ == '__main__':
    unittest.main()