# -*- coding: utf-8 -*-
# Copyright 2018, CS GROUP - France, https://www.csgroup.eu/
#
# This file is part of EODAG project
#     https://www.github.com/CS-SI/EODAG
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import logging
import os
import pathlib
import shutil
import tempfile
import zipfile

import geojson
import requests
from lxml import html
from shapely import geometry

from tests import EODagTestCase
from tests.context import (
    DEFAULT_STREAM_REQUESTS_TIMEOUT,
    NOT_AVAILABLE,
    USER_AGENT,
    DatasetDriver,
    Download,
    EOProduct,
    HTTPDownload,
    MisconfiguredError,
    ProgressCallback,
    config,
)
from tests.utils import mock


class TestEOProduct(EODagTestCase):
    NOT_ASSOCIATED_PRODUCT_TYPE = "EODAG_DOES_NOT_SUPPORT_THIS_PRODUCT_TYPE"

    def setUp(self):
        super(TestEOProduct, self).setUp()
        self.output_dir = tempfile.mkdtemp()

    def tearDown(self):
        super(TestEOProduct, self).tearDown()
        if os.path.isdir(self.output_dir):
            shutil.rmtree(self.output_dir)

    def get_mock_downloader(self):
        """Returns a mock downloader with a default configuration."""
        mock_downloader = mock.MagicMock(
            spec_set=Download(provider=self.provider, config=None)
        )
        mock_downloader.config = config.PluginConfig.from_mapping(
            {"type": "Foo", "output_dir": tempfile.gettempdir()}
        )
        return mock_downloader

    def test_eoproduct_search_intersection_geom(self):
        """EOProduct search_intersection attr must be it's geom when no bbox_or_intersect param given"""
        product = self._dummy_product()
        self.assertEqual(product.geometry, product.search_intersection)

    def test_eoproduct_default_geom(self):
        """EOProduct needs a geometry or can use confired defaultGeometry by default"""

        with self.assertRaisesRegex(MisconfiguredError, "No geometry available"):
            self._dummy_product(properties={"geometry": NOT_AVAILABLE})

        product = self._dummy_product(
            properties={"geometry": NOT_AVAILABLE, "defaultGeometry": (0, 0, 1, 1)}
        )
        self.assertEqual(product.geometry.bounds, (0.0, 0.0, 1.0, 1.0))

    def test_eoproduct_search_intersection_none(self):
        """EOProduct search_intersection attr must be None if shapely.errors.GEOSException when intersecting"""
        # Invalid geometry
        self.eoproduct_props["geometry"] = {
            "type": "Polygon",
            "coordinates": [
                [
                    [10.469970703124998, 3.9957805129630373],
                    [12.227783203124998, 4.740675384778385],
                    [12.095947265625, 4.061535597066097],
                    [10.491943359375, 4.412136788910175],
                    [10.469970703124998, 3.9957805129630373],
                ]
            ],
        }
        product = self._dummy_product(
            geometry=geometry.Polygon(
                (
                    (10.469970703124998, 3.9957805129630373),
                    (10.469970703124998, 4.740675384778385),
                    (12.227783203124998, 4.740675384778385),
                    (12.227783203124998, 3.9957805129630373),
                )
            ),
        )
        self.assertIsNone(product.search_intersection)

    def test_eoproduct_default_driver_unsupported_product_type(self):
        """EOProduct driver attr must be set even if its product type is not supported"""
        product = self._dummy_product(productType=self.NOT_ASSOCIATED_PRODUCT_TYPE)
        self.assertIsInstance(product.driver, DatasetDriver)

    def test_eoproduct_geointerface(self):
        """EOProduct must provide a geo-interface with a set of specific properties"""
        product = self._dummy_product()
        geo_interface = geojson.loads(geojson.dumps(product))
        self.assertEqual(geo_interface["type"], "Feature")
        self.assertEqual(
            geo_interface["geometry"],
            self._tuples_to_lists(geometry.mapping(self.geometry)),
        )
        properties = geo_interface["properties"]
        self.assertEqual(properties["eodag_provider"], self.provider)
        self.assertEqual(
            properties["eodag_search_intersection"],
            self._tuples_to_lists(geometry.mapping(product.search_intersection)),
        )
        self.assertEqual(properties["eodag_product_type"], self.product_type)

    def test_eoproduct_from_geointerface(self):
        """EOProduct must be build-able from its geo-interface"""
        product = self._dummy_product()
        same_product = EOProduct.from_geojson(geojson.loads(geojson.dumps(product)))
        self.assertSequenceEqual(
            [
                product.provider,
                product.location,
                product.properties["title"],
                product.properties["instrument"],
                self._tuples_to_lists(geometry.mapping(product.geometry)),
                self._tuples_to_lists(geometry.mapping(product.search_intersection)),
                product.product_type,
                product.properties["productType"],
                product.properties["platformSerialIdentifier"],
            ],
            [
                same_product.provider,
                same_product.location,
                same_product.properties["title"],
                same_product.properties["instrument"],
                self._tuples_to_lists(geometry.mapping(same_product.geometry)),
                self._tuples_to_lists(
                    geometry.mapping(same_product.search_intersection)
                ),
                same_product.product_type,
                same_product.properties["productType"],
                same_product.properties["platformSerialIdentifier"],
            ],
        )

    def test_eoproduct_get_quicklook_no_quicklook_url(self):
        """EOProduct.get_quicklook must return an empty string if no quicklook property"""  # noqa
        product = self._dummy_product()
        product.properties["quicklook"] = None

        quicklook_file_path = product.get_quicklook()
        self.assertEqual(quicklook_file_path, "")

    def test_eoproduct_get_quicklook_http_error(self):
        """EOProduct.get_quicklook must return an empty string if there was an error during retrieval"""  # noqa
        product = self._dummy_product()
        product.properties["quicklook"] = "https://fake.url.to/quicklook"

        self.requests_http_get.return_value.__enter__.return_value.raise_for_status.side_effect = (  # noqa
            requests.HTTPError
        )
        product.register_downloader(self.get_mock_downloader(), None)

        quicklook_file_path = product.get_quicklook()
        self.assertEqual(self.requests_http_get.call_count, 2)
        self.requests_http_get.assert_called_with(
            "https://fake.url.to/quicklook",
            stream=True,
            auth=None,
            headers=USER_AGENT,
            timeout=DEFAULT_STREAM_REQUESTS_TIMEOUT,
            verify=True,
        )
        self.assertEqual(quicklook_file_path, "")

    def test_eoproduct_get_quicklook_ok_without_auth(self):
        """EOProduct.get_quicklook must retrieve the quicklook without authentication."""  # noqa
        product = self._dummy_product()
        product.properties["quicklook"] = "https://fake.url.to/quicklook"

        self.requests_http_get.return_value.__enter__.return_value.raise_for_status.side_effect = (  # noqa
            requests.HTTPError,
            None,
        )
        product.register_downloader(self.get_mock_downloader(), None)

        quicklook_file_path = product.get_quicklook()
        self.assertEqual(self.requests_http_get.call_count, 2)
        self.requests_http_get.assert_called_with(
            "https://fake.url.to/quicklook",
            stream=True,
            auth=None,
            headers=USER_AGENT,
            timeout=DEFAULT_STREAM_REQUESTS_TIMEOUT,
            verify=True,
        )
        os.remove(quicklook_file_path)

    def test_eoproduct_get_quicklook_ok(self):
        """EOProduct.get_quicklook must return the path to the successfully downloaded quicklook"""  # noqa
        product = self._dummy_product()
        product.properties["quicklook"] = "https://fake.url.to/quicklook"

        self.requests_http_get.return_value = self._quicklook_response()
        product.register_downloader(self.get_mock_downloader(), None)

        quicklook_file_path = product.get_quicklook()
        self.requests_http_get.assert_called_once_with(
            "https://fake.url.to/quicklook",
            stream=True,
            auth=None,
            headers=USER_AGENT,
            timeout=DEFAULT_STREAM_REQUESTS_TIMEOUT,
            verify=True,
        )
        self.assertEqual(
            os.path.basename(quicklook_file_path), product.properties["id"]
        )
        self.assertEqual(
            os.path.dirname(quicklook_file_path),
            os.path.join(tempfile.gettempdir(), "quicklooks"),
        )
        os.remove(quicklook_file_path)

        # Test the same thing as above but with an explicit name given to the downloaded File
        quicklook_file_path = product.get_quicklook(filename="the_quicklook.png")
        self.requests_http_get.assert_called_with(
            "https://fake.url.to/quicklook",
            stream=True,
            auth=None,
            headers=USER_AGENT,
            timeout=DEFAULT_STREAM_REQUESTS_TIMEOUT,
            verify=True,
        )
        self.assertEqual(self.requests_http_get.call_count, 2)
        self.assertEqual(os.path.basename(quicklook_file_path), "the_quicklook.png")
        self.assertEqual(
            os.path.dirname(quicklook_file_path),
            os.path.join(tempfile.gettempdir(), "quicklooks"),
        )
        os.remove(quicklook_file_path)

        # Overall teardown
        os.rmdir(os.path.dirname(quicklook_file_path))

    def test_eoproduct_get_quicklook_ok_existing(self):
        """EOProduct.get_quicklook must return the path to an already downloaded quicklook"""  # noqa
        quicklook_dir = os.path.join(tempfile.gettempdir(), "quicklooks")
        quicklook_basename = "the_quicklook.png"
        existing_quicklook_file_path = os.path.join(quicklook_dir, quicklook_basename)
        if not os.path.exists(quicklook_dir):
            os.mkdir(quicklook_dir)
        with open(existing_quicklook_file_path, "wb") as fh:
            fh.write(b"content")
        product = self._dummy_product()
        product.properties["quicklook"] = "https://fake.url.to/quicklook"
        product.register_downloader(self.get_mock_downloader(), None)

        quicklook_file_path = product.get_quicklook(filename=quicklook_basename)
        self.assertEqual(self.requests_http_get.call_count, 0)
        self.assertEqual(quicklook_file_path, existing_quicklook_file_path)
        os.remove(existing_quicklook_file_path)
        os.rmdir(quicklook_dir)

    @staticmethod
    def _quicklook_response():
        class Response(object):
            """Emulation of a response to requests.get method for a quicklook"""

            def __init__(response):
                response.headers = {"content-length": 2**5}

            def __enter__(response):
                return response

            def __exit__(response, *args):
                pass

            @staticmethod
            def iter_content(**kwargs):
                with io.BytesIO(b"a" * 2**5) as fh:
                    while True:
                        chunk = fh.read(kwargs["chunk_size"])
                        if not chunk:
                            break
                        yield chunk

            def raise_for_status(response):
                pass

        return Response()

    def test_eoproduct_download_http_default(self):
        """eoproduct.download must save the product at output_dir and create a .downloaded dir"""  # noqa
        # Setup
        product = self._dummy_downloadable_product()
        with self.assertLogs(level="INFO") as cm:
            # Download
            product_dir_path = product.download()
            self.addCleanup(self._clean_product, product_dir_path)
            self.assertIn("Download url: %s" % product.remote_location, str(cm.output))
            self.assertIn(
                "Remote location of the product is still available", str(cm.output)
            )

        # Check that the mocked request was properly called.
        self.requests_request.assert_called_once()
        download_records_dir = pathlib.Path(product_dir_path).parent / ".downloaded"
        # A .downloaded folder should be created, including a text file that
        # lists the downloaded product by their url
        self.assertTrue(download_records_dir.is_dir())
        files_in_records_dir = list(download_records_dir.iterdir())
        self.assertEqual(len(files_in_records_dir), 1)
        records_file = files_in_records_dir[0]
        actual_download_url = records_file.read_text()
        self.assertEqual(actual_download_url, self.download_url)
        # Since extraction is True by default, check that the returned path is the
        # product's directory.
        self.assertTrue(os.path.isdir(product_dir_path))
        # Check that the ZIP file is still there
        product_dir_path = pathlib.Path(product_dir_path)
        product_zip = product_dir_path.parent / (product_dir_path.name + ".zip")
        self.assertTrue(zipfile.is_zipfile(product_zip))
        # check that product is not downloaded again
        with self.assertLogs(level="INFO") as cm:
            product.download()
            self.assertIn("Product already present on this platform", str(cm.output))
        # check that product is not downloaded again even if location has not been updated
        product.location = product.remote_location
        with self.assertLogs(level="INFO") as cm:
            product.download()
            self.assertIn("Product already downloaded", str(cm.output))
            self.assertIn(
                "Extraction cancelled, destination directory already exists",
                str(cm.output),
            )

    def test_eoproduct_download_http_delete_archive(self):
        """eoproduct.download must delete the downloaded archive"""  # noqa
        # Setup
        product = self._dummy_downloadable_product(delete_archive=True)
        try:
            # Download
            product_dir_path = product.download()
            # Check that the mocked request was properly called.
            self.requests_request.assert_called_once()
            # Check that the product's directory exists.
            self.assertTrue(os.path.isdir(product_dir_path))
            # Check that the ZIP file was deleted there
            _product_dir_path = pathlib.Path(product_dir_path)
            product_zip = _product_dir_path.parent / (_product_dir_path.name + ".zip")
            self.assertFalse(os.path.exists(product_zip))
            # check that product is not downloaded again even if location has not been updated
            product.location = product.remote_location
            with self.assertLogs(level="INFO") as cm:
                product.download()
                self.assertIn("Product already downloaded", str(cm.output))
                self.assertIn(
                    "Extraction cancelled, destination directory already exists",
                    str(cm.output),
                )
        finally:
            # Teardown
            self._clean_product(product_dir_path)

    def test_eoproduct_download_http_extract(self):
        """eoproduct.download over must be able to extract a product"""
        # Setup
        product = self._dummy_downloadable_product(extract=True)
        product_dir_path = product.download()
        self.addCleanup(self._clean_product, product_dir_path)
        product_dir_path = pathlib.Path(product_dir_path)
        # The returned path must be a directory.
        self.assertTrue(product_dir_path.is_dir())
        # Check that the extracted dir has at least one file, there are more
        # but that should be enough.
        self.assertGreaterEqual(len(list(product_dir_path.glob("**/*"))), 1)
        # The zip file should be around
        product_zip_file = product_dir_path.with_suffix(".zip")
        self.assertTrue(product_zip_file.is_file)

    # TODO: add a test on tarfiles extraction

    def test_eoproduct_download_http_dynamic_options(self):
        """eoproduct.download must accept the download options to be set automatically"""
        # Setup
        product = self._dummy_product()
        self._set_download_simulation()
        dl_config = config.PluginConfig.from_mapping(
            {
                "type": "HTTPDownload",
                "base_uri": "fake_base_uri",
                "output_dir": "will_be_overriden",
            }
        )
        downloader = HTTPDownload(provider=self.provider, config=dl_config)
        product.register_downloader(downloader, None)

        output_dir_name = "_testeodag"
        output_dir = pathlib.Path(tempfile.gettempdir()) / output_dir_name
        try:
            if output_dir.is_dir():
                shutil.rmtree(output_dir)
            output_dir.mkdir()

            # Download
            product_dir_path = product.download(
                output_dir=str(output_dir),
                extract=True,
                dl_url_params={"fakeparam": "dummy"},
            )
            # Check that dl_url_params are properly passed to the GET request
            self.requests_request.assert_called_once()
            # Check that "output_dir" is respected.
            product_dir_path = pathlib.Path(product_dir_path)
            self.assertEqual(product_dir_path.parent.name, output_dir_name)
            # We've asked to extract the product so there should be a directory.
            self.assertTrue(product_dir_path.is_dir())
            # Check that the extracted dir has at least one file, there are more
            # but that should be enough.
            self.assertGreaterEqual(len(list(product_dir_path.glob("**/*"))), 1)
            # The downloaded zip file is still around
            product_zip_file = product_dir_path.with_suffix(".zip")
            self.assertTrue(product_zip_file.is_file)
        finally:
            # Teardown (all the created files are within output_dir)
            shutil.rmtree(output_dir)

    def test_eoproduct_download_progress_bar(self):
        """eoproduct.download must show a progress bar"""
        product = self._dummy_downloadable_product()
        product.properties["id"] = 12345
        progress_callback = ProgressCallback()

        # progress bar did not start
        self.assertEqual(progress_callback.n, 0)

        # extract=true would replace bar desc with extraction status
        product.download(
            progress_callback=progress_callback,
            output_dir=self.output_dir,
            extract=False,
        )

        # should be product id cast to str
        self.assertEqual(progress_callback.desc, "12345")

        # progress bar finished
        self.assertEqual(progress_callback.n, progress_callback.total)
        self.assertGreater(progress_callback.total, 0)

    def test_eoproduct_register_downloader(self):
        """eoproduct.register_donwloader must set download and auth plugins"""
        product = self._dummy_product()

        self.assertIsNone(product.downloader)
        self.assertIsNone(product.downloader_auth)

        downloader = mock.MagicMock()
        downloader_auth = mock.MagicMock()

        product.register_downloader(downloader, downloader_auth)

        self.assertEqual(product.downloader, downloader)
        self.assertEqual(product.downloader_auth, downloader_auth)

    def test_eoproduct_register_downloader_resolve_ok(self):
        """eoproduct.register_donwloader must resolve locations and properties"""
        downloadable_product = self._dummy_downloadable_product(
            product=self._dummy_product(
                properties=dict(
                    self.eoproduct_props,
                    **{
                        "downloadLink": "%(base_uri)s/is/resolved",
                        "otherProperty": "%(output_dir)s/also/resolved",
                    },
                )
            )
        )
        self.assertEqual(
            downloadable_product.location,
            f"{downloadable_product.downloader.config.base_uri}/is/resolved",
        )
        self.assertEqual(
            downloadable_product.remote_location,
            f"{downloadable_product.downloader.config.base_uri}/is/resolved",
        )
        self.assertEqual(
            downloadable_product.properties["downloadLink"],
            f"{downloadable_product.downloader.config.base_uri}/is/resolved",
        )
        self.assertEqual(
            downloadable_product.properties["otherProperty"],
            f"{downloadable_product.downloader.config.output_dir}/also/resolved",
        )

    def test_eoproduct_register_downloader_resolve_ignored(self):
        """eoproduct.register_donwloader must ignore unresolvable locations and properties"""

        logger = logging.getLogger("eodag.product")
        with mock.patch.object(logger, "debug") as mock_debug:

            downloadable_product = self._dummy_downloadable_product(
                product=self._dummy_product(
                    properties=dict(
                        self.eoproduct_props,
                        **{
                            "downloadLink": "%(257B/cannot/be/resolved",
                            "otherProperty": "%(/%s/neither/resolved",
                        },
                    )
                )
            )
            self.assertEqual(downloadable_product.location, "%(257B/cannot/be/resolved")
            self.assertEqual(
                downloadable_product.remote_location, "%(257B/cannot/be/resolved"
            )
            self.assertEqual(
                downloadable_product.properties["downloadLink"],
                "%(257B/cannot/be/resolved",
            )
            self.assertEqual(
                downloadable_product.properties["otherProperty"],
                "%(/%s/neither/resolved",
            )

            needed_logs = [
                f"Could not resolve product.location ({downloadable_product.location})",
                f"Could not resolve product.remote_location ({downloadable_product.remote_location})",
                f"Could not resolve downloadLink property ({downloadable_product.properties['downloadLink']})",
                f"Could not resolve otherProperty property ({downloadable_product.properties['otherProperty']})",
            ]
            for needed_log in needed_logs:
                self.assertIn(needed_log, str(mock_debug.call_args_list))

    def test_eoproduct_repr_html(self):
        """eoproduct html repr must be correctly formatted"""
        product = self._dummy_product()
        product_repr = html.fromstring(product._repr_html_())
        self.assertIn("EOProduct", product_repr.xpath("//thead/tr/td")[0].text)

        # assets dict
        product.assets.update({"foo": {"href": "foo.href"}})
        assets_dict_repr = html.fromstring(product.assets._repr_html_())
        self.assertIn("AssetsDict", assets_dict_repr.xpath("//thead/tr/td")[0].text)

        # asset
        asset_repr = html.fromstring(product.assets._repr_html_())
        self.assertIn("Asset", asset_repr.xpath("//thead/tr/td")[0].text)

    def test_eoproduct_assets_get_values(self):
        """eoproduct.assets.get_values must return the expected values"""
        product = self._dummy_product()
        product.assets.update(
            {
                "foo": {"href": "foo.href"},
                "fooo": {"href": "fooo.href"},
                "foo?o,o": {"href": "foooo.href"},
            }
        )
        self.assertEqual(len(product.assets.get_values()), 3)
        self.assertEqual(len(product.assets.get_values("foo.*")), 3)
        self.assertEqual(len(product.assets.get_values("foo")), 1)
        self.assertEqual(product.assets.get_values("foo")[0]["href"], "foo.href")
        self.assertEqual(len(product.assets.get_values("foo?o,o")), 1)
        self.assertEqual(product.assets.get_values("foo?o,o")[0]["href"], "foooo.href")
