# @copyright (c) 2002-2016 Acronis International GmbH. All rights reserved.
# EULA: https://www.acronis.com/en-us/download/docs/eula/corporate/

import asyncio
import os
import ssl
from .protocol import InputStreamProtocol, ClientProtocol
from .requests import *


class InputStream:
    def __init__(self, name, prefix, size, host, port, ssl_ctx, loop=None):
        self._name = name
        self._prefix = prefix
        self._size = size
        self._offset = 0
        self._protocol = InputStreamProtocol(host, port, ssl_ctx, loop)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()

    def connect(self):
        return self._protocol.connect()

    def close(self):
        return self._protocol.close()

    def closed(self):
        return not self._protocol.connected()

    def size(self):
        return self._size

    def tell(self):
        return self._offset

    def seek(self, offset, whence=os.SEEK_SET):
        if whence == os.SEEK_CUR:
            new_offset = self._offset + offset
        elif whence == os.SEEK_END:
            new_offset = self._size + offset
        else:
            new_offset = offset

        if new_offset not in range(self._size):
            raise EOFError(name=self._name, prefix=self._prefix, size=self._size, offset=new_offset)

        self._offset = new_offset
        return self._offset

    @asyncio.coroutine
    def read(self, size=-1):
        max_size = self._size - self._offset
        if size < 0:
            size = max_size
        else:
            size = min(size, max_size)

        if size == 0:
            return None

        yield from self.connect()
        result = yield from self._protocol.execute_read_request(
            name=self._name,
            prefix=self._prefix,
            offset=self._offset,
            size=size
        )
        self._offset += size
        return result


class Client:
    def __init__(self, host, port, cert_file, loop=None):
        self._host = host
        self._port = port
        self._ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
        self._ssl_ctx.load_cert_chain(certfile=cert_file)
        self._loop = loop or asyncio.get_event_loop()
        self._protocol = ClientProtocol(self._host, self._port, self._ssl_ctx, self._loop)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()

    def connect(self):
        return self._protocol.connect()

    def close(self):
        return self._protocol.close()

    @asyncio.coroutine
    def open_input_stream(self, name, prefix=''):
        file_size_info = yield from self.get_file_size(name, prefix)
        return InputStream(name, prefix, file_size_info.size, self._host, self._port, self._ssl_ctx, self._loop)

    def rename_file(self, old_name, new_name, old_prefix='', new_prefix='', lock_id=0):
        return self._protocol.execute_request(
            Requests[RequestID.RenameFile],
            old_name=old_name,
            new_name=new_name,
            old_prefix=old_prefix,
            new_prefix=new_prefix,
            lock_id=lock_id
        )

    def swap_files(self, one_name, two_name, one_prefix='', two_prefix='', one_lock_id=0, two_lock_id=0):
        return self._protocol.execute_request(
            Requests[RequestID.SwapFiles],
            one_name=one_name,
            two_name=two_name,
            one_prefix=one_prefix,
            two_prefix=two_prefix,
            one_lock_id=one_lock_id,
            two_lock_id=two_lock_id
        )

    def delete_file(self, name, prefix='', lock_id=0):
        return self._protocol.execute_request(
            Requests[RequestID.DeleteFile],
            name=name,
            prefix=prefix,
            lock_id=lock_id
        )

    def get_file_size(self, name, prefix=''):
        return self._protocol.execute_request(
            Requests[RequestID.GetFileSize],
            name=name, prefix=prefix
        )

    def get_file_info(self, name, prefix=''):
        return self._protocol.execute_request(
            Requests[RequestID.GetFileInfo],
            name=name, prefix=prefix
        )

    def get_file_info_list(self, prefix=''):
        return self._protocol.execute_request(
            Requests[RequestID.GetFileInfoList],
            prefix=prefix
        )

    def lock_file(self, name, prefix='', lock_id=0, lock_type=LockType.NoLock):
        return self._protocol.execute_request(
            Requests[RequestID.LockFile],
            name=name,
            prefix=prefix,
            lock_id=lock_id,
            lock_type=lock_type
        )

    def unlock_file(self, name, prefix='', lock_id=0):
        return self._protocol.execute_request(
            Requests[RequestID.UnlockFile],
            name=name,
            prefix=prefix,
            lock_id=lock_id
        )

    def get_quota(self, prefix=''):
        return self._protocol.execute_request(
            Requests[RequestID.GetQuota],
            prefix=prefix
        )

    def get_allocation_info(self, name, prefix='', offset=0):
        return self._protocol.execute_request(
            Requests[RequestID.GetAllocationInfo],
            name=name,
            prefix=prefix,
            offset=offset
        )

    def extended_open_create(self, name, prefix='', flags=0, lock_id=0, lock_type=LockType.NoLock):
        return self._protocol.execute_request(
            Requests[RequestID.ExtendedOpenCreate],
            name=name,
            prefix=prefix,
            flags=flags,
            lock_id=lock_id,
            lock_type=lock_type
        )


@asyncio.coroutine
def create_client(host, port, cert_file, loop=None):
    client = Client(host, port, cert_file, loop)
    yield from client.connect()
    return client


__all__ = ('create_client', 'Client')
