from typing import List, Optional, Literal
import boto3
from enum import Enum
from . import pprint, console

from goose.config import GooseGlobalConfig


def create_boto3_session(gconf: GooseGlobalConfig):
    return boto3.Session(
        aws_access_key_id=gconf.aws.access_token,
        aws_secret_access_key=gconf.aws.access_secret.get_secret_value(),
        region_name=gconf.aws.default_region
    )


def get_key_if_exists(aws_keys: List, key_name: str) -> Optional[str]:
    for obj in aws_keys:
        if obj['Key'] == key_name:
            return obj['Value']
    return None


def list_ec2_and_eips(session: boto3.Session):
    # session = create_boto3_session(gconf)

    ec2_resource = session.resource('ec2')
    ec2_client = session.client('ec2')
    
    instances_and_ips = {}
    
    # for instance in ec2_resource.instances.filter(Filters=[{'Name': 'instance-state-name', 'Values': ['running']}]):
    for instance in ec2_resource.instances.filter(Filters=[]):
        # pprint(instance)

        instances_and_ips[instance.id] = {
            'instance_id': instance.id,
            'associated_ip': None,
            'running': instance.state['Name'],
            'project': get_key_if_exists(instance.tags, 'Project'),
            'name': get_key_if_exists(instance.tags, 'Name'),
            'hostname': get_key_if_exists(instance.tags, 'Hostname'),
            'instance_type': instance.instance_type
        }
        
        try:
            addresses = ec2_client.describe_addresses(Filters=[
                {'Name': 'instance-id', 'Values': [instance.id]}
            ])
            # pprint(addresses)
            if addresses and 'Addresses' in addresses:
                for address in addresses['Addresses']:
                    if 'PublicIp' in address:
                        instances_and_ips[instance.id]['associated_ip'] = address['PublicIp']
        except Exception as e:
            console.print(f"Error fetching Elastic IP for instance {instance.id}: {e}")
    
    # convert this from a dict to a list
    out = [v for _,v in instances_and_ips.items()]
    return out


def get_machine_with_name(machine_name: str, session: boto3.Session):
    hosts_and_ips = list_ec2_and_eips(session)

    # get the host from the list of hosts
    for host in hosts_and_ips:
        if host['name'].lower() == machine_name.lower():
            return host

    return None


def set_machine_power_state(machine_id: str, power_state: Literal['start'] | Literal['stop'], session: boto3.Session):
    ec2_resource = session.resource('ec2')
    ec2_client   = session.client('ec2')
    
    if not power_state in ['start', 'stop']:
        raise ValueError(f'"{power_state}" is not a valid power operation on machine id:{machine_id}')

    machine = ec2_resource.Instance(machine_id)

    if machine.state['Name'] == 'running':
        if power_state == 'start':
            console.print('[yellow]instance is already running!')
            return True

        ec2_client.stop_instances(InstanceIds=[machine_id])
        return True
        
    elif machine.state['Name'] == 'stopped':
        if power_state == 'stop':
            console.print('[yellow]instance is already stopped!')
            return True

        ec2_client.start_instances(InstanceIds=[machine_id])
        return True


def get_machine_power_state(machine_id: str, session: boto3.Session):
    ec2_resource = session.resource('ec2')
    machine = ec2_resource.Instance(machine_id)

    return machine.state['Name']
    

def list_all_hosted_zones(session: boto3.Session):
    # Initialize a session using Amazon Route 53
    # session = create_boto3_session(gconf)
    route53_client = session.client('route53')

    # Initialize the array to hold the list of hosted zones.
    hosted_zones = []
    
    # Pagination token
    marker = None
    
    while True:
        if marker:
            response = route53_client.list_hosted_zones(Marker=marker)
        else:
            response = route53_client.list_hosted_zones()
        
        for hosted_zone in response['HostedZones']:
            hosted_zones.append({
                'Id': hosted_zone['Id'],
                'Name': hosted_zone['Name'].rstrip('.'),
                'ResourceRecordSetCount': hosted_zone['ResourceRecordSetCount'],
            })

        if response.get('IsTruncated', False):
            marker = response['NextMarker']
        else:
            break

    return hosted_zones


def get_zone_id_with_name(name: str, session: boto3.Session) -> Optional[str]:
    all_zones = list_all_hosted_zones(session)

    for zone in all_zones:
        if zone['Name'] == name:
            return zone['Id']

    return None



def list_apex_records(hosted_zone_id: str, session: boto3.Session):
    route53_client = session.client('route53')

    # To hold the apex records
    apex_records = []

    # Pagination token
    next_record_name = None
    next_record_type = None
    next_record_identifier = None
    
    while True:
        # If we are paginating, include the name, type, and identifier of the last record in the response.
        if next_record_name:
            paginator = route53_client.list_resource_record_sets(
                HostedZoneId=hosted_zone_id,
                StartRecordName=next_record_name,
                StartRecordType=next_record_type,
                StartRecordIdentifier=next_record_identifier
            )
        else:
            paginator = route53_client.list_resource_record_sets(
                HostedZoneId=hosted_zone_id
            )
        
        for record_set in paginator['ResourceRecordSets']:
            # Filter apex records (record name equals the hosted zone name).
            if record_set['Type'] in ['A', 'AAAA']:
                apex_records.append(record_set)
        
        # Check if there are more records to paginate through.
        if paginator['IsTruncated']:
            next_record_name = paginator['NextRecordName']
            next_record_type = paginator['NextRecordType']
            next_record_identifier = paginator.get('NextRecordIdentifier')  # This might not always be present
        else:
            break
    
    return apex_records


def tag_ec2_instance(instance_id: str, tag_key: str, tag_value: str, session: boto3.Session):
    ec2 = session.resource('ec2')
    instance = ec2.Instance(instance_id)

    instance.create_tags(
        Tags=[
            {
                'Key': tag_key,
                'Value': tag_value
            }
        ]
    )


def untag_ec2_instance(instance_id: str, tag_key: str, tag_value: str, session: boto3.Session):
    ec2 = session.resource('ec2')
    instance = ec2.Instance(instance_id)

    instance.delete_tags(
        Tags=[
            {
                'Key': tag_key,
                'Value': tag_value
            }
        ]
    )


def create_elastic_ip(instance_id: str, instance_name: str, session: boto3.Session):
    ec2_client = session.client('ec2')

    # Allocate a new Elastic IP
    allocation = ec2_client.allocate_address(Domain='vpc')
    console.print('Elastic IP allocated:', allocation['PublicIp'])

    # Associate the allocated Elastic IP with the specified EC2 instance
    association = ec2_client.associate_address(
        AllocationId=allocation['AllocationId'],
        InstanceId=instance_id
    )

    console.print('Elastic IP associated with instance:', association['AssociationId'])

    ec2_client.create_tags(
        Resources=[
            allocation['AllocationId']
        ],
        Tags=[
            {
                'Key': 'Name',
                'Value': f'{instance_name}-eip'
            },
        ]
    )


    return allocation


def apply_dns_changes(hosted_zone_id: str, changes: List, comment: str, session: boto3.Session):
    client = boto3.client('route53')

    response = client.change_resource_record_sets(
        HostedZoneId=hosted_zone_id,
        ChangeBatch={
            'Comment': comment,
            'Changes': changes
        }
    )

    return response

