import { NgTemplateOutlet } from '@angular/common';
import {
	ChangeDetectionStrategy,
	Component,
	computed,
	contentChild,
	DestroyRef,
	Directive,
	effect,
	inject,
	input,
	signal,
	TemplateRef,
	untracked,
} from '@angular/core';
import RAPIER, { ColliderHandle, EventQueue, PhysicsHooks, Rotation, Vector, World } from '@dimforge/rapier3d-compat';
import { injectStore, pick, vector3 } from 'angular-three';
import { mergeInputs } from 'ngxtension/inject-inputs';
import * as THREE from 'three';
import { NgtrDebug } from './debug';
import { NgtrFrameStepper } from './frame-stepper';
import { _matrix4, _position, _rotation, _scale } from './shared';
import {
	NgtrColliderStateMap,
	NgtrCollisionPayload,
	NgtrCollisionSource,
	NgtrEventMap,
	NgtrFilterContactPairCallback,
	NgtrFilterIntersectionPairCallback,
	NgtrPhysicsOptions,
	NgtrRigidBodyStateMap,
	NgtrWorldStepCallback,
} from './types';
import { createSingletonProxy, rapierQuaternionToQuaternion } from './utils';

const defaultOptions: NgtrPhysicsOptions = {
	gravity: [0, -9.81, 0],
	allowedLinearError: 0.001,
	numSolverIterations: 4,
	numInternalPgsIterations: 1,
	predictionDistance: 0.002,
	minIslandSize: 128,
	maxCcdSubsteps: 1,
	contactNaturalFrequency: 30,
	lengthUnit: 1,
	colliders: 'cuboid',
	updateLoop: 'follow',
	interpolate: true,
	paused: false,
	timeStep: 1 / 60,
	debug: false,
};

/**
 * Directive for providing a fallback template when Rapier fails to load.
 *
 * @example
 * ```html
 * <ngtr-physics>
 *   <ng-template>
 *     <!-- Physics scene content -->
 *   </ng-template>
 *   <ng-template rapierFallback let-error="error">
 *     <p>Failed to load physics: {{ error }}</p>
 *   </ng-template>
 * </ngtr-physics>
 * ```
 */
@Directive({ selector: 'ng-template[rapierFallback]' })
export class NgtrPhysicsFallback {
	/** Type guard for template context */
	static ngTemplateContextGuard(_: NgtrPhysicsFallback, ctx: unknown): ctx is { error: string } {
		return true;
	}
}

/**
 * Main physics component that creates and manages a Rapier physics world.
 * Wrap your 3D scene content in this component to enable physics simulation.
 *
 * The component lazily loads the Rapier WASM module and provides the physics
 * context to all child components.
 *
 * @example
 * ```html
 * <ngtr-physics [options]="{ gravity: [0, -9.81, 0], debug: true }">
 *   <ng-template>
 *     <ngt-object3D rigidBody>
 *       <ngt-mesh>
 *         <ngt-box-geometry />
 *         <ngt-mesh-standard-material />
 *       </ngt-mesh>
 *     </ngt-object3D>
 *   </ng-template>
 * </ngtr-physics>
 * ```
 */
@Component({
	selector: 'ngtr-physics',
	template: `
		@let _rapierError = rapierError();
		@let _fallbackContent = fallbackContent();

		@if (rapierConstruct()) {
			@if (debug()) {
				<ngtr-debug />
			}

			<ngtr-frame-stepper
				[ready]="ready()"
				[stepFn]="step.bind(this)"
				[type]="updateLoop()"
				[updatePriority]="updatePriority()"
			/>

			<ng-container [ngTemplateOutlet]="content()" />
		} @else if (_rapierError && _fallbackContent) {
			<ng-container [ngTemplateOutlet]="_fallbackContent" [ngTemplateOutletContext]="{ error: _rapierError }" />
		}
	`,
	changeDetection: ChangeDetectionStrategy.OnPush,
	imports: [NgtrDebug, NgtrFrameStepper, NgTemplateOutlet],
})
export class NgtrPhysics {
	/** Physics configuration options */
	options = input(defaultOptions, { transform: mergeInputs(defaultOptions) });

	protected content = contentChild.required(TemplateRef);
	protected fallbackContent = contentChild(NgtrPhysicsFallback, { read: TemplateRef });

	protected updatePriority = pick(this.options, 'updatePriority');
	protected updateLoop = pick(this.options, 'updateLoop');

	private numSolverIterations = pick(this.options, 'numSolverIterations');
	private numInternalPgsIterations = pick(this.options, 'numInternalPgsIterations');
	private allowedLinearError = pick(this.options, 'allowedLinearError');
	private minIslandSize = pick(this.options, 'minIslandSize');
	private maxCcdSubsteps = pick(this.options, 'maxCcdSubsteps');
	private predictionDistance = pick(this.options, 'predictionDistance');
	private contactNaturalFrequency = pick(this.options, 'contactNaturalFrequency');
	private lengthUnit = pick(this.options, 'lengthUnit');
	private timeStep = pick(this.options, 'timeStep');
	private interpolate = pick(this.options, 'interpolate');

	/** Whether the physics simulation is paused */
	paused = pick(this.options, 'paused');
	protected debug = pick(this.options, 'debug');
	/** The default collider type for automatic collider generation */
	colliders = pick(this.options, 'colliders');

	private vGravity = vector3(this.options, 'gravity');

	private store = injectStore();

	protected rapierConstruct = signal<typeof RAPIER | null>(null);
	protected rapierError = signal<string | null>(null);
	/** The loaded Rapier module, null if not yet loaded */
	rapier = this.rapierConstruct.asReadonly();

	protected ready = computed(() => !!this.rapier());
	/** Singleton proxy to the Rapier physics world */
	worldSingleton = computed(() => {
		const rapier = this.rapier();
		if (!rapier) return null;
		return createSingletonProxy<World>(() => new rapier.World(untracked(this.vGravity)));
	});

	/** Map of rigid body states indexed by handle */
	rigidBodyStates: NgtrRigidBodyStateMap = new Map();
	/** Map of collider states indexed by handle */
	colliderStates: NgtrColliderStateMap = new Map();
	/** Map of rigid body event handlers indexed by handle */
	rigidBodyEvents: NgtrEventMap = new Map();
	/** Map of collider event handlers indexed by handle */
	colliderEvents: NgtrEventMap = new Map();
	/** Callbacks to run before each physics step */
	beforeStepCallbacks = new Set<NgtrWorldStepCallback>();
	/** Callbacks to run after each physics step */
	afterStepCallbacks = new Set<NgtrWorldStepCallback>();
	/** Callbacks to filter contact pairs */
	filterContactPairCallbacks = new Set<NgtrFilterContactPairCallback>();
	/** Callbacks to filter intersection pairs */
	filterIntersectionPairCallbacks = new Set<NgtrFilterIntersectionPairCallback>();

	private hooks: PhysicsHooks = {
		filterContactPair: (...args) => {
			for (const callback of this.filterContactPairCallbacks) {
				const result = callback(...args);
				if (result !== null) return result;
			}
			return null;
		},
		filterIntersectionPair: (...args) => {
			for (const callback of this.filterIntersectionPairCallbacks) {
				const result = callback(...args);
				if (result === false) return false;
			}
			return true;
		},
	};

	private eventQueue = computed(() => {
		const rapier = this.rapier();
		if (!rapier) return null;
		return new EventQueue(false);
	});

	private steppingState: {
		accumulator: number;
		previousState: Record<number, { position: Vector; rotation: Rotation }>;
	} = { accumulator: 0, previousState: {} };

	constructor() {
		import('@dimforge/rapier3d-compat')
			.then((rapier) => rapier.init().then(() => rapier))
			.then(this.rapierConstruct.set.bind(this.rapierConstruct))
			.catch((err) => {
				console.error(`[NGT] Failed to load rapier3d-compat`, err);
				this.rapierError.set(err?.message ?? err.toString());
			});

		effect(() => {
			this.updateWorldEffect();
		});

		inject(DestroyRef).onDestroy(() => {
			const world = this.worldSingleton();
			if (world) {
				world.proxy.free();
				world.reset();
			}
		});
	}

	/**
	 * Steps the physics simulation forward by the given delta time.
	 * This is called automatically by the frame stepper, but can be called manually
	 * if you need custom control over the simulation timing.
	 *
	 * @param delta - Time in seconds since the last step
	 */
	step(delta: number) {
		if (!this.paused()) {
			this.internalStep(delta);
		}
	}

	private updateWorldEffect() {
		const world = this.worldSingleton();
		if (!world) return;

		world.proxy.gravity = this.vGravity();
		world.proxy.integrationParameters.numSolverIterations = this.numSolverIterations();
		world.proxy.integrationParameters.numInternalPgsIterations = this.numInternalPgsIterations();
		world.proxy.integrationParameters.normalizedAllowedLinearError = this.allowedLinearError();
		world.proxy.integrationParameters.minIslandSize = this.minIslandSize();
		world.proxy.integrationParameters.maxCcdSubsteps = this.maxCcdSubsteps();
		world.proxy.integrationParameters.normalizedPredictionDistance = this.predictionDistance();
		world.proxy.integrationParameters.contact_natural_frequency = this.contactNaturalFrequency();
		world.proxy.lengthUnit = this.lengthUnit();
	}

	private internalStep(delta: number) {
		const worldSingleton = this.worldSingleton();
		if (!worldSingleton) return;

		const eventQueue = this.eventQueue();
		if (!eventQueue) return;

		const world = worldSingleton.proxy;
		const [timeStep, interpolate, paused] = [this.timeStep(), this.interpolate(), this.paused()];

		/* Check if the timestep is supposed to be variable. We'll do this here
      once so we don't have to string-check every frame. */
		const timeStepVariable = timeStep === 'vary';

		/**
		 * Fixed timeStep simulation progression.
		 * @see https://gafferongames.com/post/fix_your_timestep/
		 */
		const clampedDelta = THREE.MathUtils.clamp(delta, 0, 0.5);

		const stepWorld = (innerDelta: number) => {
			// Trigger beforeStep callbacks
			this.beforeStepCallbacks.forEach((callback) => {
				callback(world);
			});

			world.timestep = innerDelta;
			const hasHooks = this.filterContactPairCallbacks.size > 0 || this.filterIntersectionPairCallbacks.size > 0;
			world.step(eventQueue, hasHooks ? this.hooks : undefined);

			// Trigger afterStep callbacks
			this.afterStepCallbacks.forEach((callback) => {
				callback(world);
			});
		};

		if (timeStepVariable) {
			stepWorld(clampedDelta);
		} else {
			// don't step time forwards if paused
			// Increase accumulator
			this.steppingState.accumulator += clampedDelta;

			while (this.steppingState.accumulator >= timeStep) {
				// Set up previous state
				// needed for accurate interpolations if the world steps more than once
				if (interpolate) {
					this.steppingState.previousState = {};
					world.forEachRigidBody((body) => {
						this.steppingState.previousState[body.handle] = {
							position: body.translation(),
							rotation: body.rotation(),
						};
					});
				}

				stepWorld(timeStep);
				this.steppingState.accumulator -= timeStep;
			}
		}

		const interpolationAlpha =
			timeStepVariable || !interpolate || paused ? 1 : this.steppingState.accumulator / timeStep;

		// Update meshes
		this.rigidBodyStates.forEach((state, handle) => {
			const rigidBody = world.getRigidBody(handle);

			const events = this.rigidBodyEvents.get(handle);
			if (events?.onSleep || events?.onWake) {
				if (rigidBody.isSleeping() && !state.isSleeping) events?.onSleep?.();
				if (!rigidBody.isSleeping() && state.isSleeping) events?.onWake?.();
				state.isSleeping = rigidBody.isSleeping();
			}

			if (!rigidBody || (rigidBody.isSleeping() && !('isInstancedMesh' in state.object)) || !state.setMatrix) {
				return;
			}

			// New states
			let t = rigidBody.translation() as THREE.Vector3;
			let r = rigidBody.rotation() as THREE.Quaternion;

			let previousState = this.steppingState.previousState[handle];

			if (previousState) {
				// Get previous simulated world position
				_matrix4
					.compose(
						previousState.position as THREE.Vector3,
						rapierQuaternionToQuaternion(previousState.rotation),
						state.scale,
					)
					.premultiply(state.invertedWorldMatrix)
					.decompose(_position, _rotation, _scale);

				// Apply previous tick position
				if (state.meshType == 'mesh') {
					state.object.position.copy(_position);
					state.object.quaternion.copy(_rotation);
				}
			}

			// Get new position
			_matrix4
				.compose(t, rapierQuaternionToQuaternion(r), state.scale)
				.premultiply(state.invertedWorldMatrix)
				.decompose(_position, _rotation, _scale);

			if (state.meshType == 'instancedMesh') {
				state.setMatrix(_matrix4);
			} else {
				// Interpolate to new position
				state.object.position.lerp(_position, interpolationAlpha);
				state.object.quaternion.slerp(_rotation, interpolationAlpha);
			}
		});

		eventQueue.drainCollisionEvents((handle1, handle2, started) => {
			const source1 = this.getSourceFromColliderHandle(handle1);
			const source2 = this.getSourceFromColliderHandle(handle2);

			// Collision Events
			if (!source1?.collider.object || !source2?.collider.object) {
				return;
			}

			const collisionPayload1 = this.getCollisionPayloadFromSource(source1, source2);
			const collisionPayload2 = this.getCollisionPayloadFromSource(source2, source1);

			if (started) {
				world.contactPair(source1.collider.object, source2.collider.object, (manifold, flipped) => {
					/* RigidBody events */
					source1.rigidBody.events?.onCollisionEnter?.({ ...collisionPayload1, manifold, flipped });
					source2.rigidBody.events?.onCollisionEnter?.({ ...collisionPayload2, manifold, flipped });

					/* Collider events */
					source1.collider.events?.onCollisionEnter?.({ ...collisionPayload1, manifold, flipped });
					source2.collider.events?.onCollisionEnter?.({ ...collisionPayload2, manifold, flipped });
				});
			} else {
				source1.rigidBody.events?.onCollisionExit?.(collisionPayload1);
				source2.rigidBody.events?.onCollisionExit?.(collisionPayload2);
				source1.collider.events?.onCollisionExit?.(collisionPayload1);
				source2.collider.events?.onCollisionExit?.(collisionPayload2);
			}

			// Sensor Intersections
			if (started) {
				if (world.intersectionPair(source1.collider.object, source2.collider.object)) {
					source1.rigidBody.events?.onIntersectionEnter?.(collisionPayload1);
					source2.rigidBody.events?.onIntersectionEnter?.(collisionPayload2);
					source1.collider.events?.onIntersectionEnter?.(collisionPayload1);
					source2.collider.events?.onIntersectionEnter?.(collisionPayload2);
				}
			} else {
				source1.rigidBody.events?.onIntersectionExit?.(collisionPayload1);
				source2.rigidBody.events?.onIntersectionExit?.(collisionPayload2);
				source1.collider.events?.onIntersectionExit?.(collisionPayload1);
				source2.collider.events?.onIntersectionExit?.(collisionPayload2);
			}
		});

		eventQueue.drainContactForceEvents((event) => {
			const source1 = this.getSourceFromColliderHandle(event.collider1());
			const source2 = this.getSourceFromColliderHandle(event.collider2());

			// Collision Events
			if (!source1?.collider.object || !source2?.collider.object) {
				return;
			}

			const collisionPayload1 = this.getCollisionPayloadFromSource(source1, source2);
			const collisionPayload2 = this.getCollisionPayloadFromSource(source2, source1);

			source1.rigidBody.events?.onContactForce?.({
				...collisionPayload1,
				totalForce: event.totalForce(),
				totalForceMagnitude: event.totalForceMagnitude(),
				maxForceDirection: event.maxForceDirection(),
				maxForceMagnitude: event.maxForceMagnitude(),
			});

			source2.rigidBody.events?.onContactForce?.({
				...collisionPayload2,
				totalForce: event.totalForce(),
				totalForceMagnitude: event.totalForceMagnitude(),
				maxForceDirection: event.maxForceDirection(),
				maxForceMagnitude: event.maxForceMagnitude(),
			});

			source1.collider.events?.onContactForce?.({
				...collisionPayload1,
				totalForce: event.totalForce(),
				totalForceMagnitude: event.totalForceMagnitude(),
				maxForceDirection: event.maxForceDirection(),
				maxForceMagnitude: event.maxForceMagnitude(),
			});

			source2.collider.events?.onContactForce?.({
				...collisionPayload2,
				totalForce: event.totalForce(),
				totalForceMagnitude: event.totalForceMagnitude(),
				maxForceDirection: event.maxForceDirection(),
				maxForceMagnitude: event.maxForceMagnitude(),
			});
		});

		world.forEachActiveRigidBody(() => {
			this.store.snapshot.invalidate();
		});
	}

	private getSourceFromColliderHandle(handle: ColliderHandle) {
		const world = this.worldSingleton();
		if (!world) return;

		const collider = world.proxy.getCollider(handle);
		const colEvents = this.colliderEvents.get(handle);
		const colliderState = this.colliderStates.get(handle);

		const rigidBodyHandle = collider.parent()?.handle;
		const rigidBody = rigidBodyHandle !== undefined ? world.proxy.getRigidBody(rigidBodyHandle) : undefined;
		const rigidBodyEvents =
			rigidBody && rigidBodyHandle !== undefined ? this.rigidBodyEvents.get(rigidBodyHandle) : undefined;
		const rigidBodyState = rigidBodyHandle !== undefined ? this.rigidBodyStates.get(rigidBodyHandle) : undefined;

		return {
			collider: { object: collider, events: colEvents, state: colliderState },
			rigidBody: { object: rigidBody, events: rigidBodyEvents, state: rigidBodyState },
		} as NgtrCollisionSource;
	}

	private getCollisionPayloadFromSource(
		target: NgtrCollisionSource,
		other: NgtrCollisionSource,
	): NgtrCollisionPayload {
		return {
			target: {
				rigidBody: target.rigidBody.object,
				collider: target.collider.object,
				colliderObject: target.collider.state?.object,
				rigidBodyObject: target.rigidBody.state?.object,
			},
			other: {
				rigidBody: other.rigidBody.object,
				collider: other.collider.object,
				colliderObject: other.collider.state?.object,
				rigidBodyObject: other.rigidBody.state?.object,
			},
		};
	}
}
