# @copyright (c) 2002-2016 Acronis International GmbH. All rights reserved.
# EULA: https://www.acronis.com/en-us/download/docs/eula/corporate/

import abc
import asyncio
import binascii
import construct
import io
import logging
import os
from collections import namedtuple
from ssl import SSLError
from .exceptions import BaseError, ConnectError, make_server_error
from .requests import *


log = logging.getLogger(__name__)

REATTEMPTS_COUNT = 30
REATTEMPTS_DELAY = 1

Promise = namedtuple('Promise', ('future', 'request_data', 'request_builder', 'request_parser'))


class Hexify:
    def __init__(self, data):
        self.__data = data

    def __str__(self):
        hexed = str(binascii.hexlify(self.__data), 'ascii')
        return ':'.join(hexed[i:i + 2] for i in range(0, len(hexed), 2))


def get_stream_size(stream):
    save_stream_pos = stream.tell()
    stream.seek(0, os.SEEK_END)
    size = stream.tell()
    stream.seek(save_stream_pos)
    return size


def is_stream_empty(stream):
    return get_stream_size(stream) == 0


def is_fatal_error(exc):
    return isinstance(exc, SSLError)


class BaseProtocol(asyncio.Protocol):
    def __init__(self, host, port, ssl_ctx, loop=None):
        self._host = host
        self._port = port
        self._ssl_ctx = ssl_ctx
        self._loop = loop or asyncio.get_event_loop()

        self._transport = None
        self._is_closed = False
        self._buffer = io.BytesIO()

    @property
    def connected(self):
        return self._transport is not None

    @property
    def busy(self):
        return self._has_requests()

    @asyncio.coroutine
    def connect(self):
        if self.connected:
            return

        connect_exc = None
        for attempt_num in range(1, REATTEMPTS_COUNT + 1):
            try:
                log.debug('Connecting to server %s:%d. Attempt number %d.', self._host, self._port, attempt_num)
                self._transport, _ = yield from self._loop.create_task(
                    self._loop.create_connection(lambda: self, self._host, self._port, ssl=self._ssl_ctx)
                )
            except Exception as exc:
                connect_exc = exc
                log.exception('Failed connect to server:')
                if is_fatal_error(connect_exc):
                    break
            else:
                break
            yield from asyncio.sleep(REATTEMPTS_DELAY, loop=self._loop)

        if not self.connected:
            self._fail_requests(ConnectError(self._host, self._port, connect_exc))
            raise connect_exc

    def close(self):
        self._is_closed = True
        if self.connected:
            self._transport.close()

    def _add_data_to_buffer(self, data):
        self._buffer.seek(0, os.SEEK_END)
        self._buffer.write(data)

    def _send_request(self, request_data, request_builder):
        log.debug('Executing request: %r', request_data)
        request = request_builder.build(request_data)
        self._transport.write(request)
        log.debug('Data sent: %s', Hexify(request))

    @abc.abstractmethod
    def _fail_requests(self, exc):
        pass

    @abc.abstractmethod
    def _resend_requests(self):
        pass

    @abc.abstractmethod
    def _has_requests(self):
        pass

    @abc.abstractmethod
    def _process_data(self):
        pass

    def connection_made(self, transport):
        """
        Called when a connection is made.
        Method of asyncio.BaseProtocol parent class.
        """
        peerinfo = transport.get_extra_info('peername')
        log.debug('Connection made with: %s:%d', peerinfo[0], peerinfo[1])

        self._transport = transport
        self._is_closed = False
        self._buffer = io.BytesIO()

        self._resend_requests()

    def connection_lost(self, err):
        """
        Called when the connection is lost or closed.
        Method of asyncio.BaseProtocol parent class.
        """
        log.debug('Connection lost with reason: %s', err)

        self._transport = None
        self._buffer = io.BytesIO()

        if not self._is_closed:
            self._loop.create_task(self.connect())

    def pause_writing(self):
        """
        Called when the transport's buffer goes over the high-water mark.
        Method of asyncio.BaseProtocol parent class.
        """
        log.debug('Writing paused')

    def resume_writing(self):
        """
        Called when the transport's buffer drains below the low-water mark.
        Method of asyncio.BaseProtocol parent class.
        """
        log.debug('Writing resumed')

    def data_received(self, data):
        """
        Called when some data is received.
        Method of asyncio.Protocol parent class.
        """
        log.debug('Data of size %d received: %s', len(data), Hexify(data))

        self._add_data_to_buffer(data)
        try:
            while not is_stream_empty(self._buffer):
                self._buffer.seek(0)
                self._process_data()
        except construct.ConstructError:
            pass

    def eof_received(self):
        """
        Called when the other end calls write_eof() or equivalent.
        Method of asyncio.Protocol parent class.
        """
        log.debug('EOF received')


class InputStreamProtocol(BaseProtocol):
    STATE_IDLE          = 0
    STATE_WAIT_RESPONSE = 1
    STATE_WAIT_DATA     = 2
    STATE_WAIT_RESULT   = 3

    def __init__(self, host, port, ssl_ctx, loop=None):
        super().__init__(host, port, ssl_ctx, loop)

        self._state = self.STATE_IDLE
        self._promise = None
        self._data = io.BytesIO()

    def _send_request(self, request_data, request_builder):
        super()._send_request(request_data, request_builder)
        self._state = self.STATE_WAIT_RESPONSE

    def _fail_requests(self, exc):
        if self._promise is not None:
            self._promise.future.set_exception(exc)
            self._promise = None

    def _resend_requests(self):
        self._state = self.STATE_IDLE
        if self._promise is not None:
            data_size = self._data.tell()
            self._promise.request_data.offset += data_size
            self._promise.request_data.size -= data_size
            self._send_request(self._promise.request_data, self._promise.request_builder)

    def _has_requests(self):
        return self._state != self.STATE_IDLE

    def _process_data(self):
        if self._state in (self.STATE_WAIT_RESPONSE, self.STATE_WAIT_RESULT):
            header = StreamingResponseHeader.parse_stream(self._buffer)
            if header.result > 0:
                log.debug('Server error code received: %d', header.result)
                self._buffer = io.BytesIO(self._buffer.read())
                self._promise.future.set_exception(
                    make_server_error(self._host, self._port, self._promise.request_data, header.result)
                )
                self._promise = None
                self._data = io.BytesIO()
                self._state = self.STATE_IDLE
                return

            if self._state == self.STATE_WAIT_RESULT:
                log.debug('Streaming successfuly completed, bytes read=%d', self._data.tell())
                self._buffer = io.BytesIO(self._buffer.read())
                self._data.seek(0)
                self._promise.future.set_result(self._data.read() if not is_stream_empty(self._data) else None)
                self._promise = None
                self._data = io.BytesIO()
                self._state = self.STATE_IDLE
                return

            self._promise.request_parser.parse_stream(self._buffer)
            self._buffer = io.BytesIO(self._buffer.read())
            self._state = self.STATE_WAIT_DATA

        elif self._state == self.STATE_WAIT_DATA:
            chunk = InputChunk.parse_stream(self._buffer)
            self._buffer = io.BytesIO(self._buffer.read())
            log.debug('Chunk received, size=%d, is_last=%s', chunk.size, chunk.is_last)
            self._data.write(chunk.data)
            if chunk.is_last:
                self._state = self.STATE_WAIT_RESULT
        else:
            log.error('Data received in idle state. Disconnecting.')
            self.close()

    def execute_read_request(self, name, prefix, offset, size):
        if not self.connected:
            raise BaseError(self._host, self._port, 'Cannot execute read request on closed protocol')

        if self.busy:
            raise BaseError(self._host, self._port, 'Cannot execute read request on busy protocol')

        future = asyncio.Future(loop=self._loop)
        request_info = Requests[RequestID.ReadFile]
        request_data = construct.Container(name=name, prefix=prefix, offset=offset, size=size)

        self._promise = Promise(
            future=future,
            request_data=request_data,
            request_builder=request_info.builder_factory(),
            request_parser=request_info.parser_factory(),
        )

        self._send_request(self._promise.request_data, self._promise.request_builder)
        return future


class ClientProtocol(BaseProtocol):
    def __init__(self, host, port, ssl_ctx, loop=None):
        super().__init__(host, port, ssl_ctx, loop)

        self._serial = 0
        self._promises = dict()

    def _fail_requests(self, exc):
        while self._promises:
            _, promise = self._promises.popitem()
            promise.future.set_exception(exc)

    def _resend_requests(self):
        for promise in self._promises.values():
            self._send_request(promise.request_data, promise.request_builder)

    def _has_requests(self):
        return bool(self._promises)

    def _next_serial(self):
        self._serial += 1
        return self._serial

    def _process_data(self):
        header = ResponseHeader.parse_stream(self._buffer)

        promise = self._promises.get(header.serial)
        if promise is None:
            log.warning('Got response with unknown serial: %r', header)
            return self._transport.close()

        if header.result > 0:
            log.debug('Server error code received: %d', header.result)
            self._buffer = io.BytesIO(self._buffer.read())
            self._promises.pop(header.serial)
            return promise.future.set_exception(
                make_server_error(self._host, self._port, promise.request_data, header.result)
            )

        response = promise.request_parser.parse_stream(self._buffer)
        self._buffer = io.BytesIO(self._buffer.read())
        self._promises.pop(header.serial)
        promise.future.set_result(response)

    def execute_request(self, request_info, **request_params):
        future = asyncio.Future(loop=self._loop)
        serial = self._next_serial()
        promise = Promise(
            future=future,
            request_data=construct.Container(serial=serial, **request_params),
            request_builder=request_info.builder_factory(),
            request_parser=request_info.parser_factory()
        )
        self._promises[serial] = promise

        if self.connected:
            self._send_request(promise.request_data, promise.request_builder)

        return future


__all__ = ('InputStreamProtocol', 'ClientProtocol')
