# Copyright 2009-2015 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for the Binary wrapper."""

import base64
import copy
import pickle
import sys
import uuid

sys.path[0:0] = [""]

import bson

from bson.binary import *
from bson.codec_options import CodecOptions
from bson.py3compat import u
from bson.son import SON
from test import client_context, unittest
from pymongo.mongo_client import MongoClient


class TestBinary(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        # Generated by the Java driver
        from_java = (
            b'bAAAAAdfaWQAUCBQxkVm+XdxJ9tOBW5ld2d1aWQAEAAAAAMIQkfACFu'
            b'Z/0RustLOU/G6Am5ld2d1aWRzdHJpbmcAJQAAAGZmOTk1YjA4LWMwND'
            b'ctNDIwOC1iYWYxLTUzY2VkMmIyNmU0NAAAbAAAAAdfaWQAUCBQxkVm+'
            b'XdxJ9tPBW5ld2d1aWQAEAAAAANgS/xhRXXv8kfIec+dYdyCAm5ld2d1'
            b'aWRzdHJpbmcAJQAAAGYyZWY3NTQ1LTYxZmMtNGI2MC04MmRjLTYxOWR'
            b'jZjc5Yzg0NwAAbAAAAAdfaWQAUCBQxkVm+XdxJ9tQBW5ld2d1aWQAEA'
            b'AAAAPqREIbhZPUJOSdHCJIgaqNAm5ld2d1aWRzdHJpbmcAJQAAADI0Z'
            b'DQ5Mzg1LTFiNDItNDRlYS04ZGFhLTgxNDgyMjFjOWRlNAAAbAAAAAdf'
            b'aWQAUCBQxkVm+XdxJ9tRBW5ld2d1aWQAEAAAAANjQBn/aQuNfRyfNyx'
            b'29COkAm5ld2d1aWRzdHJpbmcAJQAAADdkOGQwYjY5LWZmMTktNDA2My'
            b'1hNDIzLWY0NzYyYzM3OWYxYwAAbAAAAAdfaWQAUCBQxkVm+XdxJ9tSB'
            b'W5ld2d1aWQAEAAAAAMtSv/Et1cAQUFHUYevqxaLAm5ld2d1aWRzdHJp'
            b'bmcAJQAAADQxMDA1N2I3LWM0ZmYtNGEyZC04YjE2LWFiYWY4NzUxNDc'
            b'0MQAA')
        cls.java_data = base64.b64decode(from_java)

        # Generated by the .net driver
        from_csharp = (
            b'ZAAAABBfaWQAAAAAAAVuZXdndWlkABAAAAAD+MkoCd/Jy0iYJ7Vhl'
            b'iF3BAJuZXdndWlkc3RyaW5nACUAAAAwOTI4YzlmOC1jOWRmLTQ4Y2'
            b'ItOTgyNy1iNTYxOTYyMTc3MDQAAGQAAAAQX2lkAAEAAAAFbmV3Z3V'
            b'pZAAQAAAAA9MD0oXQe6VOp7mK4jkttWUCbmV3Z3VpZHN0cmluZwAl'
            b'AAAAODVkMjAzZDMtN2JkMC00ZWE1LWE3YjktOGFlMjM5MmRiNTY1A'
            b'ABkAAAAEF9pZAACAAAABW5ld2d1aWQAEAAAAAPRmIO2auc/Tprq1Z'
            b'oQ1oNYAm5ld2d1aWRzdHJpbmcAJQAAAGI2ODM5OGQxLWU3NmEtNGU'
            b'zZi05YWVhLWQ1OWExMGQ2ODM1OAAAZAAAABBfaWQAAwAAAAVuZXdn'
            b'dWlkABAAAAADISpriopuTEaXIa7arYOCFAJuZXdndWlkc3RyaW5nA'
            b'CUAAAA4YTZiMmEyMS02ZThhLTQ2NGMtOTcyMS1hZWRhYWQ4MzgyMT'
            b'QAAGQAAAAQX2lkAAQAAAAFbmV3Z3VpZAAQAAAAA98eg0CFpGlPihP'
            b'MwOmYGOMCbmV3Z3VpZHN0cmluZwAlAAAANDA4MzFlZGYtYTQ4NS00'
            b'ZjY5LThhMTMtY2NjMGU5OTgxOGUzAAA=')
        cls.csharp_data = base64.b64decode(from_csharp)

    def test_binary(self):
        a_string = "hello world"
        a_binary = Binary(b"hello world")
        self.assertTrue(a_binary.startswith(b"hello"))
        self.assertTrue(a_binary.endswith(b"world"))
        self.assertTrue(isinstance(a_binary, Binary))
        self.assertFalse(isinstance(a_string, Binary))

    def test_exceptions(self):
        self.assertRaises(TypeError, Binary, None)
        self.assertRaises(TypeError, Binary, u("hello"))
        self.assertRaises(TypeError, Binary, 5)
        self.assertRaises(TypeError, Binary, 10.2)
        self.assertRaises(TypeError, Binary, b"hello", None)
        self.assertRaises(TypeError, Binary, b"hello", "100")
        self.assertRaises(ValueError, Binary, b"hello", -1)
        self.assertRaises(ValueError, Binary, b"hello", 256)
        self.assertTrue(Binary(b"hello", 0))
        self.assertTrue(Binary(b"hello", 255))

    def test_subtype(self):
        one = Binary(b"hello")
        self.assertEqual(one.subtype, 0)
        two = Binary(b"hello", 2)
        self.assertEqual(two.subtype, 2)
        three = Binary(b"hello", 100)
        self.assertEqual(three.subtype, 100)

    def test_equality(self):
        two = Binary(b"hello")
        three = Binary(b"hello", 100)
        self.assertNotEqual(two, three)
        self.assertEqual(three, Binary(b"hello", 100))
        self.assertEqual(two, Binary(b"hello"))
        self.assertNotEqual(two, Binary(b"hello "))
        self.assertNotEqual(b"hello", Binary(b"hello"))

        # Explicitly test inequality
        self.assertFalse(three != Binary(b"hello", 100))
        self.assertFalse(two != Binary(b"hello"))

    def test_repr(self):
        one = Binary(b"hello world")
        self.assertEqual(repr(one),
                         "Binary(%s, 0)" % (repr(b"hello world"),))
        two = Binary(b"hello world", 2)
        self.assertEqual(repr(two),
                         "Binary(%s, 2)" % (repr(b"hello world"),))
        three = Binary(b"\x08\xFF")
        self.assertEqual(repr(three),
                         "Binary(%s, 0)" % (repr(b"\x08\xFF"),))
        four = Binary(b"\x08\xFF", 2)
        self.assertEqual(repr(four),
                         "Binary(%s, 2)" % (repr(b"\x08\xFF"),))
        five = Binary(b"test", 100)
        self.assertEqual(repr(five),
                         "Binary(%s, 100)" % (repr(b"test"),))

    def test_hash(self):
        one = Binary(b"hello world")
        two = Binary(b"hello world", 42)
        self.assertEqual(hash(Binary(b"hello world")), hash(one))
        self.assertNotEqual(hash(one), hash(two))
        self.assertEqual(hash(Binary(b"hello world", 42)), hash(two))

    def test_legacy_java_uuid(self):
        # Test decoding
        data = self.java_data
        docs = bson.decode_all(data, CodecOptions(SON, False, PYTHON_LEGACY))
        for d in docs:
            self.assertNotEqual(d['newguid'], uuid.UUID(d['newguidstring']))

        docs = bson.decode_all(data, CodecOptions(SON, False, STANDARD))
        for d in docs:
            self.assertNotEqual(d['newguid'], uuid.UUID(d['newguidstring']))

        docs = bson.decode_all(data, CodecOptions(SON, False, CSHARP_LEGACY))
        for d in docs:
            self.assertNotEqual(d['newguid'], uuid.UUID(d['newguidstring']))

        docs = bson.decode_all(data, CodecOptions(SON, False, JAVA_LEGACY))
        for d in docs:
            self.assertEqual(d['newguid'], uuid.UUID(d['newguidstring']))

        # Test encoding
        encoded = b''.join([
            bson.BSON.encode(doc,
                             False,
                             CodecOptions(uuid_representation=PYTHON_LEGACY))
            for doc in docs])
        self.assertNotEqual(data, encoded)

        encoded = b''.join(
            [bson.BSON.encode(doc,
                              False,
                              CodecOptions(uuid_representation=STANDARD))
             for doc in docs])
        self.assertNotEqual(data, encoded)

        encoded = b''.join(
            [bson.BSON.encode(doc,
                              False,
                              CodecOptions(uuid_representation=CSHARP_LEGACY))
             for doc in docs])
        self.assertNotEqual(data, encoded)

        encoded = b''.join(
            [bson.BSON.encode(doc,
                              False,
                              CodecOptions(uuid_representation=JAVA_LEGACY))
             for doc in docs])
        self.assertEqual(data, encoded)

    @client_context.require_connection
    def test_legacy_java_uuid_roundtrip(self):
        data = self.java_data
        docs = bson.decode_all(data, CodecOptions(SON, False, JAVA_LEGACY))

        client_context.client.pymongo_test.drop_collection('java_uuid')
        db = client_context.client.pymongo_test
        coll = db.get_collection(
            'java_uuid', CodecOptions(uuid_representation=JAVA_LEGACY))

        coll.insert_many(docs)
        self.assertEqual(5, coll.count())
        for d in coll.find():
            self.assertEqual(d['newguid'], uuid.UUID(d['newguidstring']))

        coll = db.get_collection(
            'java_uuid', CodecOptions(uuid_representation=PYTHON_LEGACY))
        for d in coll.find():
            self.assertNotEqual(d['newguid'], d['newguidstring'])
        client_context.client.pymongo_test.drop_collection('java_uuid')

    def test_legacy_csharp_uuid(self):
        data = self.csharp_data

        # Test decoding
        docs = bson.decode_all(data, CodecOptions(SON, False, PYTHON_LEGACY))
        for d in docs:
            self.assertNotEqual(d['newguid'], uuid.UUID(d['newguidstring']))

        docs = bson.decode_all(data, CodecOptions(SON, False, STANDARD))
        for d in docs:
            self.assertNotEqual(d['newguid'], uuid.UUID(d['newguidstring']))

        docs = bson.decode_all(data, CodecOptions(SON, False, JAVA_LEGACY))
        for d in docs:
            self.assertNotEqual(d['newguid'], uuid.UUID(d['newguidstring']))

        docs = bson.decode_all(data, CodecOptions(SON, False, CSHARP_LEGACY))
        for d in docs:
            self.assertEqual(d['newguid'], uuid.UUID(d['newguidstring']))

        # Test encoding
        encoded = b''.join([
            bson.BSON.encode(doc,
                             False,
                             CodecOptions(uuid_representation=PYTHON_LEGACY))
            for doc in docs])
        self.assertNotEqual(data, encoded)

        encoded = b''.join([
            bson.BSON.encode(doc,
                             False,
                             CodecOptions(uuid_representation=STANDARD))
            for doc in docs])
        self.assertNotEqual(data, encoded)

        encoded = b''.join(
            [bson.BSON.encode(doc,
                              False,
                              CodecOptions(uuid_representation=JAVA_LEGACY))
             for doc in docs])
        self.assertNotEqual(data, encoded)

        encoded = b''.join(
            [bson.BSON.encode(doc,
                              False,
                              CodecOptions(uuid_representation=CSHARP_LEGACY))
             for doc in docs])
        self.assertEqual(data, encoded)

    @client_context.require_connection
    def test_legacy_csharp_uuid_roundtrip(self):
        data = self.csharp_data
        docs = bson.decode_all(data, CodecOptions(SON, False, CSHARP_LEGACY))

        client_context.client.pymongo_test.drop_collection('csharp_uuid')
        db = client_context.client.pymongo_test
        coll = db.get_collection(
            'csharp_uuid', CodecOptions(uuid_representation=CSHARP_LEGACY))

        coll.insert_many(docs)
        self.assertEqual(5, coll.count())
        for d in coll.find():
            self.assertEqual(d['newguid'], uuid.UUID(d['newguidstring']))

        coll = db.get_collection(
            'csharp_uuid', CodecOptions(uuid_representation=PYTHON_LEGACY))
        for d in coll.find():
            self.assertNotEqual(d['newguid'], d['newguidstring'])
        client_context.client.pymongo_test.drop_collection('csharp_uuid')

    def test_uri_to_uuid(self):

        uri = "mongodb://foo/?uuidrepresentation=csharpLegacy"
        client = MongoClient(uri, connect=False)
        self.assertEqual(
            client.pymongo_test.test.codec_options.uuid_representation,
            CSHARP_LEGACY)

    @client_context.require_connection
    def test_uuid_queries(self):

        db = client_context.client.pymongo_test
        coll = db.test
        coll.drop()

        uu = uuid.uuid4()
        coll.insert_one({'uuid': Binary(uu.bytes, 3)})
        self.assertEqual(1, coll.count())

        # Test UUIDLegacy queries.
        coll = db.get_collection("test",
                                 CodecOptions(uuid_representation=STANDARD))
        self.assertEqual(0, coll.find({'uuid': uu}).count())
        cur = coll.find({'uuid': UUIDLegacy(uu)})
        self.assertEqual(1, cur.count())
        retrieved = next(cur)
        self.assertEqual(uu, retrieved['uuid'])

        # Test regular UUID queries (using subtype 4).
        coll.insert_one({'uuid': uu})
        self.assertEqual(2, coll.count())
        cur = coll.find({'uuid': uu})
        self.assertEqual(1, cur.count())
        retrieved = next(cur)
        self.assertEqual(uu, retrieved['uuid'])

        # Test both.
        cur = coll.find({'uuid': {'$in': [uu, UUIDLegacy(uu)]}})
        self.assertEqual(2, cur.count())
        coll.drop()

    def test_pickle(self):
        b1 = Binary(b'123', 2)

        # For testing backwards compatibility with pre-2.4 pymongo
        if PY3:
            p = (b"\x80\x03cbson.binary\nBinary\nq\x00C\x03123q\x01\x85q"
                 b"\x02\x81q\x03}q\x04X\x10\x00\x00\x00_Binary__subtypeq"
                 b"\x05K\x02sb.")
        else:
            p = (b"ccopy_reg\n_reconstructor\np0\n(cbson.binary\nBinary\np1\nc"
                 b"__builtin__\nstr\np2\nS'123'\np3\ntp4\nRp5\n(dp6\nS'_Binary"
                 b"__subtype'\np7\nI2\nsb.")

        if not sys.version.startswith('3.0'):
            self.assertEqual(b1, pickle.loads(p))

        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
            self.assertEqual(b1, pickle.loads(pickle.dumps(b1, proto)))

        uu = uuid.uuid4()
        uul = UUIDLegacy(uu)

        self.assertEqual(uul, copy.copy(uul))
        self.assertEqual(uul, copy.deepcopy(uul))

        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
            self.assertEqual(uul, pickle.loads(pickle.dumps(uul, proto)))


if __name__ == "__main__":
    unittest.main()
