• R/O
  • HTTP
  • SSH
  • HTTPS

Commit

Tags
Keine Tags

Frequently used words (click to add to your profile)

javac++androidlinuxc#windowsobjective-ccocoa誰得qtpythonphprubygameguibathyscaphec計画中(planning stage)翻訳omegatframeworktwitterdomtestvb.netdirectxゲームエンジンbtronarduinopreviewer

Emergent generative agents


Commit MetaInfo

Revision5b55807030216ac952c2c35395ff243cba572994 (tree)
Zeit2023-05-20 04:04:24
AutorCorbin <cds@corb...>
CommiterCorbin

Log Message

Prototype choices, create subtagging for messages.

These go hand-in-hand; we want to be able to choose to not just reply on
IRC, but reply on a specific channel, so we need to use the channel as
subtag data.

The choice mechanism is clunky; I would very much like to just use
ReLLM, and maybe that'll be important in the future if/when we migrate
back to HF with their brand-new RWKV integration. They don't support
quantized models, though...

Ändern Zusammenfassung

Diff

--- a/src/agent.py
+++ b/src/agent.py
@@ -6,18 +6,19 @@ import json
66 import os.path
77 import re
88 import random
9+from string import ascii_uppercase
910 import sys
1011 from threading import Lock
1112
1213 from twisted.internet import reactor
1314 from twisted.internet.defer import DeferredLock, succeed
1415 from twisted.internet.protocol import ClientFactory
15-from twisted.internet.task import LoopingCall
16+from twisted.internet.task import LoopingCall, deferLater
1617 from twisted.internet.threads import deferToThread
1718 from twisted.words.protocols.irc import IRCClient
1819
1920 from common import irc_line, Timer, SentenceIndex, breakAt
20-from gens.mawrkov import MawrkovGen
21+from gens.mawrkov import MawrkovGen, force
2122 from gens.trans import SentenceEmbed
2223
2324 build_traits = " + ".join
@@ -43,55 +44,85 @@ prologues = {
4344 "clock": "Then I checked the time:",
4445 "lbi": "Then I tried to interpret what just happened:",
4546 "thoughts": "Then I thought to myself:",
46- "irc": "Then I chatted on IRC:",
47+ "choice": "Then I chose what to do next:",
48+ "irc": "Then I chatted on IRC, channel {subtag}:",
4749 }
4850
4951 class Mind:
5052 currentTag = None
5153 logits = state = None
5254
53- def __init__(self): self.lock = DeferredLock()
55+ def __init__(self, name):
56+ # Exclusively for debugging!
57+ self.name = name
58+ self.lock = DeferredLock()
5459
55- def switchTag(self, tag):
56- if tag == self.currentTag: return succeed(None)
60+ def switchTag(self, tag, subtag):
61+ if (tag, subtag) == self.currentTag: return succeed(None)
5762 else:
58- self.currentTag = tag
63+ self.currentTag = tag, subtag
5964 # Double newlines are added here.
60- return deferToThread(self.write, "\n" + prologues[tag])
65+ return deferToThread(self.write, "\n" +
66+ prologues[tag].format(subtag=subtag))
6167
62- def overhear(self, tag, s):
68+ def overhear(self, tag, subtag, s):
6369 def cb():
64- d = self.switchTag(tag)
70+ d = self.switchTag(tag, subtag)
6571 d.addCallback(lambda _: deferToThread(self.write, s))
6672 return d
6773 return self.lock.run(cb)
6874
69- def write(self, s):
75+ def writeRaw(self, s):
7076 with genLock:
71- # Newlines are added here.
72- self.logits, self.state = gen.feedForward(gen.tokenize(s + "\n"),
77+ self.logits, self.state = gen.feedForward(gen.tokenize(s),
7378 self.logits, self.state)
74- print(s)
79+ print(self.name, "~~", s.strip())
80+
81+ # Newlines are added here.
82+ def write(self, s): self.writeRaw(s + "\n")
7583
7684 def complete(self, s):
7785 with genLock:
7886 completion, self.logits, self.state = gen.complete(s, self.logits, self.state)
79- print("«", completion, "»")
87+ print(self.name, "«", completion.strip(), "»")
8088 return completion
8189
82- def infer(self, tag, prefix):
90+ def infer(self, tag, subtag, prefix):
91+ def cb():
92+ d = self.switchTag(tag, subtag)
93+ d.addCallback(lambda _: print(self.name, "~?", prefix)
94+ or deferToThread(self.complete, prefix))
95+ return d
96+ return self.lock.run(cb)
97+
98+ def forceChoice(self, tag, subtag, prefix, options):
99+ tokens = [gen.tokenize(opt)[0] for opt in options]
83100 def cb():
84- print(prefix)
85- d = self.switchTag(tag)
86- d.addCallback(lambda _: deferToThread(self.complete, prefix))
101+ d = self.switchTag(tag, subtag)
102+ # Breaking the newline invariant...
103+ d.addCallback(lambda _: print(self.name, "~!", prefix)
104+ or deferToThread(self.writeRaw, prefix))
105+ # ...so that this inference happens before the newline...
106+ d.addCallback(lambda _: force(tokens, self.logits))
107+ d.addCallback(lambda t: gen.tokenizer.decode([t]))
108+
109+ @d.addCallback
110+ def cb2(s):
111+ # ...and now the invariant must be restored...
112+ d2 = deferToThread(self.write, s)
113+ # ...but we want to tell the caller what was chosen!
114+ d2.addCallback(lambda _: s)
115+ return d2
87116 return d
88117 return self.lock.run(cb)
89118
90119
91120 class Agent:
92121 listeners = ()
122+ subtag = None
93123 def broadcast(self, s):
94- for listener in self.listeners: listener.overhear(self.tag, s)
124+ for listener in self.listeners:
125+ listener.overhear(self.tag, self.subtag, s)
95126
96127 class Clock(Agent):
97128 tag = "clock"
@@ -102,45 +133,77 @@ class LeftBrainInterpreter(Agent):
102133 tag = "lbi"
103134 events = 0
104135 def __init__(self, mind): self.mind = mind
105- def overhear(self, tag, s):
136+ def overhear(self, tag, subtag, s):
137+ # XXX should skip thoughts, maybe?
106138 self.events += 1
107- if self.events >= 10:
139+ if self.events >= 20:
108140 self.events = 0
109- return self.mind.infer(self.tag, "")
141+ return self.mind.infer(self.tag, self.subtag, "")
142+
143+choices = {
144+ "irc": "Chat on IRC, channel {subtag}",
145+ "thoughts": "Think to myself for a moment",
146+}
147+def choicify(options):
148+ return " ".join(f"({c}) {o}" for c, o in zip(ascii_uppercase, options))
149+
150+class ChoiceMaker(Agent):
151+ tag = "choice"
152+ def __init__(self, mind):
153+ self.mind = mind
154+ self.possibilities = set()
155+ self.actions = {}
156+ def overhear(self, tag, subtag, s): self.possibilities.add((tag, subtag))
157+ def dispatch(self, tag, subtag):
158+ if tag in self.actions: return self.actions[tag](subtag)
159+ def idle(self):
160+ possibilities = [(tag, subtag)
161+ for tag, subtag in self.possibilities
162+ if tag in choices]
163+ if len(possibilities) == 0: return
164+ elif len(possibilities) == 1: return self.dispatch(*possibilities[0])
165+
166+ options = [choices[tag].format(subtag=subtag)
167+ for tag, subtag in possibilities]
168+ prompt = choicify(options.values())
169+ self.broadcast(prompt)
170+ d = self.mind.forceChoice(self.tag, self.subtag, "My choice: ",
171+ ascii_uppercase[:len(options)])
172+ d.addCallback(lambda s: self.dispatch(*possibilities[chr(ord(s[0]) -
173+ ord('A'))]))
174+ return d
110175
111176 class ChainOfThoughts(Agent):
112177 tag = "thoughts"
113- def __init__(self, mind, index, seed):
178+ def __init__(self, mind, index, seed, idle):
114179 self.mind = mind
115180 self.index = index
116181 self.recentThoughts = deque([seed], maxlen=5)
182+ self.idle = idle
117183
118- def go(self):
119- cb = self.reflect if random.choice([0, 1]) else self.cogitate
120- return cb()
184+ def go(self): return random.choice([self.reflect, self.cogitate])()
121185
122- def overhear(self, tag, s):
123- self.addRelatedThoughts(tag + ": " + s)
186+ def overhear(self, tag, subtag, s):
187+ prefix = f"{tag} ({subtag}): " if subtag else tag + ": "
188+ self.addRelatedThoughts(prefix + s)
124189
125190 def addRelatedThoughts(self, s):
126191 thoughts = self.index.search(s, 2)
192+ new = 0
127193 for thought in thoughts:
128194 if thought not in self.recentThoughts:
195+ new += 1
129196 self.recentThoughts.append(thought)
130197 self.broadcast(thought)
198+ return new
131199
132- def cogitate(self): self.addRelatedThoughts(self.recentThoughts[-1])
200+ def cogitate(self):
201+ new = self.addRelatedThoughts(self.recentThoughts[-1])
202+ if not new: return deferLater(reactor, 0.0, self.idle)
133203 def reflect(self):
134- d = self.mind.infer(self.tag, "")
135-
136- @d.addCallback
137- def cb(s):
138- if not s.strip():
139- self.broadcast(random.choice([
140- "Head empty; no thoughts.",
141- "So bored.",
142- "Zoned out.",
143- ]))
204+ d = self.mind.infer(self.tag, self.subtag, "")
205+ d.addCallback(lambda s: s.strip() or self.idle())
206+ return d
144207
145208
146209 IRC_LINE_HEAD = re.compile(r"\d{1,2}:\d{1,2}:\d{1,2} <")
@@ -155,40 +218,45 @@ class IRCAgent(Agent, IRCClient):
155218 self.nickname = title_to_nick(title)
156219 self.startingChannels = startingChannels
157220
221+ def broadcastTagged(self, subtag, s):
222+ for listener in self.listeners: listener.overhear(self.tag, subtag, s)
223+
158224 def prefix(self, channel):
159225 return f"{datetime.now():%H:%M:%S} {channel} <{self.nickname}>"
160226
161227 def userJoined(self, user, channel):
162- self.broadcast(f"{user} joins {channel}")
228+ self.broadcastTagged(channel, f"{user} joins {channel}")
163229
164230 def userLeft(self, user, channel):
165- self.broadcast(f"{user} leaves {channel}")
231+ self.broadcastTagged(channel, f"{user} leaves {channel}")
166232
167233 def userQuit(self, user, channel):
168- self.broadcast(f"{user} quits {channel}")
234+ self.broadcastTagged(channel, f"{user} quits {channel}")
169235
170236 def topicUpdated(self, user, channel, topic):
171- self.broadcast(f"Topic for {channel} is now: {topic}")
237+ self.broadcastTagged(channel, f"Topic for {channel} is now: {topic}")
172238
173239 def signedOn(self):
174240 for channel in self.startingChannels: self.join(channel)
175241
176242 def privmsg(self, user, channel, line):
177243 user = user.split("!", 1)[0]
178- self.broadcast(irc_line(datetime.now(), channel, user, line))
179- if self.nickname in line:
180- d = self.mind.infer("irc", self.prefix(channel))
244+ self.broadcastTagged(channel, irc_line(datetime.now(), channel, user, line))
245+ # XXX should eventually be not needed!
246+ # if self.nickname in line: self.speakInChannel(channel)
181247
182- @d.addCallback
183- def cb(s):
184- line = breakIRCLine(s).strip()
185- self.msg(channel, line)
248+ def speakInChannel(self, channel):
249+ d = self.mind.infer("irc", channel, self.prefix(channel))
250+
251+ @d.addCallback
252+ def cb(s): self.msg(channel, breakIRCLine(s).strip())
186253
187254 class IRCFactory(ClientFactory):
188255 protocol = IRCAgent
189- def __init__(self, mind, listeners, title, startingChannels):
256+ def __init__(self, mind, cm, listeners, title, startingChannels):
190257 super(IRCFactory, self).__init__()
191258 self.mind = mind
259+ self.cm = cm
192260 self.listeners = listeners
193261 self.title = title
194262 self.startingChannels = startingChannels
@@ -196,12 +264,13 @@ class IRCFactory(ClientFactory):
196264 protocol = self.protocol(self.mind, self.title, self.startingChannels)
197265 protocol.factory = self
198266 protocol.listeners = self.listeners
267+ self.cm.actions["irc"] = protocol.speakInChannel
199268 return protocol
200269
201270 def go():
202271 print("~ Starting tasks…")
203272 clock = Clock()
204- LoopingCall(clock.go).start(300.0, now=False)
273+ LoopingCall(clock.go).start(60 * 10, now=False)
205274
206275 for logpath in sys.argv[2:]:
207276 character = load_character(logpath)
@@ -210,19 +279,21 @@ def go():
210279 thoughtPath = os.path.join(logpath, "thoughts.txt")
211280 thoughtIndex = SentenceIndex.fromPath(thoughtPath, embedder)
212281
213- mind = Mind()
282+ mind = Mind(title_to_nick(title))
214283 with Timer("initial warmup"):
215284 mind.logits, mind.state = gen.feedForward(gen.tokenize(firstStatement), None, None)
216285
217286 lbi = LeftBrainInterpreter(mind)
218287 clock.listeners += lbi, mind
219288
220- thoughts = ChainOfThoughts(mind, thoughtIndex, firstStatement)
289+ cm = ChoiceMaker(mind)
290+
291+ thoughts = ChainOfThoughts(mind, thoughtIndex, firstStatement, cm.idle)
221292 thoughts.listeners = lbi, mind
222- LoopingCall(thoughts.go).start(120.0, now=False)
293+ LoopingCall(thoughts.go).start(60, now=False)
223294
224295 print("~ Thought index:", thoughtIndex.size(), "thoughts")
225- factory = IRCFactory(mind, (thoughts, lbi, mind),
296+ factory = IRCFactory(mind, cm, (thoughts, lbi, cm, mind),
226297 title,
227298 character["startingChannels"])
228299 print("~ Connecting factory for:", title)
--- a/src/gens/mawrkov.py
+++ b/src/gens/mawrkov.py
@@ -27,8 +27,7 @@ TEMPERATURE = 0.8
2727 TOP_P = 0.8
2828
2929 class MawrkovGen:
30- # XXX might be wrong
31- model_name = "The Pile (14B params, 4-bit quantized)"
30+ model_name = "The Pile"
3231 model_arch = "RWKV"
3332 def __init__(self, model_path, max_new_tokens):
3433 self.max_new_tokens = max_new_tokens
@@ -57,3 +56,8 @@ class MawrkovGen:
5756 logits, state = self.feedForward([token], logits, state)
5857 if "\n" in self.tokenizer.decode([token]): break
5958 return self.tokenizer.decode(tokens).split("\n", 1)[0], logits, state
59+
60+def force(options, logits):
61+ # +10 is hopefully not too much.
62+ biases = {opt: 10 for opt in options}
63+ return sampling.sample_logits(logits, TEMPERATURE, TOP_P, biases)