ソースを参照

ajout de memoization

pull/5/head
Figg 6ヶ月前
コミット
21a75b1e5b

+ 3
- 3
million/analyze/find_holes.py ファイルの表示

@@ -4,7 +4,7 @@ from typing import List
4 4
 from million.model.hole import Hole
5 5
 from million.model.message import Message
6 6
 from million.model.sequence import Sequence
7
-import million.analyze.message_evaluation as msg_ev
7
+import million.analyze.message_evaluation as msg_val
8 8
 
9 9
 
10 10
 def compute_sequences(messages: List[Message], accepted_max: int = 1_000_000) -> List[Sequence]:
@@ -12,9 +12,9 @@ def compute_sequences(messages: List[Message], accepted_max: int = 1_000_000) ->
12 12
     current = Sequence(start_message=messages[0])
13 13
     
14 14
     for message in messages[1:]:
15
-        if msg_ev.compute(message) > accepted_max: continue
15
+        if msg_val.get(message) > accepted_max: continue
16 16
 
17
-        if msg_ev.compute(message) == current.end() + 1:
17
+        if msg_val.get(message) == current.end() + 1:
18 18
             current.end_message = message
19 19
         else:
20 20
             sequences.append(current)

+ 10
- 2
million/analyze/message_evaluation.py ファイルの表示

@@ -1,9 +1,12 @@
1 1
 from math import floor
2
+from typing import Dict
2 3
 from million.model.message import Message
3 4
 
5
+memoization: Dict[Message, int] = {}
6
+
4 7
 # TODO WIP
5 8
 # - DNS to resolve audio, gif, pictures with counts
6
-def compute(msg: Message) -> int:
9
+def __compute__(msg: Message) -> int:
7 10
     """ Returns the estimated value counted in this message
8 11
     """
9 12
     value = None
@@ -15,4 +18,9 @@ def compute(msg: Message) -> int:
15 18
     except Exception as e:
16 19
         raise ValueError(
17 20
             f"Message {cleaned_content} does not contain a number ({e})")
18
-    return value
21
+    
22
+    memoization[msg] = value
23
+    return value
24
+
25
+def get(msg: Message) -> int:
26
+    return memoization.get(msg, __compute__(msg))

+ 3
- 0
million/model/message.py ファイルの表示

@@ -49,3 +49,6 @@ class Message(BaseModel):
49 49
         dt = datetime.fromtimestamp(self.timestamp_ms / 1000)
50 50
         dt_str = dt.strftime("%d/%m/%Y, %H:%M:%S")
51 51
         return f"{self.sender_name}({dt_str}) : {self.content}"
52
+
53
+    def __hash__(self) -> int:
54
+        return hash(self.sender_name + str(self.timestamp_ms))

+ 3
- 3
million/model/sequence.py ファイルの表示

@@ -4,7 +4,7 @@ from pydantic import BaseModel
4 4
 import pydantic
5 5
 
6 6
 from million.model.message import Message
7
-import million.analyze.message_evaluation as msg_ev
7
+import million.analyze.message_evaluation as msg_val
8 8
 
9 9
 
10 10
 class Sequence(BaseModel):
@@ -16,10 +16,10 @@ class Sequence(BaseModel):
16 16
         return v or values['start_message'] 
17 17
 
18 18
     def start(self) -> int:
19
-        return msg_ev.compute(self.start_message)
19
+        return msg_val.get(self.start_message)
20 20
     
21 21
     def end(self) -> int:
22
-        return msg_ev.compute(self.end_message)
22
+        return msg_val.get(self.end_message)
23 23
     
24 24
     def length(self) -> int:
25 25
         return self.end() - self.start() + 1

読み込み中…
キャンセル
保存