123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 |
- using System.Collections.Generic;
- using System.Linq;
- using System.Threading;
- using System.Threading.Tasks;
- using LLM.Editor.Client;
- using LLM.Editor.Data;
- using UnityEditor;
- using UnityEngine;
- namespace LLM.Editor.Core
- {
- [InitializeOnLoad]
- public static class MemoryRetriever
- {
- public static async Task<List<InteractionRecord>> GetRelevantMemories(string newUserPrompt, CancellationToken cancellationToken, int topK = 3)
- {
- var apiClient = ApiClientFactory.GetClient();
- if (apiClient == null)
- {
- Debug.LogError("[MemoryRetriever] Could not get API client to generate embedding.");
- return new List<InteractionRecord>();
- }
-
- var newUserEmbedding = await EmbeddingHelper.GetEmbedding(newUserPrompt, apiClient.GetAuthToken, cancellationToken);
- if (newUserEmbedding == null)
- {
- // This can happen if the request was cancelled, which is not an error.
- if(!cancellationToken.IsCancellationRequested)
- Debug.LogError("[MemoryRetriever] Could not generate embedding for the new prompt.");
-
- return new List<InteractionRecord>();
- }
- var allRecords = MemoryLogger.GetRecords();
- var scoredRecords = new List<(InteractionRecord record, float score)>();
- foreach (var record in allRecords)
- {
- if (record.PromptEmbedding == null || record.PromptEmbedding.Length == 0) continue;
- var score = CosineSimilarity(newUserEmbedding, record.PromptEmbedding);
- scoredRecords.Add((record, score));
- }
- return scoredRecords
- .OrderByDescending(t => t.score)
- .Take(topK)
- .Select(t => t.record)
- .ToList();
- }
- private static float CosineSimilarity(IReadOnlyList<float> v1, IReadOnlyList<float> v2)
- {
- if (v1.Count != v2.Count)
- {
- Debug.LogError("[MemoryRetriever] Vector dimensions do not match for cosine similarity.");
- return 0;
- }
- float dotProduct = 0.0f;
- float norm1 = 0.0f;
- float norm2 = 0.0f;
- for (var i = 0; i < v1.Count; i++)
- {
- dotProduct += v1[i] * v2[i];
- norm1 += Mathf.Pow(v1[i], 2);
- norm2 += Mathf.Pow(v2[i], 2);
- }
- if (norm1 == 0 || norm2 == 0) return 0;
- return dotProduct / (Mathf.Sqrt(norm1) * Mathf.Sqrt(norm2));
- }
- }
- }
|